Hi TF developer community.
I’m a maintainer of a library for doing distributed reinforcement learning called RLlib. I’m writing a new distributed training stack using tensorflow 2.11. I have a question about a warning that I see when using tf.function
with tf.distribute.strategy.run
.
I frequently see the following warning:
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead
currently. We will be working on improving this in the future, but for now
please wrap `call_for_each_replica` or `experimental_run` or `run` inside a
tf.function to get the best performance.
the pattern of code that roughly describes my setup (is not an exact reproduction for simplicity) is as follows:
def _do_update_fn(self, batch) -> Mapping[str, Any]:
def helper(_batch):
with tf.GradientTape() as tape:
# This is an underlying feed forward call to multiple keras models
fwd_out = self.keras_model_container.forward_train(_batch)
# This is a loss computation
loss = self.compute_loss(fwd_out=fwd_out, batch=_batch)
gradients = self.compute_gradients(loss, tape)
self.apply_gradients(gradients)
return {
"loss": loss,
"fwd_out": fwd_out,
"postprocessed_gradients": gradients,
}
# self.strategy is a tf.distribute.strategy object
return self.strategy.run(helper, args=(batch,))
update_fn = tf.function(self._do_update_fn, reduce_retracing=True)
batch = ...
update_fn(batch)
Can anyone explain to me why I might be getting this warning despite the fact that I am using tf.function
.
Thanks!