How to modify an embedding directly in tensorflow distributed training

My model has some learnable embeddings with shapes (N, D). I use parameter-worker distribution architecture to train the model.

In each training step, only a part of the embeddings will be updated during back-propagation. I want to reset those unused (not updated) embeddings after back-propagation. How should I implement it? If I perform it in the build_model process without any gradient, will it be changed simultaneously in the parameter servers and workers? Theoretically, I think it should be implemented in parameter servers after parameter updating, but I have no idea how to implement it.

My training code is as follows, does anyone know where should I insert code to modify those unused embeddings?

import tensorflow as tf

ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# start ps or worker 
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)

if FLAGS.job_name == 'ps':
  server.join()
elif FLAGS.job_name == "worker":  
  # Client
  with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,cluster=cluster)):
    # Build model...
    loss = ...
    train_op = ...

  with tf.train.MonitoredTrainingSession(master="/job:worker/task:0",is_chief=(FLAGS.task_index == 0),checkpoint_dir="/tmp/train_logs") as mon_sess:
    while not mon_sess.should_stop():
      mon_sess.run(train_op)

Here is a PyTorch distributed implementation, it is written in the forward-propagation function. In TensorFlow, where should I write the code?

import torch.distributed as dist
# here to reset unused embeddings
if self.restart_unused_codes:
           # generate new embeddings for unused embeddings
            if n_vectors < n_embed:
                vectors = self._tile_with_noise(vectors, n_embed)
            n_vectors = vectors.shape[0]
            _vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed]
            
            # Broadcast the new embedding to each node in the distributed system
            if dist.is_initialized():
                dist.broadcast(_vectors_random, 0)
        
            # Assign new embeddings to those unused embeddings.
            usage = (self.cluster_size_ema.view(-1, 1) >= 1).float()
            self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage))
            self.cluster_size_ema.mul_(usage.view(-1))
            self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1))

Hi @Nixon_Jin,

Sorry for the delay in response.
I suggest to check which embeddings didn’t receive any updates by looking at the computed gradients.Once you apply the gradients with optimizer.apply_gradients, you can use tf.control_dependencies to create a reset operation (using tf.assign) for those embeddings that weren’t updated. Make sure to include these reset operations in your training step so that the unused embeddings are updated only when needed.Kindly refer this documentation about tf.control_dependencies for more information.

Hope this helps.Thank You.