Consider the following model:
class MyModel(keras.Model):
def __init__(self, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.square_layer = keras.layers.Dense(2)
self.cube_layer = keras.layers.Dense(2)
self.optimizer = tf.keras.optimizers.Adam()
@tf.function
def call(self, X):
return tf.stack([self.square_layer(X), self.cube_layer(X)], axis=-1)
@tf.function
def train_step(self, inputs, targets):
with tf.GradientTape() as tape:
predictions = self(inputs)
loss = tf.reduce_mean(tf.square(predictions - targets))
grads = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return loss
If we train using the following ‘train’ function, and set ‘self.cube_layer.trainable’ as True or False, the result is as expected in both the cases:
def train(self, inputs, targets, num_epochs=5000):
self.cube_layer.trainable = False # True or False
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
print("Loss: " +str(loss))
inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)
model = MyModel()
model.train(inputs, targets)
print(model(inputs))
But, if we change the ‘trainable’ flag during training, the result is not as expected:
def train(self, inputs, targets, num_epochs=5000):
self.cube_layer.trainable = False
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
self.cube_layer.trainable = True
self.compile(optimizer=self.optimizer)
for epoch in range(num_epochs):
loss = self.train_step(inputs, targets)
print("Loss: " +str(loss))
inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)
model = MyModel()
model.train(inputs, targets)
print(model(inputs))
In the above example, if we remove the ‘@tf.function’ decorators from ‘call’ and ‘train_step’, the result is as expected ! So, I believe it has something to do with tf.function and tensorflow graph compilation. Is there a way we can use tf.function and set the ‘trainable’ attribute dynamically during training ? I am using tensorflow 2.9.1.