I want to implement a grappler graph optimization pass. I want to do split for data input node only. For example, a node handling y = wx + b may have 3 input nodes, w , x and b.
I only want to do tf.split(x), not touch w and b.
So,
- is there any way to get
xout from a list of nodes? (for example, get x from node->inputs()). - Similarly, if the above w is trainable , and b is not trainable, how could I get this information from graph?
Thanks in advance