Hi everyone!
I’m looking to train several models - image classification - on a large dataset which I want to represent as TFRecord files. Today I know how to do it if I save the image (tf.string) and class number (int64). I’d like to be able to set the class numbers during preprocess from a saved string.
Motivation - the dataset I’m working on contains images which can be classified as a more general category (i.e. dogs, cats) and more specifically (labrador, bulldog, persian cat). The labels are mostly more specific, so an image can go though a cascade of a more general category (model of cat/dog/negative), and then a more nuanced model (labrador/bulldog/neg).
I’m relatively new to the framework, and I was able to express what I want in eager execution mode using .numpy() method (code below).
My question is how can this mapping be done efficiently? what methods should I look for?
Thanks
import tensorflow as tf
import functools
# example: this model classifies dogs and cats
example_label_map = {
"bulldog": "dog",
"labrador": "dog",
"persian cat": "cat",
"cow": "neg",
}
# map the N classes to an int of 0..N-1
labels_to_class_int = {
"neg": 0, # negative class
"dog": 1,
"cat": 2,
}
def parse_tfrec_function(example, labels_map):
"""
labels is a map of labels to class numbers. e.g. "bulldog" -> 1, "labrador" -> 1, "persian cat" -> 2, ...
"""
image_feature_description = {
"image": tf.io.FixedLenFeature([], tf.string),
"label": tf.io.FixedLenFeature([], tf.string)
}
features = tf.io.parse_single_example(example, image_feature_description)
image = tf.io.parse_tensor(features["image"], tf.uint8)
label = features["label"].numpy().decode('utf-8')
class_num = labels_map[label]
return image, class_num
tfrec_file = ['/srv/mickey/build-cls/test/data/dataset.tfrec']
dataset = tf.data.TFRecordDataset(tfrec_file)
labels_map = {}
for k, v in example_label_map.items():
labels_map[k] = labels_to_class_int[v]
print(labels_map)
parser_fn = functools.partial(parse_tfrec_function, labels_map=labels_map)
parset_dataset = dataset.map(parser_fn) # this breaks due to AttributeError for Tensor object