When retraining a model by loading a TensorFlow model, the initial validation losses are well above the values at the end of the previous training due to the optimizer states being reset when loading the model due to its compilation. After several searches, I still don’t understand how to continue training without having this problem. How do I do?
Hi @marcocintra, you can use tf.train.Checkpoint for saving both model and optimizer weights.
optimizer=tf.keras.optimizers.Adam()
model.compile(loss="categorical_crossentropy", optimizer=optimizer,metrics=['accuracy'])
checkpoint = tf.train.Checkpoint(model=model,optim=optim)
checkpoint.save(path='saved_model/ckpt-1')
Thank You.
Ok, thanks. Once I’ve saved a model checkpoint using ModelCheckpoint callback (from tensorflow.keras.callbacks import ModelCheckpoint) can I restore this checkpoint to retrain the model using tf.train.Checkpoint? If yes, how can I do it?
Hi @marcocintra, Once you have saved the model and optimizer state using tf.train.Checkpoints. you have define the new model with the same architecture as the previous model and restore the checkpoints to the new model and train the new model to continue training. For example, define new model
model1 = keras.Sequential([
keras.Input(shape=(28, 28, 1)),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu',),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.5),
keras.layers.Dense(10, activation='softmax')
])
get last saved checkpoint
checkpoint_dir = './checkpoints'
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
restore those checkpoints to new model
if latest_checkpoint:
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model1)
checkpoint.restore(latest_checkpoint)
print(f"Restored from {latest_checkpoint}")
else:
print("No checkpoint found. Starting training from scratch.")
complie and train the model
model1.compile(loss='categorical_crossentropy', metrics=['accuracy'])
model1.fit(x_train, y_train, epochs=2)
Please refer to this gist for working code example. Thank You.