Difficulties saving/loading a generic subclassed `Model`: TypeError: __init__() missing 2 required positional arguments: 'inputs' and 'outputs'

Hi all,

I’m trying to subclass keras.Model in order to create a class that can store additional metadata in human-readbale format, and which is saved alongside a saved model.

I’ve got something that works well in terms of basic functionality; the only problem is that I’m not able to load a saved model again. All code posted below is a simplified example that reproduces the problem and can also be acceseed here.

Note: The following is a bit of a long story on how I got here, I’ve tried lots of different options (plenty of them wrong), but I still include them here to show what I’ve already tried. It might also help other people to better understand the possibilities on how to load/save custom models.

class MyModel(tf.keras.Model):
  def __init__(self, a, *args, **kwargs):
  # a is a required argument to initial this specific model type, e.g. it is used somewhere during the custom training step
    super().__init__(*args, **kwargs)

    self._a = a

  def a(self):
    return self._a

# Create test model
inp = Input(shape=(100, 3))
feat = Conv1D(8, 1, activation='relu')(inp)
feat = GlobalMaxPool1D()(feat)
feat = Dense(3)(feat)

model = MyModel(inputs=inp, outputs=feat, name='MyModel', a=12)

The following way of saving/loading works, but has the problem that my loaded model is not of type MyModel, but a regular Functional and it has lost the property a. This is AFAIK normal and expected behavior:

loaded_model = tf.keras.models.load_model(tmp_path)
print(type(loaded_model))   # <class 'keras.engine.functional.Functional'>

A workaround would be to then use this model and create a new MyModel with it; but even if I had the property a to do so correctly, the created model would have lost all optimizer/loss/metrics info.

fake_loaded_model = MyModel(inputs=loaded_model.inputs, outputs=loaded_model.outputs, a=some_value)

OK, so injecting the class via custom_objects should help, right?

loaded_model = tf.keras.models.load_model(tmp_path, custom_objects={'MyModel': MyModel})
# TypeError: __init__() missing 1 required positional argument: 'a'

That makes sense, I didn’t include a get_config and from_config… Correcting that:

class MyModel(tf.keras.Model):
  ... # same as before
  def get_config(self):
    cfg = super().get_config()
    cfg.update({'a': self.a})
    return cfg

  def from_config(cls, config):
    return cls(**config)

# Recreate and re-save model
# ...

# Try loading
loaded_model = tf.keras.models.load_model(tmp_path, custom_objects={'MyModel': MyModel})
# TypeError: __init__() missing 2 required positional arguments: 'inputs' and 'outputs'
# Call stack: from_config --> MyModel.__init__ --> super().__init__(*args, **kwargs)

This kind of also makes sense to me. What happens now is that from_config correctly tries to instantiate my class by calling its constructor with the loaded config. The problem is that this config does not contain inputs and outputs, but rather something called input_layers and output_layers:

# {'name': 'MyModel', 'layers': [... a list of layers...], 'input_layers': [['input_34', 0, 0]], 'output_layers': [['dense_33', 0, 0]], 'a': 12}

So.. Finally my question. My understanding of subclassing `keras.Model` is that it is designed to encapsulate a specific network architecture within a `Model` class, and I'm trying to use it in a different way (i.e. add functionality to the basic `Model` class but do not tie it to a specific hard-coded architecture, only specifying inputs and outputs at construction time). Am I simply trying to do something that is not supported, or is there any way around this? How is it possible that if I load this from the trace (i.e. without specifying `custom_objects` keras is able to load the model without having to  supply `inputs` and `outputs`?

If you got this far, THANKS for reading 'til the end. If someone can show me the light on how to approach this, I’d be very grateful.

Kind regards,

Hello @Steven_Tondeur

Thank you for using TensorFlow,
According to the documentation, It is suggested that, When saving a model that includes custom objects, such as a subclassed Layer, you must define a get_config() method on the object class. If the arguments passed to the constructor (__init__() method) of the custom object aren’t Python objects (anything other than base types like ints, strings, etc.), then you must also explicitly deserialize these arguments in the from_config() class method.
Thank you.