Runinng tf.distribute.MultiWorkerMirroredStrategy

i have connected two machine together using a switch and they can communicate and ping each other. I am trying to run distributed training on them and as a test first im starting with a simple model. However i cant get around this error.

import os
import json
import tensorflow as tf
import tensorflow_datasets as tfds

NUM_WORKERS = 2
IP_ADDRS = ['192.168.1.10','192.168.1.11']
PORTS = [20000,20001]

os.environ['TF_CONFIG'] = json.dumps({
  'cluster': {
    'worker': ['%s:%d' % (IP_ADDRS[w], PORTS[w]) for w in range(NUM_WORKERS)]
  },
  'task': {'type': 'worker', 'index': 1}
})


BUFFER_SIZE = 10000
BATCH_SIZE = 64
LEARNING_RATE = 1e-4


def prepare_dataset():
    datasets = tfds.load('mnist', as_supervised=True)
    train_dataset = datasets['train']
    test_dataset = datasets['test']

    def scale(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        image = tf.expand_dims(image, -1)  
        return image, label
                                                                                                                                                                                                                    
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

    train_dataset = (train_dataset.map(scale)
                                  .shuffle(BUFFER_SIZE)
                                  .batch(BATCH_SIZE)
                                  .with_options(options))  

    test_dataset = (test_dataset.map(scale)
                                .batch(BATCH_SIZE)
                                .with_options(options))
    return train_dataset, test_dataset

def build_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

strategy = tf.distribute.MultiWorkerMirroredStrategy()

train_dataset, test_dataset = prepare_dataset()

with strategy.scope():
    model = build_model()

model.fit(train_dataset, epochs=10, validation_data=test_dataset)

loss, accuracy = model.evaluate(test_dataset)
print(f"Test accuracy: {accuracy:.4f}")

and this is the error:
I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
2025-02-06 18:38:28.488146: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1738856308.495593 32908 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738856308.497828 32908 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1738856308.503623 32908 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1738856308.503632 32908 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1738856308.503634 32908 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1738856308.503636 32908 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-02-06 18:38:28.505584: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
I0000 00:00:1738856309.346805 32908 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13813 MB memory: → device: 0, name: NVIDIA GeForce RTX 4080, pci bus id: 0000:01:00.0, compute capability: 8.9
I0000 00:00:1738856309.348309 32908 gpu_device.cc:2019] Created device /job:worker/replica:0/task:1/device:GPU:0 with 13813 MB memory: → device: 0, name: NVIDIA GeForce RTX 4080, pci bus id: 0000:01:00.0, compute capability: 8.9
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1738856309.351987 32908 grpc_server_lib.cc:463] Started server with target: grpc://192.168.1.11:20001
I0000 00:00:1738856324.785742 32908 coordination_service_agent.cc:369] Coordination agent has successfully connected.
AttributeError: module ‘ml_dtypes’ has no attribute ‘float4_e2m1fn’
/home/n2/miniconda3/envs/cluster/lib/python3.11/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an input_shape/input_dim argument to a layer. When using Sequential models, prefer using an Input(shape) object as the first layer in the model instead.
super().init(activity_regularizer=activity_regularizer, **kwargs)
2025-02-06 18:38:45.549428: I tensorflow/core/kernels/data/tf_record_dataset_op.cc:387] The default buffer size is 262144, which is overridden by the user specified buffer_size of 8388608
Traceback (most recent call last):
File “/home/n2/script.py”, line 73, in
model.fit(train_dataset, epochs=10, validation_data=test_dataset)
File “/home/n2/miniconda3/envs/cluster/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py”, line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File “/home/n2/miniconda3/envs/cluster/lib/python3.11/site-packages/tensorflow/python/framework/constant_op.py”, line 108, in convert_to_eager_tensor
return ops.EagerTensor(value, ctx.device_name, dtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Attempt to convert a value (PerReplica:{
0: <tf.Tensor: shape=(32, 28, 28, 1, 1), dtype=float32, numpy=
array([[[[[0.]],

     [[0.]],

     [[0.]],

     ...,

     [[0.]],

     [[0.]],

     [[0.]]],


    [[[0.]],

     [[0.]],

     [[0.]],

     ...,

     [[0.]],

     [[0.]],

     [[0.]]],


    [[[0.]],

     [[0.]],

     [[0.]],

     ...,

     [[0.]],

     [[0.]],

     [[0.]]],
     [[0.]]]]], dtype=float32)>

}) with an unsupported type (<class ‘tensorflow.python.distribute.values.PerReplica’>) to a Tensor.
2025-02-06 18:38:45.769598: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat(). You should use dataset.take(k).cache().repeat() instead.