Hi, I have noticed an inconsistency where wrapping the metric and passing pure lambda produce different results some of time. Google Colab
If you run the last cell multiple times, you will see instances where mean_squared_wrapped and mean_squared_error_fn are not equal to each other.
How can we explain this?
Thanks for taking a look. So the issue happens only on repeated compiles, not on first compile of model.
So I added a for loop in the new colab so that we can run once and see the mismatch. I am not sure if this is expected.
Pasted the same here
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
def squared_error_fn(y_true, y_pred):
return tf.square(y_true - y_pred)
def mean_squared_error_fn(y_true, y_pred):
return tf.math.reduce_mean(tf.square(y_true - y_pred))
mean_squared_wrapped = tf.keras.metrics.MeanMetricWrapper(fn=squared_error_fn, name='mean_squared_wrapped')
def custom_mean_squared_error(y_true, y_pred):
return tf.math.reduce_mean(tf.square(y_true - y_pred))
def get_compiled_model():
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
metrics=['accuracy', mean_squared_wrapped, mean_squared_error_fn])
return model
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")
print("TF version: ",tf.__version__)
one_hot_y_train = tf.one_hot(y_train, depth=10)
for i in range(3):
compiled_model = get_compiled_model()
eval1 = compiled_model.evaluate(x_train, one_hot_y_train, verbose=2)
eval2 = compiled_model.evaluate(x_train, one_hot_y_train, verbose=2)
metric_index = 2 # mean_squared_wrapped
if abs(eval1[metric_index] - eval2[metric_index]) > 1e-5:
print(f"mismatch found in compile: {i}")
print("eval1: ", eval1)
print("eval2: ", eval2)
The MeanMetricWrapper does have state (the running mean) is it possible that it’s just not getting reset correctly?
I have bumped into a similar issue with compile editing the metric object, and then multiple compile calls stacking up the modifications. This feels a little similar.