Occasionally skip weight updates inside train_step of Tensorflow Keras

I’m building an RNN Model where I would like to do weight updates only when some condition, based on the state of the RNN, is met:

class CustomModel(keras.Model):
    def train_step(self, data):
        x, y = data
        # Perform forward pass and compute gradients

        if (some condition is met):
            #Apply gradients

I tried an if condition that only updates based on the value of a 1 x 1 tf.Variable t_step like if t_step > 5. but got the error below:

OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

t_step is essentially the time step of the input, but when eager execution is disable it seems like I cannot use it in the if condition.

Have you tried with tf.cond  |  TensorFlow v2.16.1 ?

Yes, I tried something like if tf.cond(t_step > 5, lambda: True, lambda : False) and got the same error.

I think you need to use tf.greather or to wrap your original train_step with tf.function.

See: