I’d like some confirmation here:
So I implemented a vector quantization layer, which I’d like to use during training also. Straight-Through estimation is a common technique, which,as far as I understand it, just means to ignore the layer during backpropagation, i.e., to copy the gradient of the previous layer to the following layer during backpropagation. I’d like to implement it.
Now, I know of tensorflow’s stop_gradient function, however, shouldn’t a custom gradient defined as
def grad(dy):
return dy
yield the same result? I.e. doing
class VectorQuantization(tf.keras.layers.Layer):
def __init__(self, codebook = None, **kwargs):
super(VectorQuantization, self).__init__(**kwargs)
... init ....
@tf.custom_gradient
def call(self, inputs):
def grad(dy):
return dy
... regular VQ stuff ...
return quantized, grad
In my understanding, training a model with a structure like
mymodel = somelayers(VQ(somemorelayers)
regarding backpropagation acts like
dmymodel = dsomelayers * dsomemorelayers
right? This should be the/a straight-through estimator. Am I missing something? I looked up some VQ Variational AE code and they used “inputs - tf.stop_gradient(quantized - inputs)” for the straight through estimate. However, shouldn’t this be equivalent to defining the gradient as I did here?