VectorQuantization: tf.gradient_stop (STE) inside a for loop/recurrent network

Hi, I am trying to get straight through estimation running for a vector quantization (VQ) layer. This VQ layer is acting inside of a recurrent network I build. When I call the recurrent network with some input, what happens is:

            for i in tf.range(timesteps):
                encoded = self.encoder(input_data[i], state)
                encoded_q = self.quantizer(encoded)
                decoded, state = self.decoder(encoded_q, state)

                outputs = outputs.write(i, decoded)

The vector quantizer layer is just a copy of the implementation of standard VQ VAE as found in Vector-Quantized Variational Autoencoders.

Now, this raises an error when I try to train the model as the graph apparently cannot be tracked across the loop iterations.

Can this be solved somehow? I searched a lot and tried a lot, but to no avail.

I solved this: It is best to use a class inheriting from tensorflows RNNCell (it is slightly faster than a custom tf.while_loop apparently) and most importantly to either hide the vq_loss inside a dummy state variable to then finally output the states at the very end of the rnn processing (or to somehow merge the vq_loss with the actual rnn output, but i never got that to work).