I’m building an RNN Model where I would like to do weight updates only when some condition, based on the state of the RNN, is met:
class CustomModel(keras.Model):
def train_step(self, data):
x, y = data
# Perform forward pass and compute gradients
if (some condition is met):
#Apply gradients
I tried an if condition that only updates based on the value of a 1 x 1 tf.Variable t_step
like if t_step > 5
. but got the error below:
OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
is essentially the time step of the input, but when eager execution is disable it seems like I cannot use it in the if condition.