I have a custom TensorFlow TokenAndEmbedding layer, and while training and saving the model goes well, I get issues with deserializing the model. For loading I just use tf.keras.models.load_model. Can anyone help me find the problem?
Below is the layer architecture, and below that is the error I get.
@tf.keras.utils.register_keras_serializable(package=“Custom”, name=“TokenAndPositionEmbedding”)
class TokenAndPositionEmbedding(layers.Layer):
def init(self, max_len, vocab_size, embed_dim, **kwargs): super(TokenAndPositionEmbedding, self).init(**kwargs)
self.max_len = max_len
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
def call(self, x):
maxlen = tf.shape(x)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
def get_config(self):
config = super().get_config()
config.update(
{
"max_len": self.max_len,
"vocab_size": self.vocab_size,
"embed_dim": self.embed_dim,
}
)
return config
@classmethod
def from_config(cls, config):
return cls(**config)
could not be deserialized properly. Please ensure that components that are Python object instances (layers, models, etc.) returned by get_config()
are explicitly deserialized in the model’s from_config()
method.
config={‘module’: ‘keras.src.models.functional’, ‘class_name’: ‘Functional’, ‘config’: {}, ‘registered_name’: ‘Functional’, ‘build_config’: {‘input_shape’: None}, ‘compile_config’: {‘optimizer’: {‘module’: ‘keras.optimizers’, ‘class_name’: ‘Adam’, ‘config’: {‘name’: ‘adam’, ‘learning_rate’: 0.0010000000474974513, ‘weight_decay’: None, ‘clipnorm’: None, ‘global_clipnorm’: None, ‘clipvalue’: None, ‘use_ema’: False, ‘ema_momentum’: 0.99, ‘ema_overwrite_frequency’: None, ‘loss_scale_factor’: None, ‘gradient_accumulation_steps’: None, ‘beta_1’: 0.9, ‘beta_2’: 0.999, ‘epsilon’: 1e-07, ‘amsgrad’: False}, ‘registered_name’: None}, ‘loss’: [{‘module’: ‘keras.losses’, ‘class_name’: ‘SparseCategoricalCrossentropy’, ‘config’: {‘name’: ‘sparse_categorical_crossentropy’, ‘reduction’: ‘sum_over_batch_size’, ‘from_logits’: False, ‘ignore_class’: None}, ‘registered_name’: None}], ‘loss_weights’: None, ‘metrics’: None, ‘weighted_metrics’: None, ‘run_eagerly’: False, ‘steps_per_execution’: 1, ‘jit_compile’: False}}.
Exception encountered: Could not locate class ‘TokenAndPositionEmbedding’. Make sure custom classes are decorated with @keras.saving.register_keras_serializable()
. Full object config: {‘module’: None, ‘class_name’: ‘TokenAndPositionEmbedding’, ‘config’: {‘name’: ‘token_and_position_embedding’, ‘trainable’: True, ‘dtype’: {‘module’: ‘keras’, ‘class_name’: ‘DTypePolicy’, ‘config’: {‘name’: ‘float32’}, ‘registered_name’: None, ‘shared_object_id’: 6043273584}, ‘max_len’: 1000, ‘vocab_size’: 50000, ‘embed_dim’: 256}, ‘registered_name’: ‘Custom>TokenAndPositionEmbedding’, ‘build_config’: {‘input_shape’: [None, 1000]}, ‘name’: ‘token_and_position_embedding’, ‘inbound_nodes’: [{‘args’: [{‘class_name’: ‘keras_tensor’, ‘config’: {‘shape’: [None, 1000], ‘dtype’: ‘int64’, ‘keras_history’: [‘text_vectorization’, 0, 0]}}], ‘kwargs’: {}}]}