How are gradients applied in distributed custom loops?

I am having some problems with unstable losses in a complicated distributed custom loop and I have a question about how the gradients get applied.

Consider the custom training loop example from the docs:

The final steps of the MirroredStrategy are described like this:

  • Each replica calculates the loss and gradients for the input it received.
  • The gradients are synced across all the replicas by summing them.
  • After the sync, the same update is made to the copies of the variables on each replica.

Then in the code, the train_step() is:

def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

Reading the code, it seems to me that apply_gradients() is getting applied to each replica separately. This step happens in train_step(), which is ostensibly per-replica code. The losses are being reduced at the end of distributed_train_step(), but the gradients aren’t being explicitly synced or summed. So, what goes on under the hood? Is strategy.run() coordinating the syncing and application of gradients between the replicas? Or is there somewhere else in the code where the gradients are synced and added explicitly? Or does this example code do something differently than the MirroredStrategy description given at the beginning of the docs?

Thanks in advance to anyone who can help me understand what’s going on under the hood here.

1 Like

Hello @QEDan

Thank you for using TensorFlow

In the distributed training, strategy.run() handles the creation of replica for each device (GPU) based on the strategy chosen. Strategy.run() implicitly syncs and aggregates the gradients, simply speaking it collects all the gradients from all replicas and combine them generally summing them up or based on the chosen operation.apply_gradients() applies gradients after getting all the gradients from replicas, and apply the averaged gradients all over.