Data Leakage - image_dataset_from_directory()

Hi, I’m trying to build a very basic cnn to do multiclass image classification and getting a little stuck on one of the first steps of splitting the data! Following a youtube tutorial I initially created a dataset using tf.keras.utils.image_dataset_from_directory() then split it into train/valid/test using .skip() and .take(). While the model worked great I noticed that the test set changed each time I used it (even if I only had 1 batch). My understanding is that doing this method, every time you use all the data, it reshuffles and redraws all of the samples. So, 1. Is this then a source of data leakage? In that as you train the CNN, at each epoch it redraws the samples and hence the model has already seen the test set?

As a result I decided to try creating a separate directory for test data that I don’t touch and just doing the train/validation split. Now, from reading online i realised I could just do this using validation_split keyword… However, that brings up the second question where if I do the split using validation split (Method 1 below) I only get validation accuracy up to about 0.5 during training, whereas if I do the skip(), take() method (Method 2) I can get up to 0.95… I’m clearly doing something different with these two methods but can’t see it. Could anyone explain what it is? And which method is better?

## METHOD 1 ##

validation_split = 0.2

train1 = tf.keras.utils.image_dataset_from_directory(
              train_dir,
              validation_split = validation_split,
              subset = "training",
              seed = RANDOM_STATE)

val1 = tf.keras.utils.image_dataset_from_directory(
              train_dir,
              validation_split = validation_split,
              subset = "validation",
              seed = RANDOM_STATE)

train1 = train1.map(lambda x,y: (x/255.,y))
val1 = val1.map(lambda x,y: (x/255.,y))
## METHOD 2
data = tf.keras.utils.image_dataset_from_directory(train_dir)

# Scale the pixel data to between 0 and 1
data = data.map(lambda x,y: (x/255.,y))
          
# Split into train, validation and test samples
n_batchs = len(data)
train_size = int(n_batchs*0.8)
val_size = int(n_batchs*0.2)

# if rounding causes sizes to be less than amount of data, add spare data to the training set
total_size = train_size+val_size
if total_size < n_batchs:
  train_size += n_batchs - total_size

train2 = data.take(train_size)
val2 = data.skip(train_size).take(val_size)

Thank you so much for any help!

Hi @EmmaK, By default tf.keras.utils.image_dataset_from_directory will shuffle the data as the shuffle argument is set to True. Due to this while splitting some times the same image can be in train and val dataset also. It is recommended to use method 1. Thank You.

TensorFlow 2.16 + Python 3.12 – JARaaS Hybrid RAG - 6/16/2024

The concern you raised about data leakage is valid. When splitting datasets, especially for deep learning tasks such as image classification, ensuring that the test set remains unseen by the model during training is crucial to avoid data leakage.

Here’s a detailed response addressing your questions:

Question 1: Is this a source of data leakage?

Yes, splitting a dataset using .skip() and .take() without properly handling shuffling can lead to potential data leakage. This occurs because if the data is reshuffled in each epoch, parts of the test set can inadvertently be included in training or validation sets, thereby leaking information to the model that it should not have.

To avoid this:

  • Always ensure that the test set is completely separate and untouched during training.
  • You can create separate directories for training, validation, and test sets before starting the model training.

Question 2: Difference in validation accuracy between the two methods

There are several possible reasons for the discrepancy in validation accuracy between the two methods:

  1. Different Splitting Methods:

    • Method 1 uses a fixed split, which ensures that the same samples are always in the training and validation sets. This consistency means that the model learns to generalize based on a stable dataset.
    • Method 2 may unintentionally allow overlap between training and validation sets due to the reshuffling, which could cause inflated validation accuracy.
  2. Random State:

    • If you did not fix the random seed (RANDOM_STATE) consistently across different runs, the data splits in Method 2 could vary, leading to inconsistent results.
  3. Shuffling:

    • When using .skip() and .take(), ensure that you shuffle the data only for the training set and use a consistent sampling strategy for validation and test sets.

Suggestions for Proper Dataset Splitting

Method 1 is generally preferred and avoids many common pitfalls associated with data leakage and inconsistent dataset splits. Here’s how you can improve your data splitting strategy:

import tensorflow as tf

# Use separate directories for training, validation, and test sets beforehand
train1 = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset="training",
    seed=RANDOM_STATE,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

val1 = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    validation_split=0.2,
    subset="validation",
    seed=RANDOM_STATE,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

train1 = train1.map(lambda x, y: (x / 255., y))
val1 = val1.map(lambda x, y: (x / 255., y))

By using the validation_split parameter, you ensure that the split is consistent and deterministic, hence avoiding any overlap and data leakage.

Summarized Advice

  • Ensure Test Data is Separate: Use a separate directory for test data that the model does not access during training.
  • Consistent Splitting: Use validation_split and seed values to ensure deterministic splitting.
  • Check Overfitting: Compare the training and validation accuracy. If there is a vast difference, it might indicate overfitting.

By following these guidelines, you should achieve a more reliable and repeatable dataset splitting without the risk of data leakage.

Sources:

  • Internal document: datasets_for_estimators.md
  • Internal document: kernel_methods.md

These documents provide detailed insights into dataset handling and proper splitting techniques within TensorFlow.