I wonder why repeatedly iterating over a tf.data.Dataset
does not yield the same order each time (see the following code). Is there a way to reset the dataset manually, like it is reset for each training epoch?
This would be useful for manual model evaluation: I could reuse my val_dataset
to call model.predict(val_dataset)
, then compare the predictions with the true classes in y
. This way I could e.g. compute a Balanced Accuracy.
For testing, I have created a
synthetic dataset
import os
import numpy as np
import PIL
datapath = 'data-synthetic'
image_size = 224
for i_class in range(2):
classpath = os.path.join(datapath, f'class_{i_class}')
os.makedirs(classpath)
for i_img in range(160): # Align with default batch size of 32.
pixels = np.random.randint(i_class * 128, (i_class+1) * 128, (image_size, image_size))
image = PIL.Image.fromarray(pixels.astype('uint8'), 'L').convert('RGB')
image.save(os.path.join(classpath, f'{i_img:03d}.png'))
that is used in my test code:
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
assert tf.__version__ == '2.9.1', f'Currently: {tf.__version__}' # Remember to run pip-sync
image_size = 224
def get_dataset(subset):
print('get_dataset:', subset)
return image_dataset_from_directory(
'data-synthetic',
labels="inferred",
label_mode='binary',
color_mode="rgb",
batch_size=32,
image_size=(image_size, image_size),
shuffle=True,
seed=1,
validation_split=0.1,
subset=subset,
crop_to_aspect_ratio=False,
)
val_dataset = get_dataset('validation')
for x,y in val_dataset: # Iteration yields the single batch.
print(tf.transpose(y))
print('The dataset is reproducible:')
val_dataset = get_dataset('validation')
for x,y in val_dataset:
print(tf.transpose(y))
print('... but not when just re-iterating:')
for x,y in val_dataset:
print(tf.transpose(y))
The output shows that repeated calling of image_dataset_from_directory
yields the same element order, but simple re-iteration doesn’t:
get_dataset: validation
Found 320 files belonging to 2 classes.
Using 32 files for validation.
tf.Tensor(
[[0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0.
1. 1. 1. 0. 0. 1. 1. 0.]], shape=(1, 32), dtype=float32)
The dataset is reproducible:
get_dataset: validation
Found 320 files belonging to 2 classes.
Using 32 files for validation.
tf.Tensor(
[[0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0.
1. 1. 1. 0. 0. 1. 1. 0.]], shape=(1, 32), dtype=float32)
... but not when just re-iterating:
tf.Tensor(
[[0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0.
1. 1. 0. 0. 1. 0. 1. 0.]], shape=(1, 32), dtype=float32)