Graph mode input signature for distributed dataset

I am trying to specify input_signature for tf.function decorator of a function taking in distributed dataset, but the PerReplica object cannot be converted to tensor and the signature requires something that can be converted to tensor, how can I solve this?

This is the code segment:

 @tf.function#(input_signature=[tf.TensorSpec((None, None, None, None), dtype='float64'), tf.TensorSpec((None, None), dtype='uint8')])
    def dist_train_step(image, label):
        print('-dist_train_step trace: image: ', image.shape, image.dtype, ' label: ', label.shape, label.dtype)
        per_replica_loss_values =, args=(image, label))
        reduced_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_values, axis=None)
        print('exiting dist_train_step')
        return reduced_loss

This is the error I get when I try to specify the input signature:

ValueError                                Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _convert_inputs_to_signature(inputs, input_signature, flat_input_signature)
   2879         flatten_inputs[index] = ops.convert_to_tensor(
-> 2880             value, dtype_hint=spec.dtype)
   2881         need_packing = True

/opt/conda/lib/python3.7/site-packages/tensorflow/python/profiler/ in wrapped(*args, **kwargs)
    162           return func(*args, **kwargs)
--> 163       return func(*args, **kwargs)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
   1565     if ret is None:
-> 1566       ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    345   _ = as_ref
--> 346   return constant(v, dtype=dtype, name=name)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in constant(value, dtype, shape, name)
    271   return _constant_impl(value, dtype, shape, name, verify_shape=False,
--> 272                         allow_broadcast=True)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast)
    282         return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
--> 283     return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
    307   """Creates a constant on the current device."""
--> 308   t = convert_to_eager_tensor(value, ctx, dtype)
    309   if shape is None:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ in convert_to_eager_tensor(value, ctx, dtype)
    105   ctx.ensure_initialized()
--> 106   return ops.EagerTensor(value, ctx.device_name, dtype)

ValueError: Attempt to convert a value (PerReplica:{
  0: <tf.Tensor: shape=(32, 28, 28, 1), dtype=float64, numpy=










































  1: <tf.Tensor: shape=(32, 28, 28, 1), dtype=float64, numpy=










































}) with an unsupported type (<class 'tensorflow.python.distribute.values.PerReplica'>) to a Tensor.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_23/ in <module>
     80 start = time.time()
---> 81 train_losses, validation_losses = train()
     82 print('time: ',time.time() - start)

/tmp/ipykernel_23/ in train(train_data, test_data, batch_size, strategy)
     64         for image, label in train_data:
---> 65             total_training_loss.assign_add(dist_train_step(image, label))

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in __call__(self, *args, **kwds)
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    887       new_tracing_count = self.experimental_get_tracing_count()

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _call(self, *args, **kwds)
    952       _, _, _, filtered_flat_args = \
    953           self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
--> 954               *args, **kwds)
    955       # If we did not create any variables the trace we have is good enough.
    956       return self._concrete_stateful_fn._call_flat(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in canonicalize_function_inputs(self, *args, **kwargs)
   2785       assert not kwargs
   2786       inputs, flat_inputs, filtered_flat_inputs = _convert_inputs_to_signature(
-> 2787           inputs, self._input_signature, self._flat_input_signature)
   2789     self._validate_inputs(flat_inputs)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/ in _convert_inputs_to_signature(inputs, input_signature, flat_input_signature)
   2884                          "the Python function must be convertible to "
   2885                          "tensors:\n%s" %
-> 2886                          format_error_message(inputs, input_signature))
   2888   if any(not spec.is_compatible_with(other) for spec, other in zip(

ValueError: When input_signature is provided, all inputs to the Python function must be convertible to tensors:
  inputs: (
  0: tf.Tensor(










































   [0.]]]], shape=(32, 28, 28, 1), dtype=float64),
  1: tf.Tensor(










































   [0.]]]], shape=(32, 28, 28, 1), dtype=float64)
  0: tf.Tensor(
 [3]], shape=(32, 1), dtype=uint8),
  1: tf.Tensor(
 [0]], shape=(32, 1), dtype=uint8)
  input_signature: (
    TensorSpec(shape=(None, None, None, None), dtype=tf.float64, name=None),
    TensorSpec(shape=(None, None), dtype=tf.uint8, name=None))

Hi @lakshmi_poda,

Sorry for the delay in response.

This Error is due to returns a PerReplica object, which cannot be directly converted to a tensor when declaring the @tf.function with the input signature as tf.TensorSpec. Since you are using distributed dataset,I would recommend not to specify input signature or create a separate function to handle dataset which returns PerReplica object.

Thank You.