I’m filtering the dataset according to certain labels. Once I call the filtering method, everything is fine. But once I call next(iter(dataset))
for certain values it gets processing for more the 12 hours - for other value it just give the result.
My filtering line code is:
def balanced_dataset(dataset, labels_list, sample_size=1000):
datasets_list = []
for label in labels_list:
print(f'Preparando o dataset {label}')
locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
datasets_list.append(locals()[label].take(sample_size))
ds = tf.data.Dataset.from_tensor_slices(datasets_list)
# 2. extract all elements from datasets and concat them into one dataset
concat_ds = ds.interleave(lambda x: x, cycle_length=len(labels_list), num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
return concat_ds