GPT2:loss function is called only twice during the training

On the Keras GPT2 Text Generation tutorial code, I added my own custom loss function that print print out the shape of y_true. But surpassingly it got only printed twice not 2492(total training examples/batch_size).

# (None, 128) only got printed two times like this
(None, 128)
(None, 128)

Could anyone know why this behavior is happening?

Here’s the minimum code to replicate it

!pip install -q keras-nlp
import keras_nlp
import tensorflow as tf
from tensorflow import keras
import time

preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

import tensorflow_datasets as tfds
reddit_ds = tfds.load("reddit_tifu", split="train", as_supervised=True)
train_ds = (
    reddit_ds.map(lambda document, _: document)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# the custom loss function
def custom_sparse_categorical_crossentropy(y_true, y_pred):
  scc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  print(y_true.shape)
  return scc(y_true, y_pred)

train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=custom_sparse_categorical_crossentropy,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

Hi @Seungjun_Lee, The print statements will only be executed on the first call to the function if the function is called multiple times with the same arguments. You can tf.print to overcome this.

def custom_sparse_categorical_crossentropy(y_true, y_pred):
  scc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  # print(y_true.shape)
  tf.print(y_true.shape)
  return scc(y_true, y_pred)

As your model is using XLA compile tf.print is not supported. You can make it use by disabling the XLA compilation by

gpt2_lm.jit_compile= False

and then train the model

gpt2_lm.fit(train_ds, epochs=num_epochs)

Please refer to this gist for working code example. Thank You.

1 Like

Thanks for the reply, but above code actually doesn’t work. It rather cause Graph execution error

Hi @Seungjun_Lee, I have executed the code with the tensorflow 2.13 in colab and did not face any error. Could you please share the details about the tensortflow version, error log and environment you are using to execute the code. Thank You

I found out that cause was putting this line of code at the beginning like this

gpt2_lm.jit_compile= False
gpt2_lm.jit_compile= False # here

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=custom_sparse_categorical_crossentropy,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

seems like some how while executing rest of the code, it set gpt2_lm.jit_compile to be True.

I tried putting that line of code right before fit function, and this time it worked.

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=custom_sparse_categorical_crossentropy,
    weighted_metrics=["accuracy"],
)

gpt2_lm.jit_compile= False # here
gpt2_lm.fit(train_ds, epochs=num_epochs)

and I was using the latest version of Tensorflow