Hello. I have been trying to decorate parts of my code using tf.function. However, I have been noticing some strange behavior when trying to reproduce the results with and without tf.function. I first noticed this when training GANs, but have come up with the following short representative example to demonstrate this.
import numpy as np
import tensorflow as tf
from tensorflow import keras
tf.random.set_seed(1)
# Training data
x = tf.expand_dims(tf.linspace(0,1,200),axis=1) # Needs to be a column for training
x = tf.cast(x,dtype=tf.float32)
y = tf.math.sin(10*np.pi*x)
# Model
model = tf.keras.Sequential()
model.add(keras.layers.Dense(50, input_shape=(1,), activation='tanh'))
model.add(keras.layers.Dense(50, activation='tanh'))
model.add(keras.layers.Dense(50, activation='tanh'))
model.add(keras.layers.Dense(1))
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
# Training step
@tf.function
def train_step(x,y):
with tf.GradientTape() as tape:
pred = model(x)
loss = tf.reduce_mean(tf.math.squared_difference(y, pred))
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
# Training loop
max_epoch = 5000
for epoch in range(max_epoch):
loss = train_step(x,y)
if epoch == 0 or (epoch+1) % 500 == 0:
print(f"Epoch: {epoch+1}, loss: {loss.numpy():.12f}")
When I run the above code, I get the following output:
Epoch: 1, loss: 0.526220142841
Epoch: 500, loss: 0.485644072294
Epoch: 1000, loss: 0.457582473755
Epoch: 1500, loss: 0.081991739571
Epoch: 2000, loss: 0.008774421178
Epoch: 2500, loss: 0.000981826452
Epoch: 3000, loss: 0.000448303996
Epoch: 3500, loss: 0.000237140950
Epoch: 4000, loss: 0.000145462996
Epoch: 4500, loss: 0.000079287915
Epoch: 5000, loss: 0.000049306964
Since I have set the global random seed, which ensures the same initialization of the MLP model, I get the same result each time I re-run the code. Now if I comment out the tf.function decoration, and run the code with the same random seed, I get a different result:
Epoch: 1, loss: 0.526220202446
Epoch: 500, loss: 0.485644072294
Epoch: 1000, loss: 0.457641303539
Epoch: 1500, loss: 0.085796415806
Epoch: 2000, loss: 0.008686196990
Epoch: 2500, loss: 0.001000389108
Epoch: 3000, loss: 0.000672267575
Epoch: 3500, loss: 0.000234877603
Epoch: 4000, loss: 0.000125381193
Epoch: 4500, loss: 0.000075728072
Epoch: 5000, loss: 0.000333589123
As it appears to me, there is some roundoff error due to the difference in the implementations with and without the decoration. This roundoff error accumulates over several epochs and eventually leads to noticeable differences.
Could someone confirm whether this behavior is expected? Or am I making a mistake in the way I am using tf.function.