I posted this question on Stack Overflow last week, but it didn’t get any engagement, so I’m hoping that posting here will help me connect with someone who can answer my questions.
I am working with distribute strategy scopes using custom training loops with Keras models. Consider this script which closely follows this tutorial: Custom training with tf.distribute.Strategy | TensorFlow Core
import os
import numpy as np
import tensorflow as tf
print(tf.__version__)
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
CREATE_MODEL_WITH_SCOPE = False
CREATE_MODEL_SEQUENTIAL = True
input_image = np.random.random((28, 28, 1))
target_image = np.random.random((28, 28, 1))
def data_generator():
while True:
yield input_image, target_image
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
BUFFER_SIZE = 60000 # len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(
GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
class MyTask:
def __init__(self, strategy):
self.strategy = strategy
with self.strategy.scope():
# Set reduction to `none` so we can do the reduction afterwards and divide by
# global batch size.
self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE
)
with self.strategy.scope():
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
# model, optimizer, and checkpoint must be created under `strategy.scope`.
if CREATE_MODEL_WITH_SCOPE:
with self.strategy.scope():
self.model = self.create_model()
else:
self.model = self.create_model()
with self.strategy.scope():
self.optimizer = tf.keras.optimizers.Adam()
self.checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.model)
self.discriminator = self.create_model()
def create_model(self):
input = tf.keras.Input(
shape=(28, 28, 1),
name="input"
)
layers = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
if CREATE_MODEL_SEQUENTIAL:
return layers
output = layers(input)
model = tf.keras.Model(
inputs=input,
outputs=output
)
return model
def create_model_sequential(self):
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
return model
def train_step(self, inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = self.model(images, training=True)
loss = self.compute_loss(labels, predictions)
gradients = tape.gradient(loss, self.model.trainable_variables)
assert len(gradients) == len(self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
# train_accuracy.update_state(labels, predictions)
return loss
def compute_loss(self, labels, predictions):
per_example_loss = self.loss_object(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
def test_step(self, inputs):
images, labels = inputs
predictions = self.model(images, training=False)
t_loss = self.loss_object(labels, predictions)
self.test_loss.update_state(t_loss)
self.test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(self, dataset_inputs):
per_replica_losses = strategy.run(self.train_step, args=(dataset_inputs,))
reduced_losses = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
return reduced_losses
@tf.function
def distributed_test_step(self, dataset_inputs):
return strategy.run(self.test_step, args=(dataset_inputs,))
def fit(self):
for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0
num_batches = 0
for x in train_dist_dataset:
new_loss = self.train_step(x)
total_loss += new_loss
num_batches += 1
train_loss = total_loss / num_batches
# TEST LOOP
for x in test_dist_dataset:
self.distributed_test_step(x)
if epoch % 2 == 0:
self.checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print(template.format(epoch + 1, train_loss,
self.train_accuracy.result() * 100, self.test_loss.result(),
self.test_accuracy.result() * 100))
# template = ("Epoch {}, Loss: {}")
# print (template.format(epoch+1, train_loss['loss1']))
self.test_loss.reset_states()
self.train_accuracy.reset_states()
self.test_accuracy.reset_states()
task = MyTask(strategy)
task.fit()
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
new_model = task.create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
predictions = new_model(images, training=False)
eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
for images, labels in test_dataset:
eval_step(images, labels)
print('Accuracy after restoring the saved model without strategy: {}'.format(
eval_accuracy.result() * 100))
There are two flags in this script that control the behaviour:
- CREATE_MODEL_WITH_SCOPE controls whether the model is created under a
with strategy.scope():
. If this flag is false, no explicit scope is used when creating the model. - CREATE_MODEL_SEQUENTIAL controls whether the model is a keras.Sequential model. If this flag is False, the model is wrapped with the Keras functional API, but it is otherwise the same model
Depending on the combination of flags that I use and the tensorflow environment, this script either works or produces an error such as this one:
Traceback (most recent call last):
File "scratches/scratch_85.py", line 194, in <module>
task.fit()
File "scratches/scratch_85.py", line 168, in fit
new_loss = self.train_step(x)
File "scratches/scratch_85.py", line 132, in train_step
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 604, in apply_gradients
self._create_all_weights(var_list)
File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 783, in _create_all_weights
self._create_slots(var_list)
File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/adam.py", line 127, in _create_slots
self.add_slot(var, 'm')
File "anaconda3/envs/tf_2.4/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py", line 838, in add_slot
raise ValueError(
ValueError: Trying to create optimizer slot variable under the scope for tf.distribute.Strategy (<tensorflow.python.distribute.distribute_lib._DefaultDistributionStrategy object at 0x7f597f1eadc0>), which is different from the scope used for the original variable (MirroredVariable:{
0: <tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 32) dtype=float32, numpy=
array(...)
}). Make sure the slot variables are created under the same strategy scope. This may happen if you're restoring from a checkpoint outside the scope
The expected output from the script looks like this:
Epoch 1, Loss: 0.5163882374763489, Accuracy: 0.0, Test Loss: 0.4203348755836487, Test Accuracy: 85.15999603271484
Epoch 2, Loss: 0.342644065618515, Accuracy: 0.0, Test Loss: 0.36492371559143066, Test Accuracy: 86.51000213623047
Epoch 3, Loss: 0.2957099378108978, Accuracy: 0.0, Test Loss: 0.30147236585617065, Test Accuracy: 89.18000030517578
Epoch 4, Loss: 0.2637444734573364, Accuracy: 0.0, Test Loss: 0.2926381230354309, Test Accuracy: 89.77000427246094
Epoch 5, Loss: 0.24089021980762482, Accuracy: 0.0, Test Loss: 0.2793895900249481, Test Accuracy: 90.36000061035156
Epoch 6, Loss: 0.221912682056427, Accuracy: 0.0, Test Loss: 0.26250553131103516, Test Accuracy: 90.30999755859375
Epoch 7, Loss: 0.20427824556827545, Accuracy: 0.0, Test Loss: 0.2791960835456848, Test Accuracy: 90.05000305175781
Epoch 8, Loss: 0.18922416865825653, Accuracy: 0.0, Test Loss: 0.24758891761302948, Test Accuracy: 91.23999786376953
Epoch 9, Loss: 0.17512068152427673, Accuracy: 0.0, Test Loss: 0.24345894157886505, Test Accuracy: 91.27999877929688
Epoch 10, Loss: 0.16123878955841064, Accuracy: 0.0, Test Loss: 0.23703424632549286, Test Accuracy: 91.55999755859375
Accuracy after restoring the saved model without strategy: 91.27999877929688
Here’s what I’ve observed on two different environments:
-
tf_2.4 = Tensorflow 2.4.1, python 3.8, CUDA 10.1.243
-
tf_2.5 = Tensorflow 2.5.0, python 3.8, CUDA 11.2.2
-
CREATE_MODEL_WITH_SCOPE=False, CREATE_MODEL_SEQUENTIAL=False
- tf_2.4: Works correctly
- tf_2.5: ValueError: Trying to create optimizer slot variable under the scope…
-
CREATE_MODEL_WITH_SCOPE=True, CREATE_MODEL_SEQUENTIAL=False
- tf_2.4: ValueError: Trying to create optimizer slot variable under the scope
- tf_2.5: Works correctly
-
CREATE_MODEL_WITH_SCOPE=True, CREATE_MODEL_SEQUENTIAL=True
- tf_2.4: Works correctly
- tf_2.5: ValueError: Trying to create optimizer slot variable under the scope
-
CREATE_MODEL_WITH_SCOPE=False, CREATE_MODEL_SEQUENTIAL=True
- tf_2.4: Works correctly
- tf_2.5: ValueError: Trying to create optimizer slot variable under the scope
From the documentation, I think that CREATE_MODEL_WITH_SCOPE should be True. But, I would like to understand what is going on in the various cases. When are the slot variables being created, and how can I make sure they have the proper scope? Why is there a different behaviour between the Sequential and Functional versions of the same model? Why do these two versions of tensorflow produce opposite results on these tests?
Any help in understanding these questions would be greatly appreciated. This question arose because I want to use a custom loop with a more complicated Keras Functional API model, but I encounter this slot variable error and I don’t know how to work around it. I can’t figure out how to manage the scopes and there’s no way to replace the model with a Sequential one.