Hi, I am trying to get straight through estimation running for a vector quantization (VQ) layer. This VQ layer is acting inside of a recurrent network I build. When I call the recurrent network with some input, what happens is:
for i in tf.range(timesteps):
encoded = self.encoder(input_data[i], state)
encoded_q = self.quantizer(encoded)
decoded, state = self.decoder(encoded_q, state)
outputs = outputs.write(i, decoded)
The vector quantizer layer is just a copy of the implementation of standard VQ VAE as found in Vector-Quantized Variational Autoencoders.
Now, this raises an error when I try to train the model as the graph apparently cannot be tracked across the loop iterations.
Can this be solved somehow? I searched a lot and tried a lot, but to no avail.