NOTE: Link to Google Colab Showing Example
I was teaching today, and showing how to create custom training loops with the keras API, and came across some unexpected behavior. I am posting here to see if this is a bug I should open an issue for, or for somebody to explain what is happening.
A Neural Network with no hidden layers and a single output node is, mathematically, identical to Linear Regression. I can create a keras model for this like so:
class LinearRegression(keras.Model):
def build(self, input_shapes):
self.layer = keras.layers.Dense(1, kernel_initializer="zeros")
def call(self, input_data, training=None):
return self.layer(input_data, training=training)
Furthermore, I can very simply fit this model using the following MSE implementation:
def mse(y_true, y_pred):
return tf.reduce_mean(tf.math.square(y_true - y_pred))
linear_model = LinearRegression()
linear_model.compile(optimizer=keras.optimizers.SGD(), loss=mse)
linear_model.fit(x, y, epochs=EPOCHS, batch_size=BATCH_SIZE)
And, to no surprise, the above works perfectly. I can simulate a dataset (you can see in the colab) that can perfectly be fit with the above model.
Now, we can customize a training loop by overwriting the train_step
method. The classic implementation of this class might look something like this:
class LinearRegression(keras.Model):
def build(self, input_shapes):
self.layer = keras.layers.Dense(1, kernel_initializer="zeros")
def call(self, input_data, training=None):
return self.layer(input_data)
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
prediction = self(x, training=True)
loss = self.compiled_loss(y, prediction)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {"loss": loss}
And, again, to nobody’s surprise. This works! HOWEVER, if I change the line loss = self.compiled_loss(y, prediction)
to loss = mse(y, prediction)
, the model fails to fit!
Can somebody explain why this is the case? What is self.compiled_loss
doing that is required? I encourage you to run the code in the shared notebook to see for yourself!