Hello,
I am currently attempting to save model signatures with a static batch size in order to compile the network “ahead-of-time”. To achieve this, I am creating a concrete_function
using a static tensorspec
. Im doing this with Tensorflow 2.8.0.
This conversion process works well for models with only one input node or for Keras models with multiple input nodes. However, when attempting to extract the concrete_function
of restored models that were saved and loaded with tf.saved_model
, I encounter the following error, which I would not expect:
ValueError: Could not find matching concrete function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* [<tf.Tensor 'inputs:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'inputs_1:0' shape=(1, 3) dtype=float32>, <tf.Tensor 'inputs_2:0' shape=(1, 10) dtype=float32>]
* False
* None
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (3 total):
* (TensorSpec(shape=(None, 2), dtype=tf.float32, name='first'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='second'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='third'))
* False
* None
Keyword arguments: {}
Option 2:
Positional arguments (3 total):
* (TensorSpec(shape=(None, 2), dtype=tf.float32, name='inputs/0'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='inputs/1'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='inputs/2'))
* False
* None
Keyword arguments: {}
Option 3:
Positional arguments (3 total):
* (TensorSpec(shape=(None, 2), dtype=tf.float32, name='inputs/0'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='inputs/1'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='inputs/2'))
* True
* None
Keyword arguments: {}
Option 4:
Positional arguments (3 total):
* (TensorSpec(shape=(None, 2), dtype=tf.float32, name='first'), TensorSpec(shape=(None, 3), dtype=tf.float32, name='second'), TensorSpec(shape=(None, 10), dtype=tf.float32, name='third'))
* True
* None
So it seems that a matching function could not be found. It appears that I am doing something wrong. Additionally, I do not understand why my input tensorspecs lose their name information when passed to get_concrete_function.
My question now is: How do I correctly extract the concrete function for my saved_model?
I would expect that my code would work, since it worked on keras models.
To reproduce this behavior, please use the following code example:
import tensorflow as tf
def create_test_model():
# model with 2 input nodes and 1 output node, with non-static batchsize
x1 = tf.keras.Input(shape=(2,), name="first")
x2 = tf.keras.Input(shape=(3,), name="second")
x3 = tf.keras.Input(shape=(10,), name="third")
x = tf.concat([x1, x2], axis=1)
a1 = tf.keras.layers.Dense(10, activation="elu")(x)
y = tf.keras.layers.Dense(5, activation="softmax")(a1)
model = tf.keras.Model(inputs=(x1, x2, x3), outputs=y)
return model
def static_concrete_function(model, batch_size: int):
static_tensorspec = [tf.TensorSpec(shape=(batch_size, 2), dtype=tf.float32, name='first'),
tf.TensorSpec(shape=(batch_size, 3), dtype=tf.float32, name='second'),
tf.TensorSpec(shape=(batch_size, 10), dtype=tf.float32, name='third')]
# get the concrete function for the signature: static_tensorspec
new_signature = tf.function(
model.__call__).get_concrete_function(inputs=static_tensorspec, training=False, mask=None)
return new_signature
def main():
# create and save model
model = create_test_model()
path_tf = "./tf_model"
tf.saved_model.save(model, path_tf)
path_keras = "./keras_model"
model.save(path_keras, overwrite=True, include_optimizer=False)
# load model
keras_model = tf.keras.models.load_model(path_keras)
tf_model = tf.saved_model.load(path_tf)
# extract concrete function with static batch size
batch_size = 1
# works for keras model
keras_concrete = static_concrete_function(keras_model, batch_size)
print("*" * 50)
print(f"Keras Models: Input Signature\n{keras_concrete.structured_input_signature}")
print("*" * 50)
# fails to find matching signature for tensorflow saved model
tf_concrete = static_concrete_function(tf_model, batch_size)
if __name__ == "__main__":
main()