Tensorflow/Keras equivalent of PyTorch `nn.Module.eval()`

I am migrating Pytorch application and need the Tensorflow/Keras equivalent of PyTorch nn.Moduel.eval() method. Google search shows Model.evaluate() or Model.predict() but nothing about the tf.Module subclass. Or how to set the training to false?

Hi @khteh, The equivalent of PyTorch’s nn.Module.eval() in TensorFlow/Keras is not a separate method, instead it is the training=False argument during the forward pass.
For custom tf.keras.Model or tf.keras.layers.Layer subclasses, you can use:
output = model(input_data, training=False)
Also, Keras’s built-in model.evaluate() and model.predict() is automatically handle by setting training=False internally.

1 Like