Focal loss VS Binary cross entropy loss (results comparison)

Hi everyone,

I’m doing a binary classification using CNN, My dataset is highly imbalanced, and I don’t want to do the augmentation or oversampling.

So I’m changing the loss functions and trying to get the best results, I’ve tried 2 loss functions (Binary cross entropy and Focal loss).

I want to know what you think about the results, which one is performing better, and which one should I proceed with the training because I don’t have enough resources to try both for longer training.

(I couldn’t upload images so I uploaded the csv file to my Gdrive and shared the file link)
This is the error when I’m trying to upload results images (File exists @ dir_s_mkdir - /var/www/discourse/public/uploads)

Here are the results
Focal loss results
BCE Loss results

This is the GPT4 answer to the comparison between the results, do you agree with GPT4 observations?

Thank you for providing the full metrics. Let's dive deeper into the comparison:

Accuracy:

Binary Cross Entropy (BCE): Starts at 99.6% and ends at 99.97%.
Focal Loss: Starts at 99.5% and ends at 99.88%.
AUC (Area Under the Curve):

BCE: Starts at 51.5% and ends at 95.17%.
Focal Loss: Starts at 50.0% and ends at 83.0%.
Loss:

BCE: Starts at 0.0239 and ends at 0.0011.
Focal Loss: Starts at 0.0021 and ends at 0.0002.
Validation Accuracy:

BCE: Starts at 99.79% and ends at 99.91%.
Focal Loss: Starts at 99.79% and ends at 99.96%.
Validation AUC:

BCE: Starts at 50.96% and ends at 71.15%.
Focal Loss: Starts at 50.0% and ends at 71.15%.
Validation Loss:

BCE: Starts at 0.0107 and ends at 0.0091.
Focal Loss: Starts at 0.0004 and ends at 0.0002.
Balanced Accuracy:

BCE: Starts at 74.81% and ends at 99.97%.
Focal Loss: Starts at 74.75% and ends at 99.88%.
Validation Balanced Accuracy:

BCE: Starts at 49.90% and ends at 50.0%.
Focal Loss: Starts at 50.0% and ends at 65.06%.
Observations:

Both BCE and Focal Loss provide comparable results in terms of training accuracy.
For AUC, BCE seems to have a better performance towards the end of the training.
Focal Loss has a consistently lower loss value for both training and validation.
The validation accuracy for both is quite high, but Focal Loss ends slightly higher.
The balanced accuracy is better in BCE for training, but for validation, Focal Loss outperforms BCE.
Recommendation:
Given the metrics, Focal Loss appears to be more stable and offers advantages, especially in terms of validation balanced accuracy, which is crucial for imbalanced datasets. Thus, I would recommend proceeding with Focal Loss. However, the difference in performance is not drastic, so you might want to consider other factors like interpretability and computational costs.

Improvement Suggestions:

Balanced Datasets: The balanced accuracy for validation, especially with BCE, is close to 50%, which suggests the model might be struggling due to the imbalance in the dataset. Techniques like oversampling the minority class or undersampling the majority class can be beneficial.
Tweak Gamma: You mentioned using a gamma of 5 for the Focal Loss. Experimenting with other values might affect performance.
Hyperparameter Tuning: Experiment with different learning rates, dropout rates, and other hyperparameters.
Different Model Architectures: While the current architecture seems suitable, trying out variations or different architectures might help in achieving better performance.

This is the classification code


model = Sequential([
  ## define the model's architecture
    layers.Conv3D(filters=32, kernel_size=3, activation="relu",padding='same'),
    layers.BatchNormalization(),
    layers.Conv3D(filters=32, kernel_size=3, activation="relu",padding='same'),
    layers.BatchNormalization(),
    layers.MaxPool3D(pool_size=2),
    layers.BatchNormalization(),
    layers.Conv3D(filters=64, kernel_size=3, activation="relu",padding='same'),
    layers.BatchNormalization(),
    layers.Conv3D(filters=64, kernel_size=3, activation="relu",padding='same'),
    layers.BatchNormalization(),
    layers.MaxPool3D(pool_size=2),
    layers.BatchNormalization(),

    layers.Conv3D(filters=128, kernel_size=3, activation="relu",padding='same'),
    layers.BatchNormalization(),
    layers.Conv3D(filters=128, kernel_size=3, activation="relu",padding='same'),
    layers.BatchNormalization(),
    layers.MaxPool3D(pool_size=2),
    layers.BatchNormalization(),

    layers.Conv3D(filters=256, kernel_size=3, activation="relu", padding='same'),
    layers.BatchNormalization(),
    layers.Conv3D(filters=256, kernel_size=3, activation="relu", padding='same'),
    layers.BatchNormalization(),
    layers.MaxPool3D(pool_size=2),
    layers.BatchNormalization(),

    layers.GlobalAveragePooling3D(),
    layers.Dense(units=512, activation="relu"),
    layers.BatchNormalization(),
    layers.Dropout(0.4),
    layers.Dense(units=1),
])



train_gen = DataGenerator(train_image_paths, train_labels, base_dir, (31, 31, 31), batch_size=128, shuffle=False)
test_gen = DataGenerator(val_data, val_label, base_dir, (31, 31, 31), batch_size=128, shuffle=False)


initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)

## compile the model first of course
opt = tf.keras.optimizers.experimental.AdamW(
    learning_rate=lr_schedule,
    beta_1=0.9,
    beta_2=0.99,
    epsilon=1e-06,
    weight_decay=0.004,
    ema_momentum= 0.99,
    name="AdamW",
)

# Create a TensorBoard callback
log_dir="/home/mustafa/project/LUNA16/{}".format(time())
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch = '5,10')
csv_logger = tf.keras.callbacks.CSVLogger('metrics_focal.csv')
#loss=keras.losses.BinaryCrossentropy(from_logits=True)
model.compile(optimizer=opt,
              loss = tf.keras.losses.BinaryFocalCrossentropy(apply_class_balancing=True, from_logits=True, gamma=5),
              metrics=['accuracy', 'AUC', tf.keras.metrics.SpecificityAtSensitivity(0.5), 'Precision', 'Recall', 'FalseNegatives', 'FalsePositives', 'TrueNegatives', 'TruePositives'])
model.build(input_shape= (128,None,None,None,1))
model.summary()
# now let's train the model
#EarlyStop = tf.keras.callbacks.EarlyStopping(monitor='precision', patience=4, restore_best_weights=True,)
history = model.fit(train_gen, validation_data = test_gen, epochs=15, shuffle = False , verbose = 1 , callbacks = [csv_logger],
use_multiprocessing = True)
#,class_weight=None
model.save("gemerator_model_adamw_focal")
1 Like

These are the BCE Loss results:





These are the Focal Loss Results: