Getting retracing error


Any suggestions which part(s) of the following custom training loop causes this error:

WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7faa5069ec00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to Better performance with tf.function  |  TensorFlow Core and tf.function  |  TensorFlow v2.16.1 for more details.
#, kernel_regularizer=tf.keras.regularizers.l2(REG)
gwd_model = tf.keras.Sequential([layers.Dense(256, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(1, activation = return_scaled_sigmoid)])

optimizer = tf.keras.optimizers.AdamW()
loss_function = tf.keras.losses.MeanSquaredError()

ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))
print(f"Number of training batches: {ntraining_batches}, number of validation batches: {nval_batches}")

tf.keras.backend.set_value(optimizer.learning_rate, 1.0e-3)

ntraining_batch = 1
nval_batch = 1

best_val_loss = float('inf')
patience = 2  # Number of epochs to wait before reducing LR
wait = 0      # Counter for epochs waited
factor = 0.96  # Factor by which to reduce LR
min_lr = 1e-6 # Minimum learning rate

for epoch in range(NEPOCHS_MAX):    
    ntraining_batch = int(epoch * ntraining_batches / NEPOCHS)
    ntraining_batch = max(ntraining_batch, 1) 
    ntraining_batch = min(ntraining_batch, ntraining_batches) 
    nval_batch = int(epoch * nval_batches / NEPOCHS)
    nval_batch = max(nval_batch, 1) 
    nval_batch = min(nval_batch, nval_batches) 
    print(f"Epoch: {epoch}, Training batches: {ntraining_batch}, Validation batches: {nval_batch}")
    loss_mean = tf.keras.metrics.Mean()
    ibatch = 0
    for dataset_features, dataset_labels in training_dataset:
        #print("ibatch=", ibatch, " nbatches=", nbatches)
        with tf.GradientTape() as tape:
            predictions = gwd_model(dataset_features, training=True)
            loss = loss_function(dataset_labels, predictions)
        gradients = tape.gradient(loss, gwd_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))
        ibatch = ibatch + 1
        if ibatch >= ntraining_batch:
    loss = loss_mean.result()

    val_loss_mean = tf.keras.metrics.Mean()

    ibatch = 0
    for dataset_features, dataset_labels in val_dataset:
        #print("ibatch=", ibatch, " nbatches=", nbatches)
        predictions = gwd_model(dataset_features, training=False)
        val_loss = loss_function(dataset_labels, predictions)
        ibatch = ibatch + 1
        if ibatch >= nval_batch:

    val_loss = val_loss_mean.result()

    current_lr = optimizer.learning_rate
    print(f"Epoch {epoch}: Loss: {loss.numpy():.4e}, Validation Loss: {val_loss.numpy():.4e}, Learning Rate: {current_lr.numpy():.4e}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
        wait += 1
        if wait >= patience:
            new_lr = max(current_lr * factor, min_lr)
            tf.keras.backend.set_value(optimizer.learning_rate, new_lr)
            print(f"New learning rate: {new_lr:.4e}.")
            wait = 0
    logs = {'loss': loss, 'val_loss': val_loss}
    current_time =
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    layer_units = [str(layer.units) for layer in gwd_model.layers if hasattr(layer, 'units')]
    layer_units_string = 'x'.join(layer_units)
    postfix = f"{layer_units_string}-epoch-{epoch+1}-val_loss-{val_loss:.8f}-{current_time_string}.csv"
    for ilayer, layer in enumerate(gwd_model.layers):
        pd.DataFrame(layer.weights[0]).to_csv(f"/tmp5/gwies/tf/weights{ilayer}-{postfix}", header=False, index=False)
        pd.DataFrame(layer.weights[1]).to_csv(f"/tmp5/gwies/tf/bias{ilayer}-{postfix}" , header=False, index=False)
    stop = '/tf/stop'
    if os.path.exists(stop):
        print(f"\nStopping training as '{stop}' exists.")



It turns out this error is triggered by the optimizer.apply_gradients call. Wrapping the apply_gradients call in a tf_function resolves the error, but I only got it to work by passing one argument (gradients), not the model.Here is the updated code:

#First define the model and optimizer as 'arguments' for tf.function

gwd_model = tf.keras.Sequential([layers.Dense(256, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(1, activation = return_scaled_sigmoid)])
optimizer = tf.keras.optimizers.AdamW()

#tf.function wrapper for optimizer.apply_gradients
def apply_gradients(gradients):
    optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))


loss_function = tf.keras.losses.Huber()

ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))
print(f"Number of training batches: {ntraining_batches}, number of validation batches: {nval_batches}")

tf.keras.backend.set_value(optimizer.learning_rate, 1.0e-3)

ntraining_batch = 1
nval_batch = 1

best_val_loss = float('inf')
patience = 2  # Number of epochs to wait before reducing LR
wait = 0      # Counter for epochs waited
factor = 0.96  # Factor by which to reduce LR
min_lr = 1e-6 # Minimum learning rate

for epoch in range(NEPOCHS_MAX):    
    ntraining_batch = int(epoch * ntraining_batches / NEPOCHS)
    ntraining_batch = max(ntraining_batch, 1) 
    ntraining_batch = min(ntraining_batch, ntraining_batches) 
    #ntraining_batch = ntraining_batches
    nval_batch = int(epoch * nval_batches / NEPOCHS)
    nval_batch = max(nval_batch, 1) 
    nval_batch = min(nval_batch, nval_batches) 
    #nval_batch = nval_batches
    print(f"Epoch: {epoch}, Training batches: {ntraining_batch}, Validation batches: {nval_batch}")
    loss_mean = tf.keras.metrics.Mean()
    current_time =
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    print("current_time:", current_time_string)

    dataset = training_dataset.take(ntraining_batch)

    for dataset_features, dataset_labels in dataset:
        #print("ibatch", ibatch, " nbatches=", ntraining_batches)
        #loss = train_step(gwd_model, dataset_features, dataset_labels);
        with tf.GradientTape() as tape:
            predictions = gwd_model(dataset_features, training=True)
            loss = loss_function(dataset_labels, predictions)
        gradients = tape.gradient(loss, gwd_model.trainable_variables)
        #optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))
    loss = loss_mean.result()

    val_loss_mean = tf.keras.metrics.Mean()

    current_time =
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    print("current_time:", current_time_string)
    dataset = val_dataset.take(nval_batch)

    for dataset_features, dataset_labels in dataset:
        #print("ibatch=", ibatch, " nbatches=", nval_batches)
        predictions = gwd_model(dataset_features, training=False)
        val_loss = loss_function(dataset_labels, predictions)

    val_loss = val_loss_mean.result()

    current_lr = optimizer.learning_rate
    print(f"Epoch {epoch}: Loss: {loss.numpy():.4e}, Validation Loss: {val_loss.numpy():.4e}, Learning Rate: {current_lr.numpy():.4e}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
        wait += 1
        if wait >= patience:
            new_lr = max(current_lr * factor, min_lr)
            tf.keras.backend.set_value(optimizer.learning_rate, new_lr)
            print(f"New learning rate: {new_lr:.4e}.")
            wait = 0
    logs = {'loss': loss, 'val_loss': val_loss}
    current_time =
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    layer_units = [str(layer.units) for layer in gwd_model.layers if hasattr(layer, 'units')]
    layer_units_string = 'x'.join(layer_units)
    postfix = f"{layer_units_string}-epoch-{epoch+1}-val_loss-{val_loss:.8f}-{current_time_string}.csv"
    for ilayer, layer in enumerate(gwd_model.layers):
        pd.DataFrame(layer.weights[0]).to_csv(f"/tmp5/gwies/tf/weights{ilayer}-{postfix}", header=False, index=False)
        pd.DataFrame(layer.weights[1]).to_csv(f"/tmp5/gwies/tf/bias{ilayer}-{postfix}" , header=False, index=False)
    stop = '/tf/stop'
    if os.path.exists(stop):
        print(f"\nStopping training as '{stop}' exists.")

The idea behind this training loop is to slowly increase the number of training and validation batches that are used for training in NEPOCH epochs, rather than start with all batches right away. I have 192 one-hot encoded inputs that are compressed to 24 8-bit integers. Unpacking these 24 8-bit integers before training and validation takes quite some time, so slowly increasing the number of batches significantly reduces the overall training time of the network.
