We are using tensorflow 2.12 for training a prediction model and serving via Tensorflow-Serving REST API for prediction. The training works fine and we have a ServingModel as follows
inp = {
"request_features": tf.keras.Input(
name="request_features,
shape=(10,),
dtype=tf.float32,
),
"content_id": tf.keras.Input(
name="content_id", shape=(), dtype=tf.string
),
"product_sku": tf.keras.Input(
name="product_sku",
shape=(None,),
dtype=tf.string,
ragged=True,
),
"country_id": tf.keras.Input(
shape=(), dtype=tf.int64, name="country_id"
),
"sales_channel": tf.keras.Input(
shape=(), dtype=tf.int64, name="sales_channel"
),
}
serving_core_layer = ServingLayer(
model,
name="serving_layer",
)
out = serving_core_layer(inp)
tf.keras.Model(inputs=inp, outputs=out)
where ServingLayer is defined generally
class ServingLayer(tf.keras.layers.Layer):
# We have some complex handling on the input and processing
# which I could not simply abstract here
After the model is trained and saved. I found out the output metadata lost the input names, such as request_features
, content_id
, instead, the input layer names become args_0
, args_0_1
, etc. Also, there is a warning when the model was saved
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_3 in the SavedModel
and the model output metadata is like
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['args_0'] tensor_info:
dtype: DT_INT64
shape: (-1)
name: serving_default_args_0:0
inputs['args_0_1'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: serving_default_args_0_1:0
inputs['args_0_2'] tensor_info:
dtype: DT_INT64
shape: (-1)
name: serving_default_args_0_2:0
inputs['args_0_3'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 10)
name: serving_default_args_0_3:0
inputs['args_0_4'] tensor_info:
dtype: DT_INT64
shape: (-1)
name: serving_default_args_0_4:0
inputs['args_0_5'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: serving_default_args_0_5:0
The given SavedModel SignatureDef contains the following output(s):
outputs['serving_layer'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
So, seems we have 5 InputLayers, but now there are 6 Inputs in metadata including args_0
, and we lost the Input real names, so I could not make correct request body for calling REST API for prediction.
So, could anyone help what might cause this problem and how to fix it? Thanks!