Hi everyone,
when training my model using model.fit() and using tf.data for my training and validation data the GPU usage dips to 0% after each epoch even though I am using the prefetch method for tf.data.Dataset.
Have you experienced something similar?
Sadly I cannot provide any code.
Thank you in advance.
My first two guesses would be:
-
The dataset needs to refill the shuffle buffer after each epoch, like model.fit(ds.shuffle(buffer_size).repeat())
instead of model.fit(ds.repeat().shuffle(), steps_per_epoch=N)
-
Maybe something with the evaluation logic?
Thank you for your reply.
Currently I am using
model.fit(train_data, epochs=self.epochs, validation_data=val_data, verbose=1)
Where train_data
is a tf.data.Dataset
with
train_data = tf.data.Dataset.from_tensor_slices((train_ivs, train_logr, train_metric))
train_data = train_data.shuffle(buffer_size=train_ivs.shape[0], seed=self.seed,
reshuffle_each_iteration=True)
train_data = train_data.batch(self.batch_size)
train_data = train_data.prefetch(tf.data.AUTOTUNE)
The evaluation step does not seem to cause any problem either.