TensorFlow Dataset reduce function too slow after skip

Hi everyone,
I’m trying to take 100 random elements from cifar10 dataset for each class and reduce them to a single “image” using the mean.
The problem is that the reduce function time increases significantly after I use the skip function with a large value.

The code is the following:
for c in classes:
skip_elem = 100
class_ds = train_ds.filter(lambda x, y: tf.equal(tf.argmax(y[0]), int(c))).skip(skip_elem).take(100).unbatch()
z = class_ds.reduce(tf.zeros(shape=(res//8, res//8, 4), dtype=tf.float32), lambda a, b: a + b[0])
z /= 100

I’ve tried to batch, rebatch, but it seems that if I skip few elements it executes the reduction in a reasonable time. The “take” function also increases the reduce time, but I expected that.

Why does it happen? There are other possibilities?

Thanks!

Hi @Matteo_Doria,

Sorry for the delay in response.
This slowdown might due to skipping a large number of elements forces the system to read through the dataset sequentially and if the underlying data format is not optimized for random access. So I suggest to use shuffle instead of skip if we have to take random elements in a large dataset.

Code:

for c in classes:
    # Filter for the specific class and shuffle immediately
    class_ds = train_ds.filter(
        lambda x, y: tf.equal(tf.argmax(y[0]), int(c))
    ).shuffle(
        buffer_size=1000,  # Adjust this based on your dataset size
        seed=42  # Optional: set seed for reproducibility
    ).take(100).unbatch()
  

Please let us know if any issues.Thank You.