I try to load a model within MirroredStategy. I find that the loaded model within MirroredStategy is not working correctly in that only one replica is found, while there are 4 visible devices specified actually. This does not happen for the model that is directly constructed within MirroredStategy.
It is worth mentioning that the subclassing tf.keras.models.Model
and tf.keras.layers.Layer
are used here, which I think may be the cause of this wrong behavior. I have confirmed that loading an saved tf.keras.Sequential
model works well within MirroredStategy.
Reproducible code:
import tensorflow as tf
import tensorflow.distribute as tf_dist
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.ops import array_ops
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
def _get_current_replica_id_in_group_sync():
replica_ctx = tf_dist.get_replica_context()
if replica_ctx:
replica_id = replica_ctx.replica_id_in_sync_group
else:
replica_id = distribute_lib.get_update_replica_id()
if replica_id is None:
replica_id = array_ops.constant(0, dtype=array_ops.dtypes.int32)
return replica_id
class TestLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestLayer, self).__init__(**kwargs)
def call(self, inputs, training=False):
global_replica_id = _get_current_replica_id_in_group_sync()
tf.print("global_replica_id: {}".format(global_replica_id))
emb_vector = tf.zeros_like(inputs)
return emb_vector
class Demo(tf.keras.models.Model):
def __init__(self, **kwargs):
super(Demo, self).__init__(**kwargs)
self.test_layer = TestLayer()
self.dense_layer = tf.keras.layers.Dense(units=1, activation=None,
kernel_initializer="ones",
bias_initializer="zeros")
def call(self, inputs):
vector = self.test_layer(inputs)
logit = self.dense_layer(vector)
return logit, vector
def summary(self):
inputs = tf.keras.Input(shape=(10,), dtype=tf.int64)
model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
return model.summary()
@tf.function
def _step(inputs, labels, model):
logit, vector = model(inputs)
return logit, vector
def tf_dataset(keys, labels, batchsize, repeat):
dataset = tf.data.Dataset.from_tensor_slices((keys, labels))
dataset = dataset.repeat(repeat)
dataset = dataset.batch(batchsize, drop_remainder=True)
return dataset
def _dataset_fn(input_context):
global_batch_size = 16384
keys = np.ones((global_batch_size, 10))
labels = np.random.randint(low=0, high=2, size=(global_batch_size, 1))
replica_batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf_dataset(keys, labels, batchsize=replica_batch_size, repeat=1)
dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
return dataset
# Save model within MirroredStrategy scope
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
model = Demo()
model.compile()
model.summary()
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
print("-" * 30, "step ", str(i), "-" * 30)
logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model))
# model(tf.keras.Input(shape=(10,), dtype=tf.int64))
model.save("demo")
# Load model within MirroredStrategy scope
with strategy.scope():
model2 = tf.keras.models.load_model("demo")
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
print("-" * 30, "step ", str(i), "-" * 30)
logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model2))
Actual log
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step 0 ------------------------------
global_replica_id: 0
Expected log
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
The log is from the line tf.print("global_replica_id: {}".format(global_replica_id))
within TestLayer.call
.