Regenerate dataset after n epochs

Is there an easy way to construct a tf.data pipeline that regenerates the dataset itself after being used in an n number of Keras training epochs?

Background

I am training Keras sequential CNN-model, pretty much in a textbook example way, I am a newbie here.

My python generator spits out any number of random training samples (a pair of an RGB image and a coordinate-pair) in the following manner:

game = Game()
    dataset = tf.data.Dataset.from_generator(
         game.generate_training_image,
         output_signature=(
             tf.TensorSpec(shape=(768,768,3), dtype=tf.uint8),
             tf.TensorSpec(shape=(2), dtype=tf.float32)
         ))

The generator runs indefinitely long, so my data pipeline definition is something like this (I use small numbers for testing):

train_batches = dataset.take(20).cache().shuffle(20).batch(5).prefetch(tf.data.AUTOTUNE)
val_batches = dataset.take(10).cache().batch(5).prefetch(tf.data.AUTOTUNE)

The problem and the goal

The problem itself is Keras’ model.fit always consumes the entire dataset for an epoch, unless you define steps_per_epoch. In the former case adding repeat will just increase the length of a single epoch, in the latter one, the training simply halts once the dataset exhausts.

In my past experience (using 2000 samples) the model converges fine for about 50 epochs, where it stalls and starts to overfit, however feeding a fresh dataset helps to keep further increasing its precision.

The goal is to easily replace the dataset after (here:) every 50 epochs and keep training indefinitely.

My naive approaches

1. Running the training for 50 epochs then manually regenerating the dataset

Meh…

2. Adding a repeat(50)

This seemingly is equivalent to the goal, however I am not happy about 50 epochs getting just squished into a single one with a validation occurring once in a half an hour.

3. Writing a custom Keras callback

By hooking onto on_epoch_end, doing a count, then replacing the dataset on every 50th sounds like an option (have not tried yet, though). Not happy about the complexity creep, either.

Bonus question

Is there a difference between consecutively training 2000 x 5 samples vs. running 10,000 all at once? If no, the problem instantly degrades into a lot easier pipeline/cache problem.

Your thoughts are overly appreciated!

1 Like

Hi @Marton_Krauter & welcome to the Tensorflow forum.
I don’t quite understand why one would do what you’re willing to do (This may be just me!). Either the distribution of that new dataset is the same as the one you start training your model with in which case plugging it will not help in anyway. Or, the distribution is different, and it is an issue, no?
Basically why aren’t you satistifed as your model “converges fine”? What is "tradining indifinitely?
Isn’t one simple way to achieve what you want to do though saving your model / parameters at the desired stage and load it again to continue training it with another dataset?

1 Like

Hi @tagoma, thank you for the kind message and apologies for the belated reply.

I rephrase my question:
Given you have an unlimited source of (uniform distribution) training data, what would be your strategy to train a CNN effectively?

In my (rather limited) understanding you’ll need to find an equilibrium between:

  • keeping the samples in the training loop to let the model learn from them
  • occasionally replace the samples to induce generalization and to prevent overfitting

Yes, manually replacing the dataset after n epochs obviously works, though I was wondering if there’s a way to accomplish this in a less babysitting manner.