Hi, I try to implement straight-through estimation for a vector quantized (VQ) recurrent network, where the VQ quantizes some latent data that occurs inside of the RNN. However, tensorflow throws the following error:
The tensor <tf.Tensor ‘…/rnn/while/custom_rnn_cell/enc_vq/mul_2:0’ shape=() dtype=float32> cannot be accessed from FuncGraph(name=train_function, id=139672471995008), because it was defined in FuncGraph(name=…_rnn_while_body_574022, id=139672472052640), which is out of scope."
So I can not use add_loss as usual. Is there some work around? Otherwise I can not use straight-through estimation, which appears to be crucial for my application.