One-hot encoding a segmentation label map with pure TensorFlow

How do I one-hot encode a segmentation map with pure TensorFlow ops?

Currently, we are doing something like this:

n_classes = 12
h, w, _ = image.shape
one_hot_encoded_labels = np.zeros((h, w, n_classes), dtype=np.float32)
for i in range(n_classes):
    one_hot_encoded_labels[labels == i, i] = 1

(Say image has a shape of (443, 300, 3) and labels has a shape of (443, 300).)

And then we are using tf.py_function to make it all work.

tf.one_hot can totally handle this case.

1 Like

God!

import tensorflow as tf

labels = tf.ones((443, 300), dtype=tf.int32)
ohe_labels = tf.one_hot(labels, n_classes)

Voila!

1 Like