I made a custom model using TFRS. The model is fitted and is able to make predictions. But when I come to the point to save this model. It ran into an error FailedPreconditionError: Failed to serialize the input pipeline graph: ResourceGather is stateful. [Op:DatasetToGraphV2].
More details:
class DeepCrossMultitaskModel(tfrs.models.Model):
"""compute user-item-interaction embedding and rate the interaction"""
def __init__(self,label, df, user_id, item_id, u_cat_cols, u_cnt_cols_bins, u_txt_cols, i_cat_cols, i_cnt_cols_bins, i_txt_cols, interactions, embedding_dimension, flag_use_norm, layer_sizes, rating_weight, retrieval_weight, projection_dim=None, user_layer_sizes=[32], item_layer_sizes=[32]):
...
def call(self, features: Dict[Text, tf.Tensor]) -> tf.Tensor:
...
return (user_embeddings, item_embeddings, rating_predictions)
def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
...
return weighted_loss
# train the model
model = DeepCrossMultitaskModel(...)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
one_layer_history = model.fit(
cached_train,
validation_data=cached_test,
validation_freq=5,
epochs=10,
verbose=0)
# Save the entire model as a SavedModel.
#!mkdir -p saved_model
filepath = 'saved_model/model0'
model.save(filepath)
Detail error message:
FailedPreconditionError Traceback (most recent call last)
/tmp/ipykernel_2710054/203170664.py in
6 #!mkdir -p saved_model
7 filepath = ‘saved_model/model0’
----> 8 model.save(filepath) # tf.saved_model.save(model, filepath)
9
10 # to load it back
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.traceback)
—> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/napi/k_working_dir/ds_test/shared-space/venv/lib/python3.9/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
7105 def raise_from_not_ok_status(e, name):
7106 e.message += (" name: " + name if name is not None else “”)
→ 7107 raise core._status_to_exception(e) from None # pylint: disable=protected-access
7108
7109
FailedPreconditionError: Failed to serialize the input pipeline graph: ResourceGather is stateful. [Op:DatasetToGraphV2]