Is there an easy way to construct a tf.data
pipeline that regenerates the dataset itself after being used in an n number of Keras training epochs?
Background
I am training Keras sequential CNN-model, pretty much in a textbook example way, I am a newbie here.
My python generator spits out any number of random training samples (a pair of an RGB image and a coordinate-pair) in the following manner:
game = Game()
dataset = tf.data.Dataset.from_generator(
game.generate_training_image,
output_signature=(
tf.TensorSpec(shape=(768,768,3), dtype=tf.uint8),
tf.TensorSpec(shape=(2), dtype=tf.float32)
))
The generator runs indefinitely long, so my data pipeline definition is something like this (I use small numbers for testing):
train_batches = dataset.take(20).cache().shuffle(20).batch(5).prefetch(tf.data.AUTOTUNE)
val_batches = dataset.take(10).cache().batch(5).prefetch(tf.data.AUTOTUNE)
The problem and the goal
The problem itself is Keras’ model.fit
always consumes the entire dataset for an epoch, unless you define steps_per_epoch
. In the former case adding repeat
will just increase the length of a single epoch, in the latter one, the training simply halts once the dataset exhausts.
In my past experience (using 2000 samples) the model converges fine for about 50 epochs, where it stalls and starts to overfit, however feeding a fresh dataset helps to keep further increasing its precision.
The goal is to easily replace the dataset after (here:) every 50 epochs and keep training indefinitely.
My naive approaches
1. Running the training for 50 epochs then manually regenerating the dataset
Meh…
2. Adding a repeat(50)
This seemingly is equivalent to the goal, however I am not happy about 50 epochs getting just squished into a single one with a validation occurring once in a half an hour.
3. Writing a custom Keras callback
By hooking onto on_epoch_end
, doing a count, then replacing the dataset on every 50th sounds like an option (have not tried yet, though). Not happy about the complexity creep, either.
Bonus question
Is there a difference between consecutively training 2000 x 5 samples vs. running 10,000 all at once? If no, the problem instantly degrades into a lot easier pipeline/cache problem.
Your thoughts are overly appreciated!