Add_loss inside of RNN cell

Hi, I try to implement straight-through estimation for a vector quantized (VQ) recurrent network, where the VQ quantizes some latent data that occurs inside of the RNN. However, tensorflow throws the following error:

The tensor <tf.Tensor ‘…/rnn/while/custom_rnn_cell/enc_vq/mul_2:0’ shape=() dtype=float32> cannot be accessed from FuncGraph(name=train_function, id=139672471995008), because it was defined in FuncGraph(name=…_rnn_while_body_574022, id=139672472052640), which is out of scope."

So I can not use add_loss as usual. Is there some work around? Otherwise I can not use straight-through estimation, which appears to be crucial for my application.

1 Like

So I appear to have figured out a solution, but it is somewhat of a hack:
One can keep track of the running VQ loss as another state variable of the RNN. Then, with return_states=True, one can access the final VQ loss via one of the return values of the RNN.
To get this to work, one has to use tf.fill to fill an entire tensor that in particular covers the batch size with the respective vq loss values. Finally, using the mentioned state output of the RNN, tf.reduce_mean will do the trick followed by an add_loss call.