Tensorflow (2.9.1) : Changing the 'trainable' attribute on a layer during training

Consider the following model:

class MyModel(keras.Model):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs)
        self.square_layer = keras.layers.Dense(2)
        self.cube_layer = keras.layers.Dense(2)
        self.optimizer = tf.keras.optimizers.Adam()
    def call(self, X):
        return tf.stack([self.square_layer(X), self.cube_layer(X)], axis=-1)
    def train_step(self, inputs, targets):
        with tf.GradientTape() as tape:
            predictions = self(inputs)
            loss = tf.reduce_mean(tf.square(predictions - targets))
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return loss

If we train using the following ‘train’ function, and set ‘self.cube_layer.trainable’ as True or False, the result is as expected in both the cases:

    def train(self, inputs, targets, num_epochs=5000):
        self.cube_layer.trainable = False  # True or False
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
        print("Loss: " +str(loss))

inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)

model = MyModel()
model.train(inputs, targets)

But, if we change the ‘trainable’ flag during training, the result is not as expected:

    def train(self, inputs, targets, num_epochs=5000):
        self.cube_layer.trainable = False
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
        self.cube_layer.trainable = True
        for epoch in range(num_epochs):
            loss = self.train_step(inputs, targets)
        print("Loss: " +str(loss))

inputs = tf.constant([[1,2]], dtype=tf.float32)
targets = tf.constant([[[3,6], [9,12]]], dtype=tf.float32)

model = MyModel()
model.train(inputs, targets)

In the above example, if we remove the ‘@tf.function’ decorators from ‘call’ and ‘train_step’, the result is as expected ! So, I believe it has something to do with tf.function and tensorflow graph compilation. Is there a way we can use tf.function and set the ‘trainable’ attribute dynamically during training ? I am using tensorflow 2.9.1.

Solved the problem.
Refer: tensorflow2.0 - Tensorflow (2.9.1) : Changing the 'trainable' attribute on a layer during training - Stack Overflow

did you get any proper solution ?

Yes, a proper solution is listed here: