How to load a TensorFlow model to retrain it without the optimizer states being reset?

When retraining a model by loading a TensorFlow model, the initial validation losses are well above the values at the end of the previous training due to the optimizer states being reset when loading the model due to its compilation. After several searches, I still don’t understand how to continue training without having this problem. How do I do?

1 Like

Hi @marcocintra, you can use tf.train.Checkpoint for saving both model and optimizer weights.

optimizer=tf.keras.optimizers.Adam()
model.compile(loss="categorical_crossentropy", optimizer=optimizer,metrics=['accuracy'])

checkpoint = tf.train.Checkpoint(model=model,optim=optim)
checkpoint.save(path='saved_model/ckpt-1')

Thank You.

1 Like

Ok, thanks. Once I’ve saved a model checkpoint using ModelCheckpoint callback (from tensorflow.keras.callbacks import ModelCheckpoint) can I restore this checkpoint to retrain the model using tf.train.Checkpoint? If yes, how can I do it?

Hi @marcocintra, Once you have saved the model and optimizer state using tf.train.Checkpoints. you have define the new model with the same architecture as the previous model and restore the checkpoints to the new model and train the new model to continue training. For example, define new model

model1 = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu',),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax')
])

get last saved checkpoint

checkpoint_dir = './checkpoints'
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)

restore those checkpoints to new model

if latest_checkpoint:
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model1)
    checkpoint.restore(latest_checkpoint)
    print(f"Restored from {latest_checkpoint}")
else:
    print("No checkpoint found. Starting training from scratch.")

complie and train the model

model1.compile(loss='categorical_crossentropy', metrics=['accuracy'])
model1.fit(x_train, y_train, epochs=2)

Please refer to this gist for working code example. Thank You.

1 Like