I’m building an RNN Model where I would like to do weight updates only when some condition, based on the state of the RNN, is met:
class CustomModel(keras.Model):
def train_step(self, data):
x, y = data
# Perform forward pass and compute gradients
if (some condition is met):
#Apply gradients
I tried an if condition that only updates based on the value of a 1 x 1 tf.Variable t_step
like if t_step > 5
. but got the error below:
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
t_step
is essentially the time step of the input, but when eager execution is disable it seems like I cannot use it in the if condition.