My model has some learnable embeddings with shapes (N, D). I use parameter-worker distribution architecture to train the model.
In each training step, only a part of the embeddings will be updated during back-propagation. I want to reset those unused (not updated) embeddings after back-propagation. How should I implement it? If I perform it in the build_model process without any gradient, will it be changed simultaneously in the parameter servers and workers? Theoretically, I think it should be implemented in parameter servers after parameter updating, but I have no idea how to implement it.
My training code is as follows, does anyone know where should I insert code to modify those unused embeddings?
import tensorflow as tf
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# start ps or worker
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
if FLAGS.job_name == 'ps':
server.join()
elif FLAGS.job_name == "worker":
# Client
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,cluster=cluster)):
# Build model...
loss = ...
train_op = ...
with tf.train.MonitoredTrainingSession(master="/job:worker/task:0",is_chief=(FLAGS.task_index == 0),checkpoint_dir="/tmp/train_logs") as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
Here is a PyTorch distributed implementation, it is written in the forward-propagation function. In TensorFlow, where should I write the code?
import torch.distributed as dist
# here to reset unused embeddings
if self.restart_unused_codes:
# generate new embeddings for unused embeddings
if n_vectors < n_embed:
vectors = self._tile_with_noise(vectors, n_embed)
n_vectors = vectors.shape[0]
_vectors_random = vectors[torch.randperm(n_vectors, device=vectors.device)][:n_embed]
# Broadcast the new embedding to each node in the distributed system
if dist.is_initialized():
dist.broadcast(_vectors_random, 0)
# Assign new embeddings to those unused embeddings.
usage = (self.cluster_size_ema.view(-1, 1) >= 1).float()
self.embed_ema.mul_(usage).add_(_vectors_random * (1-usage))
self.cluster_size_ema.mul_(usage.view(-1))
self.cluster_size_ema.add_(torch.ones_like(self.cluster_size_ema) * (1-usage).view(-1))