Hi,
Any suggestions which part(s) of the following custom training loop causes this error:
WARNING:tensorflow:5 out of the last 5 calls to <function _BaseOptimizer._update_step_xla at 0x7faa5069ec00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to Better performance with tf.function | TensorFlow Core and tf.function | TensorFlow v2.16.1 for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function _BaseOptimizer._update_step_xla at 0x7faa5069ec00> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to Better performance with tf.function | TensorFlow Core and tf.function | TensorFlow v2.16.1 for more details.
NEPOCHS = 20
NEPOCHS_MAX = 200
REG=0.01
#os.remove('/tf/stop')
#, kernel_regularizer=tf.keras.regularizers.l2(REG)
gwd_model = tf.keras.Sequential([layers.Dense(256, activation = "relu"),
layers.Dense(16, activation = "relu"),
layers.Dense(16, activation = "relu"),
layers.Dense(1, activation = return_scaled_sigmoid)])
optimizer = tf.keras.optimizers.AdamW()
loss_function = tf.keras.losses.MeanSquaredError()
ntraining_batches = len(list(training_dataset))
nval_batches = len(list(val_dataset))
print(f"Number of training batches: {ntraining_batches}, number of validation batches: {nval_batches}")
tf.keras.backend.set_value(optimizer.learning_rate, 1.0e-3)
ntraining_batch = 1
nval_batch = 1
best_val_loss = float('inf')
patience = 2 # Number of epochs to wait before reducing LR
wait = 0 # Counter for epochs waited
factor = 0.96 # Factor by which to reduce LR
min_lr = 1e-6 # Minimum learning rate
for epoch in range(NEPOCHS_MAX):
ntraining_batch = int(epoch * ntraining_batches / NEPOCHS)
ntraining_batch = max(ntraining_batch, 1)
ntraining_batch = min(ntraining_batch, ntraining_batches)
nval_batch = int(epoch * nval_batches / NEPOCHS)
nval_batch = max(nval_batch, 1)
nval_batch = min(nval_batch, nval_batches)
print(f"Epoch: {epoch}, Training batches: {ntraining_batch}, Validation batches: {nval_batch}")
loss_mean = tf.keras.metrics.Mean()
ibatch = 0
for dataset_features, dataset_labels in training_dataset:
#print("ibatch=", ibatch, " nbatches=", nbatches)
with tf.GradientTape() as tape:
predictions = gwd_model(dataset_features, training=True)
loss = loss_function(dataset_labels, predictions)
gradients = tape.gradient(loss, gwd_model.trainable_variables)
optimizer.apply_gradients(zip(gradients, gwd_model.trainable_variables))
loss_mean.update_state(loss)
ibatch = ibatch + 1
if ibatch >= ntraining_batch:
break
loss = loss_mean.result()
val_loss_mean = tf.keras.metrics.Mean()
ibatch = 0
for dataset_features, dataset_labels in val_dataset:
#print("ibatch=", ibatch, " nbatches=", nbatches)
predictions = gwd_model(dataset_features, training=False)
val_loss = loss_function(dataset_labels, predictions)
val_loss_mean.update_state(val_loss)
ibatch = ibatch + 1
if ibatch >= nval_batch:
break
val_loss = val_loss_mean.result()
current_lr = optimizer.learning_rate
print(f"Epoch {epoch}: Loss: {loss.numpy():.4e}, Validation Loss: {val_loss.numpy():.4e}, Learning Rate: {current_lr.numpy():.4e}")
if val_loss < best_val_loss:
best_val_loss = val_loss
wait = 0
else:
wait += 1
if wait >= patience:
new_lr = max(current_lr * factor, min_lr)
tf.keras.backend.set_value(optimizer.learning_rate, new_lr)
print(f"New learning rate: {new_lr:.4e}.")
wait = 0
logs = {'loss': loss, 'val_loss': val_loss}
current_time = datetime.now()
current_time_string = current_time.strftime("%H%M%S-%Y%m%d")
layer_units = [str(layer.units) for layer in gwd_model.layers if hasattr(layer, 'units')]
layer_units_string = 'x'.join(layer_units)
postfix = f"{layer_units_string}-epoch-{epoch+1}-val_loss-{val_loss:.8f}-{current_time_string}.csv"
for ilayer, layer in enumerate(gwd_model.layers):
pd.DataFrame(layer.weights[0]).to_csv(f"/tmp5/gwies/tf/weights{ilayer}-{postfix}", header=False, index=False)
pd.DataFrame(layer.weights[1]).to_csv(f"/tmp5/gwies/tf/bias{ilayer}-{postfix}" , header=False, index=False)
stop = '/tf/stop'
if os.path.exists(stop):
print(f"\nStopping training as '{stop}' exists.")
break
Regards,
GW