I have problems saving and loading models composed of custom layers.
A good example of this kind of model in the VAE example in tensorflow official documentation at https://www.tensorflow.org/guide/keras/making_new_layers_and_models_via_subclassing#putting_it_all_together_an_end-to-end_example
In fact, there is no error when I save it:
vae.save("vae211.keras")
But when I load it :
vae_ = tf.keras.models.load_model("vae211.keras")
I have the following error:
TypeError: Could not locate class ‘VariationalAutoEncoder’. Make sure
custom classes are decorated with
@keras.saving.register_keras_serializable()
. Full object config:
{‘module’: None, ‘class_name’: ‘VariationalAutoEncoder’, ‘config’:
{‘name’: ‘autoencoder’, ‘trainable’: True, ‘dtype’: ‘float32’,
‘img_d’: 784, ‘hidden_d’: 128, ‘latent_d’: 32}, ‘registered_name’:
‘Custom>VariationalAutoEncoder’, ‘build_config’: {‘input_shape’: [64,
784]}}
I tried to add get_config() methods but nothing worked.
Please find the full code below if you can help :
import tensorflow as tf
img_d = 784
hidden_d = 128
latent_d = 32
epochs = 2
"""
Dataset
"""
(x_tra, _), _ = tf.keras.datasets.mnist.load_data()
x_tra = x_tra.reshape(60000, 784).astype("float32") / 255
tra_ds = tf.data.Dataset.from_tensor_slices(x_tra)
tra_ds = tra_ds.shuffle(buffer_size=1024).batch(64)
"""
Model
"""
@tf.keras.saving.register_keras_serializable()
class Sampling(tf.keras.layers.Layer):
def call(self, z_mean, z_log_var):
bs, latent_dim = tf.shape(z_mean)
epsilon = tf.random.normal(shape=(bs, latent_dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
@tf.keras.saving.register_keras_serializable()
class Encoder(tf.keras.layers.Layer):
def __init__(self, latent_d=32, hidden_d=64, name="encoder", **kwargs):
super().__init__(name=name, **kwargs)
self.dense1 = tf.keras.layers.Dense(hidden_d, activation="relu")
self.dense_mean = tf.keras.layers.Dense(latent_d)
self.dense_log_var = tf.keras.layers.Dense(latent_d)
self.sampling = Sampling()
def call(self, inputs):
x = self.dense1(inputs)
z_mean = self.dense_mean(x)
z_log_var = self.dense_log_var(x)
z = self.sampling(z_mean, z_log_var)
return z_mean, z_log_var, z
@tf.keras.saving.register_keras_serializable()
class Decoder(tf.keras.layers.Layer):
def __init__(self, img_d, hidden_d=64, name="decoder", **kwargs):
super().__init__(name=name, **kwargs)
self.dense1 = tf.keras.layers.Dense(hidden_d, activation="relu")
self.dense2 = tf.keras.layers.Dense(img_d, activation="sigmoid")
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.keras.saving.register_keras_serializable()
class VariationalAutoEncoder(tf.keras.Model):
def __init__(self, img_d, hidden_d=64, latent_d=32,
name="autoencoder", **kwargs):
super().__init__(name=name, **kwargs)
self.img_d = img_d
self.encoder = Encoder(latent_d=latent_d, hidden_d=hidden_d)
self.decoder = Decoder(img_d=img_d, hidden_d=hidden_d)
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
kl_loss = -0.5 * tf.reduce_mean(
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
)
self.add_loss(kl_loss)
return reconstructed
vae = VariationalAutoEncoder(
img_d=img_d, hidden_d=hidden_d, latent_d=latent_d
)
optimizer = tf.keras.optimizers.Adam(1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()
"""
Fitting
"""
loss_metric = tf.keras.metrics.Mean()
for epoch in range(epochs):
print(f"Start of {epoch = }")
for step, x in enumerate(tra_ds):
with tf.GradientTape() as tape:
reconstructed = vae(x)
loss = mse_loss_fn(x, reconstructed)
loss += sum(vae.losses)
grads = tape.gradient(loss, vae.trainable_weights)
optimizer.apply_gradients(zip(grads, vae.trainable_weights))
loss_metric(loss)
if step % 100 == 0:
print(f"{step=}: mean loss = {loss_metric.result():.4f}")
"""
Save & Load
"""
vae.save("vae211.keras")
vae_ = tf.keras.models.load_model("vae211.keras") # <===== Error
decoder_ = vae.decoder
...