The general approach to Serialize/Save model

TF version: 2.6

What is the most general way to serialize/save custom Tensorflow model (a subclass of tf.keras.Model) so far?

I have tried several ways, some of them are from Q&A several years ago. I wonder if there is any updates or some new "one for all` approach.

  • pickle: For example, for Pytorch model, we can directly pickle. But if I pickle a model subclass of tf.keras.Model, it raises can not pickle weakref.

  • SavedModel: this requires input signature when loading, which I think is an extra step. And I want to save a model with list of tensor as input. But I only saw some workaround to do this, e.g. stack to a single Tensor, or use dict instead.

  • Keras H5 model: this can not save subclass model. It only support functional model.

@Litchy_S,

You can use model.save to save the subclass model.

Example:

class CustomModel(keras.Model):
    def __init__(self, hidden_units):
        super(CustomModel, self).__init__()
        self.hidden_units = hidden_units
        self.dense_layers = [keras.layers.Dense(u) for u in hidden_units]

    def call(self, inputs):
        x = inputs
        for layer in self.dense_layers:
            x = layer(x)
        return x

    def get_config(self):
        return {"hidden_units": self.hidden_units}

    @classmethod
    def from_config(cls, config):
        return cls(**config)


model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

For more details please refer here Save, serialize, and export models  |  TensorFlow Core.

Thank you!