In the code below, I overridded train_step(), but self.metrics does not contain compiled metrics.
As a result the verbose of fit() does not show the compiled metrics and the history does not contains the metrics.
How can I modified this train_step() to reproduce exactly the same behaviour as tensorflow?
import tensorflow as tf
from tensorflow.keras.layers import Dense
print(f"{tf.__version__ = }")
tf.config.run_functions_eagerly(True)
# Dataset
(x_tra, y_tra), (x_tst, y_tst) = tf.keras.datasets.mnist.load_data()
x_tra = x_tra.reshape(-1, 784).astype("float32") / 255
x_tst = x_tst.reshape(-1, 784).astype("float32") / 255
y_tra = tf.one_hot(y_tra.astype("float32"), depth=10)
y_tst = tf.one_hot(y_tst.astype("float32"), depth=10)
# Model
class MModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = Dense(64, activation="relu")
self.dense2 = Dense(64, activation="relu")
self.softmax = Dense(10, activation="softmax")
def call(self, inputs):
x = self.dense1(inputs)
x = self.dense2(x)
return self.softmax(x)
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
ŷ = self(x, training=True)
loss_value = self.compiled_loss(y, ŷ, regularization_losses=self.losses)
# Compute gradients and Update weights
gradients = tape.gradient(loss_value, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
"""
With a debugger, you can see that self.metrics list contains only the loss (Mean(name=loss,dtype=float32)) but
not the compiled metrics
Also, self.compiled_metrics.metrics list is empty.
"""
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss_value)
else:
metric.update_state(y, ŷ)
return {m.name: m.result() for m in self.metrics}
model = MModel()
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()]
)
history = model.fit(
x_tra, y_tra,
epochs=1,
validation_data=(x_tst, y_tst)
)
print(f"{model.evaluate( x_tst, y_tst) = }")
"""
1875/1875 [==============================] - 12s 5ms/step - loss: 0.2838 - val_loss: 0.1499 - val_categorical_accuracy: 0.9525
"""
Please note that the validation metrics is present because I did not overide train_test()