Model.fit and model.predict on a single sample gives different results

Monocular depth estimation with NYUv2 dataset. I want to overfit a model on a single sample. This is the triplet of (prediction, ground truth, input) that troubles me:

These are the losses I observe by calling model.fit():

This clearly indicates overfitting. But when calling model.predict() on a single sample used for training, I observe a much higher loss (0.0145 vs 0.63).

Why do I see such a big difference in loss and poor depth estimate for a sample that was solely used for overfitting?

The code:

    def get_model(img_size, in_channels=1):
        # inputs = keras.Input(shape=img_size + (in_channels,))
        inputs = layers.Input(shape=(*img_size, in_channels), name="input")
    
        ### [First half of the network: downsampling inputs] ###
    
        # Entry block
        filters = [16, 32, 64]
        # filters = [16, 32, 64, 128, 256]
        x = layers.Conv2D(filters[0], in_channels, strides=2, padding="same")(inputs)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
    
        previous_block_activation = x  # Set aside residual
    
        # Blocks 1, 2, 3 are identical apart from the feature depth.
        for filter in filters[1:]:
            x = layers.Activation("relu")(x)
            x = layers.SeparableConv2D(filter, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.Activation("relu")(x)
            x = layers.SeparableConv2D(filter, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
    
            # Project residual
            residual = layers.Conv2D(filter, 1, strides=2, padding="same")(
                previous_block_activation
            )
            x = layers.add([x, residual])  # Add back residual
            previous_block_activation = x  # Set aside next residual
    
        ### [Second half of the network: upsampling inputs] ###
    
        for filter in filters[::-1]:
            x = layers.Activation("relu")(x)
            x = layers.Conv2DTranspose(filter, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.Activation("relu")(x)
            x = layers.Conv2DTranspose(filter, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.UpSampling2D(2)(x)
    
            # Project residual
            residual = layers.UpSampling2D(2)(previous_block_activation)
            residual = layers.Conv2D(filter, 1, padding="same")(residual)
            x = layers.add([x, residual])  # Add back residual
            previous_block_activation = x  # Set aside next residual
    
        # x = layers.Activation("sigmoid")(x)
    
        # Add a per-pixel classification layer
        outputs = layers.Conv2D(
            1, in_channels, activation='sigmoid', padding="same", name="output"
        )(x)
    
        # Define the model
        model = keras.Model(inputs, outputs)
    
        return model

    metrics =  tf.keras.metrics.Mean(name="loss")
    # Define a custom metric
    def custom_metric(y_true, y_pred, sample_weight=None):
        metric_value = calculate_loss(y_true, y_pred)
        metrics.update_state(metric_value, sample_weight=sample_weight)
        return metric_value

    model = get_model((64, 64), 1)
    # Compile the model
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)
    model.compile(optimizer=optimizer, loss=calculate_loss, metrics=[custom_metric])

    history = model.fit(
            x=ds_train,
            epochs=cfg.epochs*3,
            validation_data=ds_val,
            # callbacks=[es],
            verbose=1
        )

    model.predict(ds_train)

    def calculate_loss(target, pred):
    # Edges
    if not isinstance(target, tf.Tensor):
        target = tf.convert_to_tensor(target)
    if not isinstance(pred, tf.Tensor):
        pred = tf.convert_to_tensor(pred)
    if len(target.shape) == 3:
        target = target[tf.newaxis, ...]
    if len(pred.shape) == 3:
        pred = pred[tf.newaxis, ...]
    dy_true, dx_true = tf.image.image_gradients(target)
    dy_pred, dx_pred = tf.image.image_gradients(pred)
    weights_x = tf.exp(tf.reduce_mean(tf.abs(dx_true)))
    weights_y = tf.exp(tf.reduce_mean(tf.abs(dy_true)))

    # Depth smoothness
    smoothness_x = dx_pred * weights_x
    smoothness_y = dy_pred * weights_y

    depth_smoothness_loss = tf.reduce_mean(abs(smoothness_x)) + tf.reduce_mean(
        abs(smoothness_y)
    )

    # Structural similarity (SSIM) index
    ssim_loss = tf.reduce_mean(
        1
        - tf.image.ssim(
            target, pred, max_val=cfg.w, filter_size=7, k1=0.01**2, k2=0.03**2
        )
    )
    # Point-wise depth
    l1_loss = tf.reduce_mean(tf.abs(target - pred))

    loss = (
        (cfg.ssim_loss_weight * ssim_loss)
        + (cfg.l1_loss_weight * l1_loss)
        + (cfg.edge_loss_weight * depth_smoothness_loss)
    )

    return loss

I tried double-checking the input, tuning the training, increasing model’s capacity, removing BatchNorm, etc.