TPU v3-8 TensorFlow CrossReplicaSum Error

I’m getting this error when I fit my model.

tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:112] Asked to propagate a dynamic dimension from hlo transpose.3750@{}@0 to hlo %all-reduce.3755 = f32[<=70,256]{1,0} all-reduce(f32[<=70,256]{1,0} %transpose.3750), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=%sum.3751, metadata={op_type="CrossReplicaSum" op_name="CrossReplicaSum_33" source_file="dummy_file_name" source_line=10}, which is not implemented.

1013 tpu_program_group.cc:90] Check failed: xla_tpu_programs.size() > 0 (0 vs. 0) 

However I am passing explicit input shapes like so:

Config.COMPUTED_BATCH_SIZE = 128

with strategy.scope():
    model = my_model()
    input_shapes = [
        [Config.COMPUTED_BATCH_SIZE, 192], 
        [Config.COMPUTED_BATCH_SIZE, 192], 
        [COMPUTED_CHANNELS, 105, 129, 100], 
        [COMPUTED_CHANNELS, 105, 129, 100], 
        [COMPUTED_CHANNELS, 105, 129, 100], 
        [Config.COMPUTED_BATCH_SIZE, 70], 
        [Config.COMPUTED_BATCH_SIZE, 320]
    ]
    model.build(input_shape=input_shapes)

Edit:

I’ve tracked it down to this code:

hidden_size = 128
self.descriptor_embedding = layers.Dense(
    hidden_size * 2, # 256
    activation='relu',
    input_shape=(Config.COMPUTED_BATCH_SIZE, 70)
)


learned_descriptors = tf.expand_dims(
    self.descriptor_embedding(descriptors),
    1
) # [BS, 1, HS * 2]

Any ideas?

The error you’re encountering likely stems from dynamic dimensions in your TensorFlow model not being properly handled during TPU operations. Ensure all input shapes are fully defined and static. Adjust your Dense layer to use input_shape=(70,) instead Patient Gateway of including the batch size. Verify that all shapes in your model are consistent. Additionally, place all TPU operations within the TPU strategy scope and update TensorFlow to the latest version. Simplifying your model can help isolate the issue. For example:

hidden_size = 128
self.descriptor_embedding = layers.Dense(
    hidden_size * 2,  # 256
    activation='relu',
    input_shape=(70,)  # Note the change here
)

learned_descriptors = tf.expand_dims(
    self.descriptor_embedding(descriptors),
    axis=1  # Ensure this is the correct axis you want to expand
)