Tensorflow retracing issue?

Hi there,
I am trying to develop a transformer sequence to vector model but encounter performance issues.
I am working with a Tesla V100-PCIE-16GB. Whenever the model encounters an unseen sequence length, the (training-) step takes much longer. I assume this is due to the tracing algorithm, creating the tensorflow graph. However, I do not get any retracing warnings or so. I know there is the @tf.function(reduce_retracing=True) decorator, but I don’t want to rewrite the MultiHeadAttention class from scratch and decorate every function.
I am using tensorflow Version 2.17.0.

Here is code for replication. In this snipped I have a small exemplary network, training on 100 samples. The first 10 batches have varying shapes and take around 1.5 seconds per batch. The last 90 batches have similar shapes, taking 0.002 seconds per batch.
Any ideas on how to approach this issue?

import datetime
import numpy as np
import tensorflow as tf
from tensorflow import keras

# only use the first gpu
gpus = tf.config.list_physical_devices("GPU")
assert gpus, "no gpus found"
gpu=gpus[1]
tf.config.set_visible_devices([gpu], "GPU")
tf.config.experimental.set_memory_growth(gpu, True)

nclasses = 3
d_model = 8


# A custom Data Generator, creating random samples
class DG(keras.utils.PyDataset):
    def __init__(self):
        self.nSamplesWithAlternatingShapes = 10
        super().__init__()
        self.last=datetime.datetime.now()

    def __len__(self):
        return 100

    def __getitem__(self, index):
        # not using index, to have the first nSamplesWithAlternatingShapes samples with different shapes
        # and the last n samples with same shape. index is random, due to keras.

        index += 1

        n = self.nSamplesWithAlternatingShapes
        if index < self.nSamplesWithAlternatingShapes:
            n = index
        x = np.random.random((32, n + 10, d_model))
        y = np.random.random((32, nclasses))
        y = np.argmax(y, axis=1)
        y = tf.keras.utils.to_categorical(y, num_classes=nclasses)
        now=datetime.datetime.now()
        print(f"\n{round((now-self.last).total_seconds(), ndigits=3)} secs elapsed")
        self.last=now
        return x, y

# A Test layer, incorporating a MultiHeadAttention Layer
class TestLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        # some multi head attention
        self.mha = tf.keras.layers.MultiHeadAttention(
            num_heads=1, value_dim=d_model, key_dim=d_model, dropout=0.3
        )
        self.dense = tf.keras.layers.Dense(nclasses, activation="softmax")

    def call(self, x):
        x = self.mha(x, x, x)
        pool = tf.keras.layers.GlobalAveragePooling1D(data_format="channels_last")(x)
        fin = self.dense(pool)
        return fin


if __name__ == "__main__":
    
    # The first dimension is the sequence dimension
    input = tf.keras.layers.Input((None, d_model), name="test", dtype=np.float32)
    output = TestLayer()(input)
    model = tf.keras.models.Model(inputs=[input], outputs=output)
    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=[tf.keras.metrics.CategoricalAccuracy(name="OA")],
    )
    
    model.fit(DG(), shuffle=False)
    
    print("done")

Hi @David_Eidmann,

Thanks for using tensorflow forum.
I recommend to wrap the call method of your custom layer (TestLayer) with the tf.function and specify the input_signature to reduce unnecessary retracing.This would helps tensorflow to understand the expected input shape when different shapes are used. Please refer this documentation about tf.function for more information.
Eg:

@tf.function(input_signature=[tf.TensorSpec(shape=[None, None, d_model], dtype=tf.float32)], reduce_retracing=True)
def call(self, x):
    ...