Hi, I’m currently trying to understand the internal functioning of the MirroredStrategy
and the recording of gradients of MirroredVariable
s in particular.
I understand the concept of the MirroredVariable
but it’s unclear to me how a correct gradient tape is recorded over these variables in _call_for_each_replica
in mirrored_strategy
. As this implementation seems mostly covered by mirrored_run
I tried to mainly focus on this file instead. Say we have 1 MirroredVariable
with the following signature:
MirroredVariable {
0: <tf.Variable 'w:0' shape=() dtype=float32>,
1: <tf.Variable 'w/replica_1:0' shape=() dtype=float32>
}
I’ve tried to understand this behavior by altering the _call_for_each_replica
implementation so it runs every function sequentially on the defined device (just removing the replica threads). This works for variable creation, computation and reduction, but breaks when recording gradients. Say I have the following function:
@def_function.function
def step(x):
with backprop.GradientTape() as tape:
loss = w * x
optimizer.minimize(loss, var_list=[w], tape=tape)
strategy.run(step, args=(2.0,))
This yields:
No gradients provided for any variable: ['w:0']
Adding tape.watch(w)
doesn’t change anything and my guess is that it’s due to the function wrapping happening in call_for_each_replica
in mirrored_run
. Could anyone shine some light on how these gradients are recorded here?