Hello:
I trying to pratice the example of Multi-worker training with Keras in colab:
import json
import os
import sys
os.environ[“CUDA_VISIBLE_DEVICES”] = “-1”
os.environ.pop(‘TF_CONFIG’, None) ```
if ‘.’ not in sys.path:
sys.path.insert(0, ‘.’)
!pip install tf-nightly
import tensorflow as tf
%%writefile mnist_setup.py
import os
import tensorflow as tf
import numpy as np
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation=‘relu’),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=‘relu’),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=[‘accuracy’])
return model
import mnist_setupbatch_size = 64single_worker_dataset = mnist_setup.mnist_dataset(batch_size)single_worker_model = mnist_setup.build_and_compile_cnn_model()single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)
tf_config = { ‘cluster’: { ‘worker’: [‘localhost:12345’, ‘localhost:23456’] }, ‘task’: {‘type’: ‘worker’, ‘index’: 0}}
json.dumps(tf_config)
os.environ[‘GREETINGS’] = ‘Hello TensorFlow!’
!echo ${GREETINGS}
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope(): # Model building/compiling need to be within strategy.scope()
. multi_worker_model = mnist_setup.build_and_compile_cnn_model()
%%writefile main.py
import os
import json
import tensorflow as tf
import mnist_setup
per_worker_batch_size = 64
tf_config = json.loads(os.environ[‘TF_CONFIG’])
num_workers = len(tf_config[‘cluster’][‘worker’])
strategy = tf.distribute.MultiWorkerMirroredStrategy()
global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)
with strategy.scope():
multi_worker_model = mnist_setup.build_and_compile_cnn_model()
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
!ls *.py
os.environ[‘TF_CONFIG’] = json.dumps(tf_config)
%killbgscripts
!python main.py &> job_0.log
The last step(!python main.py &> job_0.log), it run over and over, no stop.
Please tell me how to solve it, Thank you very much.