Hello Team,
I have a requirement where i need to have my saved model in frozen graph format. I am not able to find any pointers to convert my tfdf.keras.GradientBoostedTreesModel model to frozen graph .pb format for our internal serving infrastructure.
I followed the tutorials here but it didnt work out. I got this error: AttributeError: 'GradientBoostedTreesModel' object has no attribute 'graph'
I am in a time crunch and have been stuck on this for 3 days, can someone help me out?
Here is how i am saving my model
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label=feature_spec.get_label_col())
valid_ds = tfdf.keras.pd_dataframe_to_tf_dataset(valid_df, label=feature_spec.get_label_col())
test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label=feature_spec.get_label_col())
if not model_config.gradient_boosted_trees_config:
exit(f"Boosted trees config not set, check the config file {model_config}")
model = tfdf.keras.GradientBoostedTreesModel(
task=tfdf.keras.Task.CLASSIFICATION,
**model_config.gradient_boosted_trees_config.model_dump(),
features=[
f for f in feature_spec.get_all_tffeature() if f.name not in feature_spec.get_blacklist_feature_names()
],
exclude_non_specified_features=True,
num_threads=12,
)
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=model_config.tensorboard_log_dir + "/{}")
model.fit(train_ds, validation_data=valid_ds, callbacks=[tensorboard_callback], batch_size=BATCH_SIZE)
# Some information about the model.
print(model.make_inspector().variable_importances())
print(model.summary())
# Evaluates the model on the test dataset.
model.compile(
metrics=[
tf.keras.metrics.Accuracy(),
tf.keras.metrics.Precision(),
tf.keras.metrics.Recall(),
]
)
evaluation = model.evaluate(valid_ds)
print(f"BinaryCrossentropyloss: {evaluation[0]}")
print(f"Accuracy: {evaluation[1]}")
saved_model_path = f"{model_config.saved_model_dir}/boosted_trees/my_saved_model_{utils.get_now()}"
model.save(saved_model_path + ".keras", save_format="keras")
model.save(saved_model_path, save_format="tf")