Hi all, I have been working on a model which makes liberal use of @tf.function(jit_compile=True)
. However, when I try to export the model as a Frozen Graph, I am hitting some issues with missing output tensor shapes. However, removing jit_compile=True
sidesteps the issue.
I have distilled this down into the following reproducer:
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
@tf.function(jit_compile=True)
def jit_compiled_fn(x, y):
x_squared = tf.square(x)
result = x_squared + y
return x_squared, result
@tf.function
def not_jit_compiled_fn(x, y):
return jit_compiled_fn(x, y)
concrete_fn = not_jit_compiled_fn.get_concrete_function(
tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
tf.TensorSpec(shape=[1, 2], dtype=tf.float32))
print("Before Freezing:")
print(concrete_fn.structured_outputs)
frozen_fn = convert_variables_to_constants_v2(concrete_fn)
print("After Freezing:")
for tensor in frozen_fn.graph.get_operations():
if tensor.type == 'Identity':
print(f"{tensor.name}: {tensor.outputs[0].shape}")
Which prints the following:
Before Freezing:
(<tf.Tensor 'Identity:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 2) dtype=float32>)
After Freezing:
Identity: <unknown>
Identity_1: <unknown>
However, removing the jit_compile
kwarg prints
Before Freezing:
(<tf.Tensor 'Identity:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'Identity_1:0' shape=(1, 2) dtype=float32>)
After Freezing:
Func/PartitionedCall/input/_0: (1, 2)
PartitionedCall/Identity: (1, 2)
Func/PartitionedCall/output/_2: (1, 2)
Identity: (1, 2)
Func/PartitionedCall/input/_1: (1, 2)
PartitionedCall/Identity_1: (1, 2)
Func/PartitionedCall/output/_3: (1, 2)
Identity_1: (1, 2)
It is also worth noting that adding jit_compile=True
to not_jit_compiled_fn
also obscures the output shapes. However, freezing the concrete function of jit_compiled_fn
, does not.
I was hoping to sanity check my approach before considering it a bug; it’s far more likely that I’m missing some autograph or XLA nuance.
EDIT: I had forgotten that since TF 2.0, frozen graphs aren’t supported/are deprecated, so perhaps it’s not surprising that XLA compilation of functions causes issues with freezing, which is a throwback to TF1 sessions.