I’m trying to find a way in TF2 to use the tf.keras.layers.BatchNormalization layer in training mode (i.e. normalizing using the statistics of the current batch) but without updating the moving mean and variance (for some batches, not all).
In TF1, using tf.layers.batch_normalization, you could do something like
x = my_first_inputs # I want to use these data for updating moving statistics
y = my_second_inputs # I do not want to use these data for updating moving statistics
out_x = my_model(x, training=True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
out_y = my_model(y, training=True)
…
train_op = gradient step to minimize loss
…
with tf.control_dependencies([train_op]):
train_op = tf.group(*update_ops)
session.run(train_op)
Does anyone have an idea of how to replicate this in TF2?
Unfortunately I can’t really find a solution for my problem in the migration guide. All it says is that the moving statistics for BatchNorm will be updated automatically in TF2 when calling with “training=True”, which is what I don’t want.
I am also unfortunately not familiar enough with all the inner mechanics of TF2 to understand how the snippet from eager_utils.py helps me.
I think I finally found a fairly good solution to this. Posting in case anyone with the same problem finds this thread.
One cause of my original problem is that tf.keras.layers.BatchNormalization uses a custom behavior for layer.trainable = False. From the docs:
However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.
This behavior can be disabled by subclassing tf.keras.layers.BatchNormalization and overwriting the _get_training_value() method:
class MyBatchNorm(tf.keras.layers.BatchNormalization):
def _get_training_value(self, training=None):
if training is None:
training = backend.learning_phase()
if self._USE_V2_BEHAVIOR:
if isinstance(training, int):
training = bool(training)
#if not self.trainable:
# # When the layer is not trainable, it overrides the value passed
# # from model.
# training = False
return training
Note that the custom behavior for layer.training is disabled by the commenting the four lines.
We can then use batch normalization with batch statistics but without updating the moving statistics using something like
model = MyBatchNorm()
model.trainable = False
out = model(x, training=True)
whereas the unmodified tf.keras.layers.BatchNormalization would use the moving statistics in this call.