Hi folks.
I am currently implementing a custom training loop by overriding the train_step()
function. I am also not using the default compile()
method. So, I believe loss scaling is to be implemented.
Here’s how the fundamental loop is implemented (that runs as expected on a single GPU):
with tf.GradientTape() as tape:
fake_colorized = self.gen_model(grayscale)
fake_input = tf.concat([grayscale, fake_colorized], axis=-1)
predictions = self.disc_model(fake_input)
misleading_labels = tf.ones_like(predictions)
g_loss = - self.loss_fn(misleading_labels, predictions)
l1_loss = tf.keras.losses.mean_absolute_error(colorized, fake_colorized)
final_g_loss = g_loss + self.reg_strength * l1_loss
self.loss_fn
is binary cross-entropy.
Here’s how the distributed variant is implemented:
with tf.GradientTape() as tape:
fake_colorized = self.gen_model(grayscale)
fake_input = tf.concat([grayscale, fake_colorized],
axis=-1)
predictions = self.disc_model(fake_input)
misleading_labels = tf.ones_like(predictions)
g_loss = - self.loss_fn(misleading_labels, predictions)
g_loss /= tf.cast(
tf.reduce_prod(tf.shape(misleading_labels)[1:]),
tf.float32)
g_loss = tf.nn.compute_average_loss(g_loss,
self.global_batch_size)
l1_loss = tf.keras.losses.MeanAbsoluteError(
reduction=tf.keras.losses.Reduction.NONE)(colorized,
fake_colorized)
l1_loss /= tf.cast(
tf.reduce_prod(tf.shape(colorized)[1:]),
tf.float32)
l1_loss = tf.nn.compute_average_loss(l1_loss,
self.global_batch_size)
final_g_loss = g_loss + (l1_loss * self.reg_strength)
self.loss_fn
is binary cross-entropy but in this case, it’s initialized without any reduction.
This loop is not behaving as expected because the losses are way off. Am I missing out on something?