Can you check my implementation of ParameterServerStrategy

I have this implementation of parameterServerStrategy, and I need some explination about the strategy, because the implementation works sometimes, and sometimes it gives errors.

import tensorflow as tf
import multiprocessing
import os
import portpicker

def create_in_process_cluster(num_workers, num_ps):
    """Creates and starts local servers and returns the cluster_resolver."""
    worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
    ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]

    cluster_dict = {
        "worker": ["localhost:%s" % port for port in worker_ports],
        "ps": ["localhost:%s" % port for port in ps_ports]
    }

    cluster_spec = tf.train.ClusterSpec(cluster_dict)

    # Workers need inter_op threads to work properly.
    worker_config = tf.compat.v1.ConfigProto()
    if multiprocessing.cpu_count() < num_workers + 1:
        worker_config.inter_op_parallelism_threads = num_workers + 1

    # Launch Worker and Parameter Server processes
    for i in range(num_workers):
        tf.distribute.Server(
            cluster_spec,
            job_name="worker",
            task_index=i,
            config=worker_config,
            protocol="grpc"
        )

    for i in range(num_ps):
        tf.distribute.Server(
            cluster_spec,
            job_name="ps",
            task_index=i,
            protocol="grpc"
        )

    return tf.distribute.cluster_resolver.SimpleClusterResolver(
        cluster_spec, rpc_layer="grpc"
    )

NUM_WORKERS = 3
NUM_PS = 1
cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)



strategy = tf.distribute.ParameterServerStrategy(
    cluster_resolver
)

coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
with strategy.scope():
    def create_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        return model

    model = create_model()
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    accuracy = tf.keras.metrics.Accuracy()

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Normalize the dataset
x_train, x_test = x_train / 255.0, x_test / 255.0


# Add channel dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

BATCH_SIZE = 64

def dataset_fn(input_context):
    """Distributes dataset across workers."""
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
    return train_dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)

@tf.function
def step_fn(iterator):
    """A single training step for workers."""

    def worker_step(batch_data, labels):
        with tf.GradientTape() as tape:
            predictions = model(batch_data, training=True)
            per_example_loss = loss_object(labels, predictions)
            loss = tf.reduce_mean(per_example_loss) 


        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        accuracy.update_state(labels, tf.argmax(predictions, axis=1))
        return loss

    batch_data, labels = next(iterator)
    losses = strategy.run(worker_step, args=(batch_data, labels))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)



# Train the model
num_epochs = 4
steps_per_epoch = 5

for epoch in range(num_epochs):

    for _ in range(steps_per_epoch):
        coordinator.schedule(step_fn, args=(per_worker_iterator,))

    accuracy.reset_state()

    # Wait for all scheduled steps to complete before starting next epoch
    coordinator.join()

    print(f"Finished epoch {epoch+1}, Accuracy: {accuracy.result().numpy():.4f}")