Help Debugging Mirrored Strategy with Loss going to NAN

I’ve been having a lot of difficulty getting distributed training to work as I would expect using mirrored strategy. This is the first tensorflow project I’ve been working on distributed training with. My project involves input data that is effectively an image into a general CNN (defined below). To accomplish this I’m using a general keras training loop (also defined below). The model trains without issues using a single GPU, but when I attempt the mirrored strategy my loss goes to nan. I’ve attempted using different optimizer and optimization settings (adjusting learning rate, epsilon), changing the activation function into the model’s softmax, removing regularization from my model, but all of my attempts have failed to result in progress. I’m considering writing my own training loop, but I would like to learn if there are any other options I can try as well to debug my issue.

model = keras.Sequential([
    keras.layers.Input(shape=input_shape),
    keras.layers.Conv2D(32, kernel_size=3, activation='relu',data_format='channels_last'),
    BatchNormalization(),
    keras.layers.MaxPool2D((2,2), data_format='channels_last'),
    keras.layers.Conv2D(64, kernel_size=3, activation='relu',data_format='channels_last'),
    BatchNormalization(),
    keras.layers.MaxPool2D((2,2), data_format='channels_last'),
    keras.layers.Conv2D(128, kernel_size=3, activation='relu',data_format='channels_last'),
    BatchNormalization(),
    keras.layers.MaxPool2D((2,2), data_format='channels_last'),
    keras.layers.Conv2D(128, kernel_size=3, activation='relu',data_format='channels_last'),
    BatchNormalization(),
    keras.layers.MaxPool2D((2,2), data_format='channels_last'),
    keras.layers.Conv2D(128, kernel_size=3, activation='relu',data_format='channels_last'),
    BatchNormalization(),
    keras.layers.MaxPool2D((2,2), data_format='channels_last'),
    keras.layers.Dropout(0.5),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(2, activation='softmax')
])
with strategy.scope():
        from models.cnn.cnn import model

        model.summary()
        optimizer = optimization_helper.return_optimizer()
        loss_object = tf.keras.losses.CategoricalCrossentropy()
        metrics_list = [
            metrics.Precision(),
            metrics.CategoricalAccuracy(),
        ]

        earlystopping = callbacks.EarlyStopping(monitor = "val_loss", mode= "min",
                                                patience= optimization['early_stopping_patience'], restore_best_weights = True)

        scheduler = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=optimization['scheduler_factor'],
                                                patience=optimization['scheduler_patience'], min_lr=optimization['scheduler_min_lr'], verbose=1)

        checkpointer = callbacks.ModelCheckpoint(filepath=os.path.join(results_dir, 'best_model_state.hdf5'),
                                                 verbose=1, save_best_only=True)

        callback_list = [
            earlystopping,
            scheduler,
            checkpointer
        ]

    model.compile(
        run_eagerly=False,
        optimizer=optimizer,
        loss=loss_object,
        metrics=metrics_list
    )


    steps_per_epoch = math.ceil(train_len/train_batch_size)
    steps_per_val = math.ceil(val_len/val_batch_size)

    history = model.fit(x=train_datasets,
                        validation_data=val_datasets,
                        steps_per_epoch=steps_per_epoch,
                        validation_steps=steps_per_val,
                        epochs=optimization['total_target_epochs'],
                        verbose=1,
                        callbacks = callback_list,
                        use_multiprocessing=False,
                        shuffle=True)