Hi folks.
Currently, I have a requirement for a batch of data that should have an equal number of samples from each of the given classes.
I am implementing it using the naive way for CIFAR10:
def support_sampler():
idx_dict = dict()
for class_id in np.arange(0, 10):
subset_labels = sampled_labels[sampled_labels == class_id]
random_sampled = np.random.choice(len(subset_labels), 16)
idx_dict[class_id] = random_sampled
return np.concatenate(list(idx_dict.values()))
def get_support_ds():
random_balanced_idx = support_sampler()
temp_train, temp_labels = sampled_train[random_balanced_idx],\
sampled_labels[random_balanced_idx]
support_ds = tf.data.Dataset.from_tensor_slices((temp_train, temp_labels))
support_ds = (
support_ds
.shuffle(BATCH_SIZE * 1000)
.map(agumentation, num_parallel_calls=AUTO)
.batch(BATCH_SIZE)
)
return support_ds
Is there a better way? Particularly using pure TF ops with tf.data
?