Seeing warning saying that I am not using `tf.function` when calling `tf.distribute.strategy.run` however I am

Hi TF developer community.

I’m a maintainer of a library for doing distributed reinforcement learning called RLlib. I’m writing a new distributed training stack using tensorflow 2.11. I have a question about a warning that I see when using tf.function with tf.distribute.strategy.run.

I frequently see the following warning:

WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead
currently. We will be working on improving this in the future, but for now
please wrap `call_for_each_replica` or `experimental_run` or `run` inside a
tf.function to get the best performance.

the pattern of code that roughly describes my setup (is not an exact reproduction for simplicity) is as follows:

def _do_update_fn(self, batch) -> Mapping[str, Any]:
        def helper(_batch):
            with tf.GradientTape() as tape:
                # This is an underlying feed forward call to multiple keras models
                fwd_out = self.keras_model_container.forward_train(_batch)
               # This is a loss computation
                loss = self.compute_loss(fwd_out=fwd_out, batch=_batch)
            gradients = self.compute_gradients(loss, tape)
            self.apply_gradients(gradients)
            return {
                "loss": loss,
                "fwd_out": fwd_out,
                "postprocessed_gradients": gradients,
            }
        # self.strategy is a tf.distribute.strategy object
        return self.strategy.run(helper, args=(batch,))


update_fn = tf.function(self._do_update_fn, reduce_retracing=True)

batch = ...
update_fn(batch)

Can anyone explain to me why I might be getting this warning despite the fact that I am using tf.function.

Thanks!

Hi @Avnish_Narayan,

Thanks for using Tensorflow Forum.

This above warning is due to not wrapping the tf.distribute.strategy(self.strategy.run) object with tf.function .It seems you declared tf.function as a function call where it can compile different graphs dynamically based on runtime conditions which may result in less efficiency. Here, so I recommend to use the @tf.function decorator with mirror strategy to ensures that the same compiled graph is replicated across all devices which is more efficient.

You can rewrite the code with tf.function decorator as follows:

@tf.function(reduce_retracing=True)
def _do_update_fn(self, batch) -> Mapping[str, Any]:
        def helper(_batch):
                ......

Thank You.