How to implement batch-wise data augmentation in TensorFlow 2.x custom training loop?

I’m working on a custom training loop in TensorFlow 2.x where I need to apply data augmentation (rotation, zoom, and horizontal flip) to my image batches during training. My current setup uses tf.data.Dataset for the input pipeline.

Here’s my current code:

batch_size = 32
input_shape = (224, 224, 3)

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

model = create_model(input_shape)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

I want to:

  1. Apply random augmentations to each batch during training
  2. Keep the augmentations consistent within the custom training loop
  3. Ensure the augmentations are performed on GPU for better performance

What’s the most efficient way to implement this in TensorFlow 2.x while maintaining the custom training loop structure?

Hi @KRows, I don’t think you required a custom training loop for data augmentation for batch data during training. you can define the data augmentation layers and can pass those layers in the model architecture. while training the batch data passed to the model will be augmented. For example,

data_augmentation = keras.Sequential(
  ["augmentation layers" ])

model=keras.Sequential([
   data_augmentation,
"other model layers"])

Thank You.