Can you check my implementation of ParameterServerStrategy

I have this implementation of parameterServerStrategy, and I need some explination about the strategy, because the implementation works sometimes, and sometimes it gives errors.

import tensorflow as tf
import multiprocessing
import os
import portpicker

def create_in_process_cluster(num_workers, num_ps):
    """Creates and starts local servers and returns the cluster_resolver."""
    worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
    ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]

    cluster_dict = {
        "worker": ["localhost:%s" % port for port in worker_ports],
        "ps": ["localhost:%s" % port for port in ps_ports]
    }

    cluster_spec = tf.train.ClusterSpec(cluster_dict)

    # Workers need inter_op threads to work properly.
    worker_config = tf.compat.v1.ConfigProto()
    if multiprocessing.cpu_count() < num_workers + 1:
        worker_config.inter_op_parallelism_threads = num_workers + 1

    # Launch Worker and Parameter Server processes
    for i in range(num_workers):
        tf.distribute.Server(
            cluster_spec,
            job_name="worker",
            task_index=i,
            config=worker_config,
            protocol="grpc"
        )

    for i in range(num_ps):
        tf.distribute.Server(
            cluster_spec,
            job_name="ps",
            task_index=i,
            protocol="grpc"
        )

    return tf.distribute.cluster_resolver.SimpleClusterResolver(
        cluster_spec, rpc_layer="grpc"
    )

NUM_WORKERS = 3
NUM_PS = 1
cluster_resolver = create_in_process_cluster(NUM_WORKERS, NUM_PS)



strategy = tf.distribute.ParameterServerStrategy(
    cluster_resolver
)

coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
with strategy.scope():
    def create_model():
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        return model

    model = create_model()
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
    accuracy = tf.keras.metrics.Accuracy()

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Normalize the dataset
x_train, x_test = x_train / 255.0, x_test / 255.0


# Add channel dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

BATCH_SIZE = 64

def dataset_fn(input_context):
    """Distributes dataset across workers."""
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(10000).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
    return train_dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)

@tf.function
def step_fn(iterator):
    """A single training step for workers."""

    def worker_step(batch_data, labels):
        with tf.GradientTape() as tape:
            predictions = model(batch_data, training=True)
            per_example_loss = loss_object(labels, predictions)
            loss = tf.reduce_mean(per_example_loss) 


        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        accuracy.update_state(labels, tf.argmax(predictions, axis=1))
        return loss

    batch_data, labels = next(iterator)
    losses = strategy.run(worker_step, args=(batch_data, labels))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None)

@tf.function
def per_worker_dataset_fn():
  return strategy.distribute_datasets_from_function(dataset_fn)

per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)



# Train the model
num_epochs = 4
steps_per_epoch = 5

for epoch in range(num_epochs):

    for _ in range(steps_per_epoch):
        coordinator.schedule(step_fn, args=(per_worker_iterator,))

    accuracy.reset_state()

    # Wait for all scheduled steps to complete before starting next epoch
    coordinator.join()

    print(f"Finished epoch {epoch+1}, Accuracy: {accuracy.result().numpy():.4f}")

Hi @Ali_saaeddin, Apologies for the delay response!
Could you please share the error message you are encountering, To help us investigate this issue more efficiently. Also, kindly refer to this document for more information about ParameterServerStrategy. Thanks!