tf.data.Dataset with tf.distribute

Hello, tensorflow community. I’m trying to run training a bi-directional LSTM model on a machine with 8 A100 GPUs. For this I use tf.distribute.MirroredStrategy. My dataset is huge and consists of millions of samples. I initially load the dataset as a numpy array. I’m trying to follow the official documentation and create tf.data.Dataset.from_tensor_slices() but I get a memory overflow error. I don’t know why, but apparently tf is trying to load the entire dataset on the GPU (and apparently on one GPU) and then convert it to Datasets. I decided to create a tf.data.Dataset like this with tf.device('CPU'): dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
And it worked. Also, when creating the dataset, I specified the batch size equal to the global batch. After starting the training, I noticed that it was somehow slow. I decided to downsize the original dataset in order to test distributed learning without creating a tf.data.Dataset
Indeed, if numpy arrays are passed to the fit method of the model, then training is 2-3 times faster compared to tf.dataDataset. Tell me why is this happening? What are the best practices when using tf.data.Datasets in distributed learning. Since in order to train the model on all my data, I still have to convert from to a generator, otherwise I get an error about GPU memory overflow.

Hello @geometryk
Thank you for using TensorFlow
While using tf.data.Dataset.from_tensor_slices() the dataset is loaded on to the memory, which causes the overflow error, It’s better to load dataset on CPU and then tensorflow loads in a better way onto GPU without error.
The best practice is to use from_generator which generate data on the go and entire data need not be stored on memory, also using prefetch(buffer_size=tf.data.AUTOTUNE) method, the data can be loaded batch wise, while the current batch is in training. Shuffling and also parallelizing the data can also help in fastening the training.