Save load and retrain machine translation model (help needed)

I’ve been trying train a model for machine translation. It worked pretty fine when I trained it for 10 epochs at a time and tested it. But when I try to train it for 1 epoch at a time, save and load it to continue from where I left earlier it gives some errors.

import tensorflow as tf
import einops
import numpy as np
import os
import tensorflow as tf
import tensorflow_text as tf_text
import pathlib

from keras.layers import TextVectorization

class ShapeChecker():
    def __init__(self):
        self.shapes = {}

    def __call__(self, tensor, names, broadcast=False):
        if not tf.executing_eagerly():
            return

        parsed = einops.parse_shape(tensor, names)

        for name, new_dim in parsed.items():
            old_dim = self.shapes.get(name, None)

            if broadcast and new_dim == 1:
                continue

            if old_dim is None:
                self.shapes[name] = new_dim
                continue

            if new_dim != old_dim:
                raise ValueError(f"Shape mismatch for dimension: '{name}'\n"
                                 f"    found: {new_dim}\n"
                                 f"    expected: {old_dim}\n")


class Encoder(tf.keras.layers.Layer):
    def __init__(self, text_processor, units):
        super(Encoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.units = units

        self.embedding = tf.keras.layers.Embedding(self.vocab_size, units,
                                                   mask_zero=True)

        self.rnn = tf.keras.layers.Bidirectional(
            merge_mode='sum',
            layer=tf.keras.layers.GRU(units,
                                      return_sequences=True,
                                      recurrent_initializer='glorot_uniform'))

    def call(self, x):
        shape_checker = ShapeChecker()
        shape_checker(x, 'batch s')

        x = self.embedding(x)
        shape_checker(x, 'batch s units')

        x = self.rnn(x)
        shape_checker(x, 'batch s units')

        return x

    def convert_input(self, texts):
        texts = tf.convert_to_tensor(texts)
        if len(texts.shape) == 0:
            texts = tf.convert_to_tensor(texts)[tf.newaxis]
        context = self.text_processor(texts).to_tensor()
        context = self(context)
        return context


class CrossAttention(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(key_dim=units, num_heads=1, **kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()

    def call(self, x, context):
        shape_checker = ShapeChecker()

        shape_checker(x, 'batch t units')
        shape_checker(context, 'batch s units')

        attn_output, attn_scores = self.mha(
            query=x,
            value=context,
            return_attention_scores=True)

        shape_checker(x, 'batch t units')
        shape_checker(attn_scores, 'batch heads t s')

        attn_scores = tf.reduce_mean(attn_scores, axis=1)
        shape_checker(attn_scores, 'batch t s')
        self.last_attention_weights = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x


class Decoder(tf.keras.layers.Layer):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, text_processor, units):
        super(Decoder, self).__init__()
        self.text_processor = text_processor
        self.vocab_size = text_processor.vocabulary_size()
        self.word_to_id = tf.keras.layers.StringLookup(
            vocabulary=text_processor.get_vocabulary(),
            mask_token='', oov_token='[UNK]')
        self.id_to_word = tf.keras.layers.StringLookup(
            vocabulary=text_processor.get_vocabulary(),
            mask_token='', oov_token='[UNK]',
            invert=True)
        self.start_token = self.word_to_id('[START]')
        self.end_token = self.word_to_id('[END]')

        self.units = units

        self.embedding = tf.keras.layers.Embedding(self.vocab_size, units, mask_zero=True)

        self.rnn = tf.keras.layers.GRU(units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

        self.attention = CrossAttention(units)

        self.output_layer = tf.keras.layers.Dense(self.vocab_size)


@Decoder.add_method
def call(self,
         context, x,
         state=None,
         return_state=False):
    shape_checker = ShapeChecker()
    shape_checker(x, 'batch t')
    shape_checker(context, 'batch s units')

    x = self.embedding(x)
    shape_checker(x, 'batch t units')

    x, state = self.rnn(x, initial_state=state)
    shape_checker(x, 'batch t units')

    x = self.attention(x, context)
    self.last_attention_weights = self.attention.last_attention_weights
    shape_checker(x, 'batch t units')
    shape_checker(self.last_attention_weights, 'batch t s')

    logits = self.output_layer(x)
    shape_checker(logits, 'batch t target_vocab_size')

    if return_state:
        return logits, state
    else:
        return logits


@Decoder.add_method
def get_initial_state(self, context):
    batch_size = tf.shape(context)[0]
    start_tokens = tf.fill([batch_size, 1], self.start_token)
    done = tf.zeros([batch_size, 1], dtype=tf.bool)
    embedded = self.embedding(start_tokens)
    return start_tokens, done, self.rnn.get_initial_state(embedded)[0]


@Decoder.add_method
def tokens_to_text(self, tokens):
    words = self.id_to_word(tokens)
    result = tf.strings.reduce_join(words, axis=-1, separator=' ')
    result = tf.strings.regex_replace(result, '^ *\[START\] *', '')
    result = tf.strings.regex_replace(result, ' *\[END\] *$', '')
    return result


@Decoder.add_method
def get_next_token(self, context, next_token, done, state, temperature=0.0):
    logits, state = self(
        context, next_token,
        state=state,
        return_state=True)

    if temperature == 0.0:
        next_token = tf.argmax(logits, axis=-1)
    else:
        logits = logits[:, -1, :] / temperature
        next_token = tf.random.categorical(logits, num_samples=1)

    done = done | (next_token == self.end_token)
    next_token = tf.where(done, tf.constant(0, dtype=tf.int64), next_token)

    return next_token, done, state


class Translator(tf.keras.Model):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, units,
                 context_text_processor,
                 target_text_processor):
        super().__init__()

        self.ctp = context_text_processor
        self.ttp = target_text_processor

        encoder = Encoder(context_text_processor, units)
        decoder = Decoder(target_text_processor, units)

        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)
        logits = self.decoder(context, x)

        return logits

    def get_config(self):
        return {
            "context": self.ctp,
            "target": self.ttp
        }


path_to_zip = tf.keras.utils.get_file(
    'spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',
    extract=True)

path_to_file = pathlib.Path(path_to_zip).parent / 'spa-eng/spa.txt'


def load_data(path):
    text = path.read_text(encoding='utf-8')
    lines = text.splitlines()

    pairs = [line.split('\t') for line in lines]

    context = np.array([context for target, context in pairs])
    target = np.array([target for target, context in pairs])

    return target, context


target_raw, context_raw = load_data(path_to_file)

buffer_size = len(context_raw)
batch_size = 64

is_train = np.random.uniform(size=(len(target_raw),)) < .8

train_raw = (
    tf.data.Dataset
    .from_tensor_slices((context_raw[is_train], target_raw[is_train]))
    .shuffle(buffer_size)
    .batch(batch_size)
)

val_raw = (
    tf.data.Dataset
    .from_tensor_slices((context_raw[~is_train], target_raw[~is_train]))
    .shuffle(buffer_size)
    .batch(batch_size)
)

for example_context_strings, example_target_strings in train_raw.take(1):
    print()
    break

example_text = tf.constant('¿Todavía está en casa?')


def tf_lower_and_split_punctuation(text):
    # Split accented characters.
    text = tf_text.normalize_utf8(text, 'NFKD')
    text = tf.strings.lower(text)
    # Keep space, a to z, and select punctuation.
    text = tf.strings.regex_replace(text, '[^ a-z.?!,¿]', '')
    # Add spaces around punctuation.
    text = tf.strings.regex_replace(text, '[.?!,¿]', r' \0 ')
    # Strip whitespace.
    text = tf.strings.strip(text)

    text = tf.strings.join(['[START]', text, '[END]'], separator=' ')
    return text


max_vocab_size = 5000
context_text_processor = TextVectorization(
    standardize=tf_lower_and_split_punctuation,
    max_tokens=max_vocab_size,
    ragged=True
)

context_text_processor.adapt(train_raw.map(lambda context, target: context))

target_text_processor = TextVectorization(
    standardize=tf_lower_and_split_punctuation,
    max_tokens=max_vocab_size,
    ragged=True
)

target_text_processor.adapt(train_raw.map(lambda context, target: target))

example_tokens = context_text_processor(example_context_strings)

context_vocab = np.array(context_text_processor.get_vocabulary())
tokens = context_vocab[example_tokens[0].numpy()]
' '.join(tokens)


def process_text(context, target):
    context = context_text_processor(context).to_tensor()
    target = target_text_processor(target)
    targ_in = target[:, :-1].to_tensor()
    targ_out = target[:, 1:].to_tensor()
    return (context, targ_in), targ_out


train_ds = train_raw.map(process_text, tf.data.AUTOTUNE)
val_ds = val_raw.map(process_text, tf.data.AUTOTUNE)

for (ex_context_tok, ex_tar_in), ex_tar_out in train_ds.take(1):
    print()

def masked_loss(y_true, y_pred):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    loss = loss_fn(y_true, y_pred)

    mask = tf.cast(y_true != 0, loss.dtype)
    loss *= mask

    return tf.reduce_sum(loss) / tf.reduce_sum(mask)


def masked_acc(y_true, y_pred):
    y_pred = tf.argmax(y_pred, axis=-1)
    y_pred = tf.cast(y_pred, y_true.dtype)

    match = tf.cast(y_true == y_pred, tf.float32)
    mask = tf.cast(y_true != 0, tf.float32)

    return tf.reduce_sum(match) / tf.reduce_sum(mask)


UNITS = 256

model = Translator(UNITS, context_text_processor, target_text_processor)
model.compile(
    optimizer='adam',
    loss=masked_loss,
    metrics=[masked_acc, masked_loss]
)
vocab_size = 1.0 * target_text_processor.vocabulary_size()

model_path = "./Saved Model"

initial_epoch = 0

os.makedirs(model_path, exist_ok=True)

for (dir_path, dir_names, filenames) in os.walk(model_path):

    if len(dir_names) != 0:
        dir_names.sort()

        initial_epoch = int(dir_names[-1])

        model = tf.keras.models.load_model(os.path.join(model_path, dir_names[-1]))
    else:
        model.compile(
            optimizer='adam',
            loss=masked_loss,
            metrics=[masked_acc, masked_loss]
        )
    break

history = model.fit(
    train_ds.repeat(),
    initial_epoch=initial_epoch,
    epochs=initial_epoch + 1,
    steps_per_epoch=100,
    validation_data=val_ds,
    validation_steps=20,
)

model.save(os.path.join(model_path, f'{initial_epoch + 1:02d}'))

This is not even my code. It’s given in tensorflow doc. I just modified it to train for multiple epochs separately.

When I try to train, first epoch runs smoothly. But while trying to load to train for 2nd epoch, the following error message is shown:

RuntimeError: Unable to restore object of class ‘TextVectorization’. One of several possible causes could be a missing custom object. Decorate your custom object with @keras.utils.register_keras_serializable and include that file in your program, or pass your class in a keras.utils.CustomObjectScope that wraps this load call.

Exception: Error when deserializing class ‘TextVectorization’ using config={‘name’: ‘text_vectorization’, ‘trainable’: True, ‘dtype’: ‘string’, ‘batch_input_shape’: (None,), ‘max_tokens’: 5000, ‘standardize’: ‘tf_lower_and_split_punctuation’, ‘split’: ‘whitespace’, ‘ngrams’: None, ‘output_mode’: ‘int’, ‘output_sequence_length’: None, ‘pad_to_max_tokens’: False, ‘sparse’: False, ‘ragged’: True, ‘vocabulary’: None, ‘idf_weights’: None, ‘encoding’: ‘utf-8’, ‘vocabulary_size’: 5000, ‘has_input_vocabulary’: False}.

Exception encountered: Unkown value for standardize argument of layer TextVectorization. If restoring a model and standardize is a custom callable, please ensure the callable is registered as a custom object. See Save, serialize, and export models  |  TensorFlow Core for details. Allowed values are: None, a Callable, or one of the following values: (‘lower_and_strip_punctuation’, ‘lower’, ‘strip_punctuation’). Received: tf_lower_and_split_punctuation

Does anyone have any idea why it’s happening?
Thanks.

Hi @JR_Jahed, I have executed the given code in the colab but i did not face any error. Could you please provide which version of tensorflow you are using and in which environment you are executing the code? Thank You.

@Kiran_Sai_Ramineni

I’m using PyCharm IDE and miniconda env. tensorflow verion is 2.12.0 and tensorflow_text version is 2.12.1.
After reading your comment I’ve also executed the code in colab and faced the same problem during 2nd epoch. For how many epochs have you run the code?