I want to implement a grappler graph optimization pass. I want to do split for data input node only. For example, a node handling y = wx + b
may have 3 input nodes, w
, x
and b
.
I only want to do tf.split(x)
, not touch w
and b
.
So,
- is there any way to get
x
out from a list of nodes? (for example, get x from node->inputs()). - Similarly, if the above w is trainable , and b is not trainable, how could I get this information from graph?
Thanks in advance