Multiclass semantic segmentation model does not learn

Hi guys, I’m trying to learn Tensorflow (2.3.0 with Python 3.7) to train a UNET model. I’m using the PASCAL VOC 2012 dataset (21 classes). However, my model doesn’t learn, the accuracy is always around 0.75 and the loss is also stable.
My labels are converted to a tensor with pixel values ​​between 0 and 21 and converted to one-hot (I’m not sure that’s the way it is, but that’s what I get). Images are normalized from 0 to 1. I use cross entropy as loss function and Adam optimizer.
I’ve tried to fix the problem in a number of ways, but I really don’t know what else to do.

can anybody help me? Thanks.

The code is below.

MODEL:

 from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model

def conv_block(inputs, filters, pool = True):
    x = Conv2D(filters, 3, padding = "same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(filters, 3, padding = "same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    if pool == True:
        p = MaxPool2D((2, 2))(x)
        return x, p
    else:
        return x

def build_unet(shape, num_classes):
    inputs = Input(shape)

    #Encoder
    x1, p1 = conv_block(inputs, 16, pool = True)
    x2, p2 = conv_block(p1, 32, pool = True)
    x3, p3 = conv_block(p2, 48, pool = True)
    x4, p4 = conv_block(p3, 64, pool = True)

    #Bridge

    b1 = conv_block(p4, 128, pool = False)

    #Decoder
    u1 = UpSampling2D((2, 2), interpolation = "bilinear")(b1)
    c1 = Concatenate()([u1, x4])
    x5 = conv_block(c1, 64, pool = False)

    u2 = UpSampling2D((2, 2), interpolation = "bilinear")(x5)
    c2 = Concatenate()([u2, x3])
    x6 = conv_block(c2, 48, pool = False)

    u3 = UpSampling2D((2, 2), interpolation = "bilinear")(x6)
    c3 = Concatenate()([u3, x2])
    x7 = conv_block(c3, 32, pool = False)

    u4 = UpSampling2D((2, 2), interpolation = "bilinear")(x7)
    c4 = Concatenate()([u4, x1])
    x8 = conv_block(c4, 16, pool = False)

    #output layer

    output = Conv2D(num_classes, 1, padding = "same", activation = "softmax")(x8) #sigmoid 

    return Model(inputs, output)

    if __name__ == "__name__":
        model = build_unet((256, 256, 3), 21)
        model.summary()

TRAIN:

    import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

from data import load_data, tf_dataset
from model import build_unet

if __name__ == "__main__":
    """ Seeding """
    np.random.seed(42)
    tf.random.set_seed(42)

    """ Dataset """
    path = "VOC2012/"   #root
    (train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)
    print(f"Dataset: Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")

    """ Hyperparameters """
    shape = (256, 256, 3)
    num_classes = 21
    lr = 1e-2
    batch_size = 2
    epochs = 100

    """ Model """
    model = build_unet(shape, num_classes)
    model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr), metrics = ['CategoricalAccuracy']) #CategoricalCrossentropy

    train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
    valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)

    train_steps = len(train_x)//batch_size
    valid_steps = len(valid_x)//batch_size

    callbacks = [
        ModelCheckpoint("model.h5", verbose=1, save_best_model=True),
        ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.5, verbose=1, min_lr=5e-8),
        EarlyStopping(monitor="val_loss", patience=8, verbose=1)
    ]

    model.fit(train_dataset,
        steps_per_epoch=train_steps,
        validation_data=valid_dataset,
        validation_steps=valid_steps,
        epochs=epochs,
        callbacks=callbacks
    )

DATA:

    import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow as tf
import cv2
from mapeamento import augment_data
from mapeamento import  rgb_to_mask

H = 256
W = 256
factor = 8
batch = 32

def process_data(data_path, file_path):
    df = pd.read_csv(file_path, sep=" ", header=None)
    names = df[0].values
    #names = names[0:int(len(names)/factor)]

    images = [os.path.join(data_path, f"JPEGImages/{name}.jpg") for name in names]              
    masks = [os.path.join(data_path, f"SegmentationClass/{name}.png") for name in names]  

    return images, masks

def load_data(path):
    train_valid_path = os.path.join(path, "ImageSets/Segmentation/trainval.txt")                       
    test_path = os.path.join(path, "ImageSets/Segmentation/val.txt")                                  

    train_x, train_y = process_data(path, train_valid_path)
    test_x, test_y = process_data(path, test_path)

    train_x, valid_x = train_test_split(train_x, test_size=0.15, random_state=42)
    train_y, valid_y = train_test_split(train_y, test_size=0.15, random_state=42)

    return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)

def read_image(x):
    x = cv2.imread(x, cv2.IMREAD_COLOR)
    x = cv2.resize(x, (W, H), interpolation = cv2.INTER_NEAREST)
    x = x/255.0
    x = x.astype(np.float32)
    return x

def read_mask(x):
    x = cv2.imread(x, cv2.IMREAD_UNCHANGED)
    x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) #masks colormap is RGB
    x = cv2.resize(x, (W, H), interpolation = cv2.INTER_NEAREST)
    x =  rgb_to_mask(x)
    #x = x.astype(np.int32)
    return x

def tf_dataset(x, y, batch=batch):
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    dataset = dataset.shuffle(buffer_size=5000)
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch)
    dataset = dataset.repeat()
    dataset = dataset.prefetch(8)
    return dataset

def preprocess(x, y):
    def f(x, y):
        x = x.decode()
        y = y.decode()

        image = read_image(x)
        mask = read_mask(y) #uint8
        #print(f"\nMAX UINT8: {mask.max()} - MIN UINT8: {mask.min()}")

        image, mask = augment_data(image, mask)
        mask = mask.astype(np.int32)
        #print(f"\nMAX INT32: {mask.max()} - MIN INT32: {mask.min()}")

        return image, mask

    image, mask = tf.numpy_function(f, [x, y], [tf.float32, tf.int32])
    mask = tf.one_hot(mask, 21, dtype=tf.int32)
    image.set_shape([H, W, 3])
    mask.set_shape([H, W, 21])

    return image, mask


if __name__ == "__main__":
    path = "VOC2012/"   #root
    (train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)
    print(f"Dataset: Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")

    dataset = tf_dataset(train_x, train_y, batch=batch)
    for x, y in dataset:
        print(x.shape, y.shape) ## (b, 256, 256, 3), (b, 256, 256, 21)

TRANSFORMS:

    import numpy as np
from albumentations import HorizontalFlip, GridDistortion, OpticalDistortion, ChannelShuffle, CoarseDropout, CenterCrop, Crop, Rotate
import cv2
#from numba import jit
#import random

#@jit
def rgb_to_mask(img):

    ClassIndex = {0:[0, 0, 0], 1:[128, 0, 0], 2:[0, 128, 0], 3:[128, 128, 0], 4:[0, 0, 128], 5:[128, 0, 128],
                     6:[0, 128, 128], 7:[128, 128, 128], 8:[64, 0, 0], 9:[192, 0, 0], 10:[64, 128, 0],
                     11:[192, 128, 0], 12:[64, 0, 128], 13:[192, 0, 128], 14:[64, 128, 128], 15:[192, 128, 128],
                     16:[0, 64, 0], 17:[128, 64, 0], 18:[0, 192, 0], 19:[128, 192, 0], 20:[0, 64, 128]}

    h, w, c = img.shape
    seg_labels = np.zeros((h, w), dtype = np.uint8)

    mask_0  = (img == ClassIndex[0]).all(axis=2)
    mask_1  = (img == ClassIndex[1]).all(axis=2)
    mask_2  = (img == ClassIndex[2]).all(axis=2)
    mask_3  = (img == ClassIndex[3]).all(axis=2)
    mask_4  = (img == ClassIndex[4]).all(axis=2)
    mask_5  = (img == ClassIndex[5]).all(axis=2)
    mask_6  = (img == ClassIndex[6]).all(axis=2)
    mask_7  = (img == ClassIndex[7]).all(axis=2)
    mask_8  = (img == ClassIndex[8]).all(axis=2)
    mask_9  = (img == ClassIndex[9]).all(axis=2)
    mask_10 = (img == ClassIndex[10]).all(axis=2)
    mask_11 = (img == ClassIndex[11]).all(axis=2)
    mask_12 = (img == ClassIndex[12]).all(axis=2)
    mask_13 = (img == ClassIndex[13]).all(axis=2)
    mask_14 = (img == ClassIndex[14]).all(axis=2)
    mask_15 = (img == ClassIndex[15]).all(axis=2)
    mask_16 = (img == ClassIndex[16]).all(axis=2)
    mask_17 = (img == ClassIndex[17]).all(axis=2)
    mask_18 = (img == ClassIndex[18]).all(axis=2)
    mask_19 = (img == ClassIndex[19]).all(axis=2)
    mask_20 = (img == ClassIndex[20]).all(axis=2)
    
    seg_labels[mask_0] = 0
    seg_labels[mask_1] = 1
    seg_labels[mask_2] = 2
    seg_labels[mask_3] = 3
    seg_labels[mask_4] = 4
    seg_labels[mask_5] = 5
    seg_labels[mask_6] = 6
    seg_labels[mask_7] = 7
    seg_labels[mask_8] = 8
    seg_labels[mask_9] = 9
    seg_labels[mask_10] = 10
    seg_labels[mask_11] = 11
    seg_labels[mask_12] = 12
    seg_labels[mask_13] = 13
    seg_labels[mask_14] = 14
    seg_labels[mask_15] = 15
    seg_labels[mask_16] = 16
    seg_labels[mask_17] = 17
    seg_labels[mask_18] = 18
    seg_labels[mask_19] = 19
    seg_labels[mask_20] = 20

    return seg_labels


def augment_data(images, masks):
    H = 256
    W = 256

    aug = HorizontalFlip(p=1.0)
    augmented = aug(image=images, mask=masks)
    x1 = augmented["image"]
    y1 = augmented["mask"]

    aug = ChannelShuffle(p=1)
    augmented = aug(image=x1, mask=y1)
    x3 = augmented['image']
    y3 = augmented['mask']

    aug = CoarseDropout(p=1, min_holes=3, max_holes=10, max_height=32, max_width=32)
    augmented = aug(image=x3, mask=y3)
    x4 = augmented['image']
    y4 = augmented['mask']

    aug = Rotate(limit=45, interpolation = cv2.INTER_NEAREST, p=1.0)
    augmented = aug(image=x4, mask=y4)
    x5 = augmented["image"]
    y5 = augmented["mask"]

    try:
        """ Center Cropping """
        aug = CenterCrop(H, W, p=1.0)
        augmented = aug(image=x5, mask=y5)
        i = augmented["image"]
        m = augmented["mask"]
        return i, m

    except Exception as e:
        return i, m

Can anyone help? Thanks.

Have you tried to reproduce the results training a simpler 3 class Unet like:

Check also:

Hi Bhack, thanks for the references.

Answering your question. Yes, I was able to successfully train the model on the oxford-pet-iii dataset using this same code structure. What I changed was the class number, and consequently the range of labels (0 to 3 for Oxford, and 0 to 21 for Pascal VOC). I don’t understand why this doesn’t work.
I’ll study the code you sent me. Thanks.