Retracing with Distributed Training

Hi guys,

I am trying a custom model with distributed training on multiple GPU with tf.function(). The graph tends to compile at every call. To solve the issue i passed the input_signature argument with specified tf.TensorSpec() on the tf.function() which works fine for 1 gpu, however when i use multiple gpus, it returns the error ‘Perreplica does not have dtype’.

Please any idea, how i an solve this problem?

Hello @King_Gee

Thank you for using TensorFlow
In the training step, add following line to avoid the perreplica issues for dtypes, as mentioned in the documentation
per_replica_inputs = strategy.experimental_local_results(inputs)
and give the per_replica_inputs to model as a list.