Nan loss occurring when training transformer model for machine translation

I am trying to train my model, I had no issues building it but the gradients just seem to not be computing, I have tried gradient clipping and switching optimizes but they did not work I also have filtered my data to make sure no Nan values existed. Would be very helpful if someone could help me figure this out.

Code for Transformer :

import tensorflow as tf
from tensorflow.keras.layers import Dropout, MultiHeadAttention, LayerNormalization, Dense, Embedding, Input
import numpy as np

def positional_encoding(max_seq_len, d_model):
pos_enc = np.zeros((max_seq_len, d_model))

for pos in range(max_seq_len):
    for i in range(0, d_model, 2):    
        pos_enc[pos, i] = np.sin(pos / np.power(10000, (2 * i) / d_model))
        if i + 1 < d_model:
            pos_enc[pos, i + 1] = np.cos(pos / np.power(10000, (2 * i) / d_model))
return pos_enc

def create_padding_mask(seq):
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
return mask[:, tf.newaxis, tf.newaxis, :]

def encoder_layer(input, d_model, num_heads, dff, mask, training, dropout_rate=0.1):
mha_output = MultiHeadAttention(num_heads, d_model, dropout=dropout_rate)(input, input, input, attention_mask=mask)
layernorm1 = LayerNormalization(epsilon=1e-6)(input + mha_output)

ffn = Dense(dff, activation='relu')(layernorm1)
ffn = Dense(d_model)(ffn)
ffn = Dropout(dropout_rate)(ffn, training=training)

output = LayerNormalization(epsilon=1e-6)(layernorm1 + ffn)

return output

def encoder(input, d_model, num_heads, num_layers, dff, mask, training, max_seq_len, vocab_size, dropout_rate=0.1):
emb_output = Embedding(vocab_size, d_model)(input)
emb_output *= tf.math.sqrt(tf.cast(d_model, tf.float32))
pos_out = positional_encoding(max_seq_len, d_model)
emb_output += pos_out[np.newaxis, :, :]
emb_output = Dropout(dropout_rate)(emb_output, training=training)
enc_output = emb_output
for i in range(num_layers):
enc_output = encoder_layer(enc_output, d_model, num_heads, dff, mask, training, dropout_rate)
return(enc_output)

def decoder_layer(input, d_model, num_heads, dff, training, padding_mask, enc_output, dropout_rate=0.1):
mha1_output, attn_weights1 = MultiHeadAttention(num_heads, d_model, dropout=dropout_rate)(input, input, input, use_causal_mask=True, return_attention_scores=True)
layernorm1 = LayerNormalization(epsilon=1e-6)(input + mha1_output)
mha2_output, attn_weights2 = MultiHeadAttention(num_heads, d_model, dropout=dropout_rate)(layernorm1, enc_output, enc_output, padding_mask, True)
layernorm2 = LayerNormalization(epsilon=1e-6)(layernorm1 + mha2_output)

ffn = Dense(dff, activation='relu')(layernorm2)
ffn = Dense(d_model)(ffn)
ffn = Dropout(dropout_rate)(ffn, training=training)

output = LayerNormalization(epsilon=1e-6)(layernorm2 + ffn)

return output, attn_weights1, attn_weights2

def decoder(input, d_model, num_layers, num_heads, dff, training, max_seq_len, padding_mask, vocab_size, enc_output, dropout_rate=0.1):
attention_weights = {}
emb_output = Embedding(vocab_size, d_model)(input)
emb_output *= tf.math.sqrt(tf.cast(d_model, tf.float32))
pos_out = positional_encoding(max_seq_len, d_model)
emb_output += pos_out[np.newaxis, :, :]
emb_output = Dropout(dropout_rate)(emb_output, training=training)
dec_outut = emb_output

for i in range(num_layers):
    dec_outut, block1, block2 = decoder_layer(dec_outut, d_model, num_heads, dff, training, padding_mask, enc_output, dropout_rate)
    attention_weights['decoder_layer{}_block1_self_att'.format(i+1)] = block1
    attention_weights['decoder_layer{}_block2_self_att'.format(i+1)] = block2

return dec_outut, attention_weights    

def Transformer(num_layers, d_model, num_heads, dff, training, en_vocab_size, ta_vocab_size, max_seq_len, dropout_rate=0.1):
input = Input(shape=(max_seq_len,), dtype=‘int32’, name=‘inputs’)
target = Input(shape=(max_seq_len,), dtype=‘int32’, name=‘targets’)

en_mask = create_padding_mask(input)
ta_mask = create_padding_mask(target)

encoder_output = encoder(input, d_model, num_heads, num_layers, dff, en_mask, training, max_seq_len, en_vocab_size, dropout_rate=dropout_rate)
decoder_output, _ = decoder(target, d_model, num_layers, num_heads, dff, training, max_seq_len, ta_mask, ta_vocab_size, encoder_output, dropout_rate=dropout_rate)

outputs = Dense(ta_vocab_size, activation='softmax')(decoder_output)

return tf.keras.models.Model(inputs=[input, target], outputs=outputs)

class MaskedSparseCategoricalCrossentropy(tf.keras.losses.Loss):
def init(self, from_logits=False, reduction=tf.keras.losses.Reduction.AUTO, name=‘masked_sparse_categorical_crossentropy’):
super().init(reduction=reduction, name=name)
self.sparse_categorical_crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=from_logits, reduction=tf.keras.losses.Reduction.NONE)
def call(self, y_true, y_pred):
mask = tf.math.not_equal(y_true, 0)
loss = self.sparse_categorical_crossentropy(y_true, y_pred)
mask = tf.cast(mask, dtype=loss.dtype)
loss *= mask
return tf.reduce_sum(loss) / (tf.reduce_sum(mask) + tf.keras.backend.epsilon())

Code for Training:

import pickle
import numpy as np
from collections import Counter
from transformer import Transformer
import tensorflow as tf
from transformer import MaskedSparseCategoricalCrossentropy

with open(‘en-ta/English.txt’, ‘r’) as file:
en_sentences = file.readlines()
with open(‘en-ta/Tamil.txt’, ‘r’) as file:
ta_sentences = file.readlines()

TOTAL_SENTENCES = 200000
en_sentences = en_sentences[:TOTAL_SENTENCES]
ta_sentences = ta_sentences[:TOTAL_SENTENCES]
en_sentences = [sentence.rstrip(‘\n’).lower() for sentence in en_sentences]
ta_sentences = [sentence.rstrip(‘\n’) for sentence in ta_sentences]

max(len(x) for x in ta_sentences), max(len(x) for x in en_sentences)

Assuming ta_sentences and en_sentences are lists of strings (sentences)

Find the longest Tamil sentence

longest_ta_sentence = max(ta_sentences, key=len)

Find the longest English sentence

longest_en_sentence = max(en_sentences, key=len)

print(“Longest Tamil sentence:”, longest_ta_sentence)
print(“Longest English sentence:”, longest_en_sentence)

%%

PERCENTILE = 97
print( f"{PERCENTILE}th percentile length Tamil: {np.percentile([len(x) for x in ta_sentences], PERCENTILE)}" ) # roughly 250
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in en_sentences], PERCENTILE)}" ) # roughly 250

%%

START_TOKEN = ‘’
PADDING_TOKEN = ‘’
END_TOKEN = ‘’
UNKNOWN_TOKEN = ‘’

%%

ta_vocab = [PADDING_TOKEN, START_TOKEN, ’ ‘, ‘!’, ‘"’, ‘#’, ‘$’, ‘%’, ‘&’, "’", ‘(’, ‘)’, ‘*’, ‘+’, ‘,’, ‘-’, ‘.’, ‘/’,
‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’, ‘:’, ‘<’, ‘=’, ‘>’, ‘?’, ‘ˌ’,
‘ஃ’, ‘அ’, ‘ஆ’, ‘இ’, ‘ஈ’, ‘உ’, ‘ஊ’, ‘எ’, ‘ஏ’, ‘ஐ’, ‘ஒ’, ‘ஓ’, ‘ஔ’, ‘க்’, ‘க’, ‘கா’, ‘கி’, ‘கீ’, ‘கு’, ‘கூ’, ‘கெ’,
‘கே’, ‘கை’, ‘கொ’, ‘கோ’, ‘கௌ’, ‘ங்’, ‘ங’, ‘ஙா’, ‘ஙி’, ‘ஙீ’, ‘ஙு’, ‘ஙூ’, ‘ஙெ’, ‘ஙே’, ‘ஙை’, ‘ஙொ’, ‘ஙோ’, ‘ஙௌ’, ‘ச்’,
‘ச’, ‘சா’, ‘சி’, ‘சீ’, ‘சு’, ‘சூ’, ‘செ’, ‘சே’, ‘சை’, ‘சொ’, ‘சோ’, ‘சௌ’,
‘ஞ்’, ‘ஞ’, ‘ஞா’, ‘ஞி’, ‘ஞீ’, ‘ஞு’, ‘ஞூ’, ‘ஞெ’, ‘ஞே’, ‘ஞை’, ‘ஞொ’, ‘ஞோ’, ‘ஞௌ’,
‘ட்’, ‘ட’, ‘டா’, ‘டி’, ‘டீ’, ‘டு’, ‘டூ’, ‘டெ’, ‘டே’, ‘டை’, ‘டொ’, ‘டோ’, ‘டௌ’,
‘ண்’, ‘ண’, ‘ணா’, ‘ணி’, ‘ணீ’, ‘ணு’, ‘ணூ’, ‘ணெ’, ‘ணே’, ‘ணை’, ‘ணொ’, ‘ணோ’, ‘ணௌ’,
‘த்’, ‘த’, ‘தா’, ‘தி’, ‘தீ’, ‘து’, ‘தூ’, ‘தெ’, ‘தே’, ‘தை’, ‘தொ’, ‘தோ’, ‘தௌ’,
‘ந்’, ‘ந’, ‘நா’, ‘நி’, ‘நீ’, ‘நு’, ‘நூ’, ‘நெ’, ‘நே’, ‘நை’, ‘நொ’, ‘நோ’, ‘நௌ’,
‘ப்’, ‘ப’, ‘பா’, ‘பி’, ‘பீ’, ‘பு’, ‘பூ’, ‘பெ’, ‘பே’, ‘பை’, ‘பொ’, ‘போ’, ‘பௌ’,
‘ம்’, ‘ம’, ‘மா’, ‘மி’, ‘மீ’, ‘மு’, ‘மூ’, ‘மெ’, ‘மே’, ‘மை’, ‘மொ’, ‘மோ’, ‘மௌ’,
‘ய்’, ‘ய’, ‘யா’, ‘யி’, ‘யீ’, ‘யு’, ‘யூ’, ‘யெ’, ‘யே’, ‘யை’, ‘யொ’, ‘யோ’, ‘யௌ’,
‘ர்’, ‘ர’, ‘ரா’, ‘ரி’, ‘ரீ’, ‘ரு’, ‘ரூ’, ‘ரெ’, ‘ரே’, ‘ரை’, ‘ரொ’, ‘ரோ’, ‘ரௌ’,
‘ல்’, ‘ல’, ‘லா’, ‘லி’, ‘லீ’, ‘லு’, ‘லூ’, ‘லெ’, ‘லே’, ‘லை’, ‘லொ’, ‘லோ’, ‘லௌ’,
‘வ்’, ‘வ’, ‘வா’, ‘வி’, ‘வீ’, ‘வு’, ‘வூ’, ‘வெ’, ‘வே’, ‘வை’, ‘வொ’, ‘வோ’, ‘வௌ’,
‘ழ்’, ‘ழ’, ‘ழா’, ‘ழி’, ‘ழீ’, ‘ழு’, ‘ழூ’, ‘ழெ’, ‘ழே’, ‘ழை’, ‘ழொ’, ‘ழோ’, ‘ழௌ’,
‘ள்’, ‘ள’, ‘ளா’, ‘ளி’, ‘ளீ’, ‘ளு’, ‘ளூ’, ‘ளெ’, ‘ளே’, ‘ளை’, ‘ளொ’, ‘ளோ’, ‘ளௌ’,
‘ற்’, ‘ற’, ‘றா’, ‘றி’, ‘றீ’, ‘று’, ‘றூ’, ‘றெ’, ‘றே’, ‘றை’, ‘றொ’, ‘றோ’, ‘றௌ’,
‘ன்’, ‘ன’, ‘னா’, ‘னி’, ‘னீ’, ‘னு’, ‘னூ’, ‘னெ’, ‘னேனை’,
‘ஶ்’, ‘ஶ’, ‘ஶா’, ‘ஶி’, ‘ஶீ’, ‘ஶு’, ‘ஶூ’, ‘ஶெ’, ‘ஶே’, ‘ஶை’, ‘ஶொ’, ‘ஶோ’, ‘ஶௌ’,
‘ஜ்’, ‘ஜ’, ‘ஜா’, ‘ஜி’, ‘ஜீ’, ‘ஜு’, ‘ஜூ’, ‘ஜெ’, ‘ஜே’, ‘ஜை’, ‘ஜொ’, ‘ஜோ’, ‘ஜௌ’,
‘ஷ்’, ‘ஷ’, ‘ஷா’, ‘ஷி’, ‘ஷீ’, ‘ஷு’, ‘ஷூ’, ‘ஷெ’, ‘ஷே’, ‘ஷை’, ‘ஷொ’, ‘ஷோ’, ‘ஷௌ’,
‘ஸ்’, ‘ஸ’, ‘ஸா’, ‘ஸி’, ‘ஸீ’, ‘ஸு’, ‘ஸூ’, ‘ஸெ’, ‘ஸே’, ‘ஸை’, ‘ஸொ’, ‘ஸோ’, ‘ஸௌ’,
‘ஹ்’, ‘ஹ’, ‘ஹா’, ‘ஹி’, ‘ஹீ’, ‘ஹு’, ‘ஹூ’, ‘ஹெ’, ‘ஹே’, ‘ஹை’, ‘ஹொ’, ‘ஹோ’, ‘ஹௌ’,
‘க்ஷ்’, ‘க்ஷ’, ‘க்ஷா’, ‘க்ஷ’, ‘க்ஷீ’, ‘க்ஷு’, ‘க்ஷூ’, ‘க்ஷெ’, ‘க்ஷே’, ‘க்ஷை’, ‘க்ஷொ’, ‘க்ஷோ’, ‘க்ஷௌ’,
‘்’, ‘ா’, ‘ி’, ‘ீ’, ‘ு’, ‘ூ’, ‘ெ’, ‘ே’, ‘ை’, ‘ொ’, ‘ோ’, ‘ௌ’,END_TOKEN]

%%

en_vocab = [PADDING_TOKEN, START_TOKEN, ’ ‘, ‘!’, ‘"’, ‘#’, ‘$’, ‘%’, ‘&’, "’", ‘(’, ‘)’, ‘*’, ‘+’, ‘,’, ‘-’, ‘.’, ‘/’,
‘0’, ‘1’, ‘2’, ‘3’, ‘4’, ‘5’, ‘6’, ‘7’, ‘8’, ‘9’,
‘:’, ‘<’, ‘=’, ‘>’, ‘?’, ‘@’,
‘[’, ‘\’, ‘]’, ‘^’, ‘_’, ‘`’,
‘a’, ‘b’, ‘c’, ‘d’, ‘e’, ‘f’, ‘g’, ‘h’, ‘i’, ‘j’, ‘k’, ‘l’,
‘m’, ‘n’, ‘o’, ‘p’, ‘q’, ‘r’, ‘s’, ‘t’, ‘u’, ‘v’, ‘w’, ‘x’,
‘y’, ‘z’, ‘{’, ‘|’, ‘}’, ‘~’, END_TOKEN]

def is_valid_token(sentence, vocab):
return all(token in vocab for token in sentence)

def find_invalid_tokens(sentence, vocab):
return [token for token in set(sentence) if token not in vocab]

def is_valid_length(sentence, max_sequence_length):
return len(sentence) <= max_sequence_length

ta_vocab = {v:k for k,v in enumerate(ta_vocab)}
en_vocab = {v:k for k,v in enumerate(en_vocab)}

invalid_tokens_list =
valid_sentence_indices =
invalid_sentence_indices =

for index, (ta_sentence, en_sentence) in enumerate(zip(ta_sentences, en_sentences)):
invalid_ta_tokens = find_invalid_tokens(ta_sentence, ta_vocab)
invalid_en_tokens = find_invalid_tokens(en_sentence, en_vocab)

if is_valid_length(ta_sentence, 250) and is_valid_length(en_sentence, 250):
    if is_valid_token(ta_sentence, ta_vocab) and is_valid_token(en_sentence, en_vocab):
        valid_sentence_indices.append(index)
    else:
        invalid_tokens_list.append((invalid_ta_tokens, invalid_en_tokens))
        invalid_sentence_indices.append(index)

ta_sentences = [ta_sentences[i] for i in valid_sentence_indices]
en_sentences = [en_sentences[i] for i in valid_sentence_indices]

def text_to_indices(sequences, vocab):
sequences_to_ids =
for sequence in sequences:
seq = [vocab[char] for char in sequence]
sequences_to_ids.append(seq)
return sequences_to_ids

def create_decoder_sequences(sequences, vocab, max_len=250):
sos_token = vocab[‘’]
eos_token = vocab[‘’]
pad_token = vocab[‘’]

decoder_input_seqs = []
decoder_output_seqs = []

for seq in sequences:
    input_seq = [sos_token] + seq
    output_seq = seq + [eos_token]
    
    
    decoder_input_seqs.append(input_seq)
    decoder_output_seqs.append(output_seq)
    
decoder_input_seqs = tf.keras.preprocessing.sequence.pad_sequences(decoder_input_seqs, maxlen=max_len, padding='post', truncating='post', value=pad_token)
decoder_output_seqs = tf.keras.preprocessing.sequence.pad_sequences(decoder_output_seqs, maxlen=max_len, padding='post', truncating='post', value=pad_token)

return decoder_input_seqs, decoder_output_seqs

en_input = tf.keras.preprocessing.sequence.pad_sequences(text_to_indices(en_sentences, en_vocab), maxlen=250, padding=‘post’, truncating=‘post’, value=en_vocab[‘’])
decoder_input, decoder_output = create_decoder_sequences(text_to_indices(ta_sentences, ta_vocab), ta_vocab)

en_vocab_size = 71
ta_vocab_size = 367
num_layers=1
num_heads=8
d_model=512
dff=2048
dropout_rate=0.1
training=True
max_seq_len=250
batch_size = 30

model = Transformer(num_layers, d_model, num_heads, dff, training, en_vocab_size, ta_vocab_size, max_seq_len, dropout_rate=dropout_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, clipvalue=1.0)

model.compile(optimizer=optimizer, loss=MaskedSparseCategoricalCrossentropy(), metrics=[‘acc’])

history = model.fit([en_input, decoder_input], decoder_output, epochs=1, batch_size=batch_size)

Hi @KVcoder,

Apologize for the delay in response.

I’d suggest to convert the positional encoding to tensor using tf.convert_to_tensor instead of using a NumPy array which it would break gradient chain.In your custom loss function, check that the mask allows gradients to propagate by using tf.reduce_sum to handle the loss computation and make sure you are passing training=True when training, as dropout needs to be active. Validate that your input data doesn’t have NaN values, and check that the learning rate in your optimizer is suitable (try increasing).Finally, If you’re doing custom training, you need to use tf.GradientTape for the backward pass and confirm that your model has trainable parameters by printing model.summary().
Additionally, Kindly refer this tutorial about machine translation and tf.gradient tape for more information.

Hope this helps.Thank You.