Subclassing tf.keras.Model and overrding its train_step() function give us the kind of flexibility we need to control our training loops. It allows for easily plugging in our favorite callbacks, and almost do all kinds of stuff that are readily available during model.fit().
I am wondering how to use the ModelCheckpoint callback in this instance. Consider the following use-case (taken from here):
Of course. But that somehow defeats the purpose of progressive disclosure of complexity IMO. I wanted to able to focus on my training loop and delegate rest of the things to the framework whenever possible.
And the link does not elaboratively suggest a workaround for the use-case I mentioned.
Callback list seems to be a good option. I will try it out.
If your subclassed model (where I am overriding train_step()) contains two or more models and if you are passing ModelCheckpoint callback while calling .fit() on the subclassed model the callback would get confused.
That shouldn’t be like that. Models are supposed to be nestable.
…
The problem here is that the callback is defaulting to saving the model in HDF5 format (which apparently requires that to call .fit to set the input shape, and we don’t call fit on the nested mopdels.).
Set save_weights_only=True to save in the tensorflow checkpoint format and then it works.