I have a huge TFRecord file with more than 4M entries. It is a very unbalanced dataset containing many more entries of some labels and few others - compare to the whole dataset. I want to filter a limited number of entries of some of these labels in order to have a balanced dataset. Below, you can see my attempt, but it takes more than 24 hours to filter 1k from each label (33 different labels).
import tensorflow as tf
tf.compat.as_str(
bytes_or_text='str', encoding='utf-8'
)
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
print("Device:", tpu.master())
strategy = tf.distribute.TPUStrategy(tpu)
except:
strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
dataset = tf.data.TFRecordDataset('/test.tfrecord')
dataset = dataset.with_options(ignore_order)
features, feature_lists = detect_schema(dataset)
#Decodings TFRecord serialized data
def decode_data(serialized):
X, y = tf.io.parse_single_sequence_example(
serialized,
context_features=features,
sequence_features=feature_lists)
return X['title'], y['subject']
dataset = dataset.map(lambda x: tf.py_function(func=decode_data, inp=[x], Tout=(tf.string, tf.string)))
#Filtering and concatenating the samples
def balanced_dataset(dataset, labels_list, sample_size=1000):
datasets_list = []
for label in labels_list:
#Filtering the chosen labels
locals()[label] = dataset.filter(lambda x, y: tf.greater(tf.reduce_sum(tf.cast(tf.equal(tf.constant(label, dtype=tf.int64), y), tf.float32)), tf.constant(0.)))
#appending a limited sample
datasets_list.append(locals()[label].take(sample_size))
concat_dataset = datasets_list[0]
#concatenating the datasets
for dset in datasets_list[1:]:
concat_dataset = concat_dataset.concatenate(dset)
return concat_dataset
balanced_data = balanced_dataset(tabledataset, labels_list=list(decod_dic.values()), sample_size=1000)