Getting retracing error

Hi,

Any suggestions which part(s) of the following custom training loop causes this error:

WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7faa5069ec00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to Better performance with tf.function  |  TensorFlow Core and tf.function  |  TensorFlow v2.16.1 for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7faa5069ec00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to Better performance with tf.function  |  TensorFlow Core and tf.function  |  TensorFlow v2.16.1 for more details.

NEPOCHS = 20
NEPOCHS_MAX = 200
REG=0.01
#os.remove('/tf/stop')
#, kernel_regularizer=tf.keras.regularizers.l2(REG)
gwd_model = tf.keras.Sequential([layers.Dense(256, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(1, activation = return_scaled_sigmoid)])

optimizer = tf.keras.optimizers.AdamW()
loss_function = tf.keras.losses.MeanSquaredError()

ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))
print(f"Number of training batches: {ntraining_batches}, number of validation batches: {nval_batches}")

tf.keras.backend.set_value(optimizer.learning_rate, 1.0e-3)

ntraining_batch = 1
nval_batch = 1

best_val_loss = float('inf')
patience = 2  # Number of epochs to wait before reducing LR
wait = 0      # Counter for epochs waited
factor = 0.96  # Factor by which to reduce LR
min_lr = 1e-6 # Minimum learning rate

for epoch in range(NEPOCHS_MAX):    
    ntraining_batch = int(epoch * ntraining_batches / NEPOCHS)
    ntraining_batch = max(ntraining_batch, 1) 
    ntraining_batch = min(ntraining_batch, ntraining_batches) 
    nval_batch = int(epoch * nval_batches / NEPOCHS)
    nval_batch = max(nval_batch, 1) 
    nval_batch = min(nval_batch, nval_batches) 
                          
    print(f"Epoch: {epoch}, Training batches: {ntraining_batch}, Validation batches: {nval_batch}")
    
    loss_mean = tf.keras.metrics.Mean()
    
    ibatch = 0
    for dataset_features, dataset_labels in training_dataset:
        #print("ibatch=", ibatch, " nbatches=", nbatches)
        with tf.GradientTape() as tape:
            predictions = gwd_model(dataset_features, training=True)
            loss = loss_function(dataset_labels, predictions)
        gradients = tape.gradient(loss, gwd_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))
        loss_mean.update_state(loss)
        ibatch = ibatch + 1
        if ibatch >= ntraining_batch:
            break
            
    loss = loss_mean.result()

    val_loss_mean = tf.keras.metrics.Mean()

    ibatch = 0
    for dataset_features, dataset_labels in val_dataset:
        #print("ibatch=", ibatch, " nbatches=", nbatches)
        predictions = gwd_model(dataset_features, training=False)
        val_loss = loss_function(dataset_labels, predictions)
        val_loss_mean.update_state(val_loss)
        ibatch = ibatch + 1
        if ibatch >= nval_batch:
            break

    val_loss = val_loss_mean.result()

    current_lr = optimizer.learning_rate
    
    print(f"Epoch {epoch}: Loss: {loss.numpy():.4e}, Validation Loss: {val_loss.numpy():.4e}, Learning Rate: {current_lr.numpy():.4e}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            new_lr = max(current_lr * factor, min_lr)
            tf.keras.backend.set_value(optimizer.learning_rate, new_lr)
            print(f"New learning rate: {new_lr:.4e}.")
            wait = 0
            
    logs = {'loss': loss, 'val_loss': val_loss}
    
    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    layer_units = [str(layer.units) for layer in gwd_model.layers if hasattr(layer, 'units')]
    layer_units_string = 'x'.join(layer_units)
    postfix = f"{layer_units_string}-epoch-{epoch+1}-val_loss-{val_loss:.8f}-{current_time_string}.csv"
    for ilayer, layer in enumerate(gwd_model.layers):
        pd.DataFrame(layer.weights[0]).to_csv(f"/tmp5/gwies/tf/weights{ilayer}-{postfix}", header=False, index=False)
        pd.DataFrame(layer.weights[1]).to_csv(f"/tmp5/gwies/tf/bias{ilayer}-{postfix}" , header=False, index=False)
        
    stop = '/tf/stop'
    if os.path.exists(stop):
        print(f"\nStopping training as '{stop}' exists.")
        break

Regards,
GW

Hi,

It turns out this error is triggered by the optimizer.apply_gradients call. Wrapping the apply_gradients call in a tf_function resolves the error, but I only got it to work by passing one argument (gradients), not the model.Here is the updated code:

#First define the model and optimizer as 'arguments' for tf.function

gwd_model = tf.keras.Sequential([layers.Dense(256, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(16, activation = "relu"),
                               layers.Dense(1, activation = return_scaled_sigmoid)])
optimizer = tf.keras.optimizers.AdamW()

#tf.function wrapper for optimizer.apply_gradients
@tf.function
def apply_gradients(gradients):
    optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))

NEPOCHS = 20
NEPOCHS_MAX = 200
REG=0.01
#os.remove('/tf/stop')

loss_function = tf.keras.losses.Huber()

ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))
print(f"Number of training batches: {ntraining_batches}, number of validation batches: {nval_batches}")

tf.keras.backend.set_value(optimizer.learning_rate, 1.0e-3)

ntraining_batch = 1
nval_batch = 1

best_val_loss = float('inf')
patience = 2  # Number of epochs to wait before reducing LR
wait = 0      # Counter for epochs waited
factor = 0.96  # Factor by which to reduce LR
min_lr = 1e-6 # Minimum learning rate

for epoch in range(NEPOCHS_MAX):    
    ntraining_batch = int(epoch * ntraining_batches / NEPOCHS)
    ntraining_batch = max(ntraining_batch, 1) 
    ntraining_batch = min(ntraining_batch, ntraining_batches) 
    #ntraining_batch = ntraining_batches
    nval_batch = int(epoch * nval_batches / NEPOCHS)
    nval_batch = max(nval_batch, 1) 
    nval_batch = min(nval_batch, nval_batches) 
    #nval_batch = nval_batches
                          
    print(f"Epoch: {epoch}, Training batches: {ntraining_batch}, Validation batches: {nval_batch}")
    
    loss_mean = tf.keras.metrics.Mean()
    
    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    print("current_time:", current_time_string)

    dataset = training_dataset.take(ntraining_batch)

    for dataset_features, dataset_labels in dataset:
        #print("ibatch", ibatch, " nbatches=", ntraining_batches)
        #loss = train_step(gwd_model, dataset_features, dataset_labels);
        with tf.GradientTape() as tape:
            predictions = gwd_model(dataset_features, training=True)
            loss = loss_function(dataset_labels, predictions)
        gradients = tape.gradient(loss, gwd_model.trainable_variables)
        apply_gradients(gradients)
        #optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))
        loss_mean.update_state(loss)
            
    loss = loss_mean.result()

    val_loss_mean = tf.keras.metrics.Mean()

    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    print("current_time:", current_time_string)
    
    dataset = val_dataset.take(nval_batch)

    for dataset_features, dataset_labels in dataset:
        #print("ibatch=", ibatch, " nbatches=", nval_batches)
        predictions = gwd_model(dataset_features, training=False)
        val_loss = loss_function(dataset_labels, predictions)
        val_loss_mean.update_state(val_loss)

    val_loss = val_loss_mean.result()

    current_lr = optimizer.learning_rate
    
    print(f"Epoch {epoch}: Loss: {loss.numpy():.4e}, Validation Loss: {val_loss.numpy():.4e}, Learning Rate: {current_lr.numpy():.4e}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            new_lr = max(current_lr * factor, min_lr)
            tf.keras.backend.set_value(optimizer.learning_rate, new_lr)
            print(f"New learning rate: {new_lr:.4e}.")
            wait = 0
            
    logs = {'loss': loss, 'val_loss': val_loss}
    
    current_time = datetime.now()
    current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
    layer_units = [str(layer.units) for layer in gwd_model.layers if hasattr(layer, 'units')]
    layer_units_string = 'x'.join(layer_units)
    postfix = f"{layer_units_string}-epoch-{epoch+1}-val_loss-{val_loss:.8f}-{current_time_string}.csv"
    for ilayer, layer in enumerate(gwd_model.layers):
        pd.DataFrame(layer.weights[0]).to_csv(f"/tmp5/gwies/tf/weights{ilayer}-{postfix}", header=False, index=False)
        pd.DataFrame(layer.weights[1]).to_csv(f"/tmp5/gwies/tf/bias{ilayer}-{postfix}" , header=False, index=False)
        
    stop = '/tf/stop'
    if os.path.exists(stop):
        print(f"\nStopping training as '{stop}' exists.")
        break

The idea behind this training loop is to slowly increase the number of training and validation batches that are used for training in NEPOCH epochs, rather than start with all batches right away. I have 192 one-hot encoded inputs that are compressed to 24 8-bit integers. Unpacking these 24 8-bit integers before training and validation takes quite some time, so slowly increasing the number of batches significantly reduces the overall training time of the network.

Regards,
GW