Suggestions regarding loss scaling in a distributed training loop

Hi folks.

I am currently implementing a custom training loop by overriding the train_step() function. I am also not using the default compile() method. So, I believe loss scaling is to be implemented.

Here’s how the fundamental loop is implemented (that runs as expected on a single GPU):

 with tf.GradientTape() as tape:
    fake_colorized = self.gen_model(grayscale)
    fake_input = tf.concat([grayscale, fake_colorized], axis=-1)
    predictions = self.disc_model(fake_input)
    misleading_labels = tf.ones_like(predictions)
    g_loss = - self.loss_fn(misleading_labels, predictions)
    l1_loss = tf.keras.losses.mean_absolute_error(colorized, fake_colorized)
    final_g_loss = g_loss + self.reg_strength * l1_loss

self.loss_fn is binary cross-entropy.

Here’s how the distributed variant is implemented:

with tf.GradientTape() as tape:
    fake_colorized = self.gen_model(grayscale)
    fake_input = tf.concat([grayscale, fake_colorized],
                           axis=-1)
    predictions = self.disc_model(fake_input)
    misleading_labels = tf.ones_like(predictions)
    
    g_loss = - self.loss_fn(misleading_labels, predictions)
    g_loss /= tf.cast(
        tf.reduce_prod(tf.shape(misleading_labels)[1:]),
        tf.float32)
    g_loss = tf.nn.compute_average_loss(g_loss,
        self.global_batch_size)
    l1_loss = tf.keras.losses.MeanAbsoluteError(
        reduction=tf.keras.losses.Reduction.NONE)(colorized,
        fake_colorized)
    l1_loss /= tf.cast(
        tf.reduce_prod(tf.shape(colorized)[1:]),
        tf.float32)
    l1_loss = tf.nn.compute_average_loss(l1_loss,
        self.global_batch_size)
    final_g_loss = g_loss + (l1_loss * self.reg_strength)

self.loss_fn is binary cross-entropy but in this case, it’s initialized without any reduction.

This loop is not behaving as expected because the losses are way off. Am I missing out on something?

1 Like

Are you in the same case as Custom training with tf.distribute.Strategy  |  TensorFlow Core or is it something different?

1 Like

Didn’t quite get your point. Could you elaborate?

1 Like

I mean currently that is the official tutorial that we propose for users/devs that want to use Distributed training with a custom train loop. So I was just asking if you have specific needs and if we could expand that page.

1 Like

No. I was actually asking if the distributed variant of my training loop is correctly implemented.

1 Like

Hi @Sayak_Paul were you able to make a headway with this? I am experiencing something similar with my own custom loss in a distributed custom training loop running on TPU and my loss is way way way off.