I’d like to be able to obtain the underlying computational graph consisting of the low-level tf operators of a tf2 model. Basically, for any given tf model architecture(along with training script), is there a straightforward way to modify my code to get a list of the tf operators called during training?
import tensorflow as tf
fun = lambda x: tf.nn.softmax(x)
tf_fun = tf.function(fun)
graph = tf_fun.get_concrete_function(tf.constant([1.0])).graph
isinstance(concrete_fun.graph, tf.Graph) # True
graph.get_operations() # returns a list of graph operations
Notice the model’s forward pass will need to be a tf.function and you’ll need to know at least the input shapes ahead of time. If the model is a Keras model, here’s a somewhat hacky way to do that:
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(1, activation='softmax', input_shape=(1,))
])
model.predict(tf.constant([1.0])) # model.predict_function is only populated after 1 call to predict, you can also do the same thing with model.train_function
graph = model.predict_function.get_concrete_function(iter([tf.constant([1.0])])).graph # The concrete function takes an iterator
isinstance(graph, tf.Graph) # True
graph.get_operations()
Thanks for the reply! I’m actually looking for the operators that are called during training (not inference) of the model: during backpropagation as well.
Thanks for sharing! Unfortunately, I don’t think the ops from a SavedModel would represent those invoked during training (which is what I’m interested in). To clarify my question: I’d like a way to store the computational graph (op graph) that’s visible on Tensorboard as a JSON file/ some other readable format.
import tensorflow as tf
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(2, activation='softmax', input_shape=(1,))
])
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
xs = tf.constant([[1.0]])
ys = tf.constant([[0.2, 0.8]])
model.train_on_batch(xs, ys)
graph = model.train_function.get_concrete_function(iter([(xs, ys)])).graph # The concrete function takes an iterator
isinstance(graph, tf.Graph) # True
graph.get_operations()
This will list the operations performed during a single compiled training step. If you want to export that to a readable format you can do graph.as_graph_def().