Hi guys, I’m trying to learn Tensorflow (2.3.0 with Python 3.7) to train a UNET model. I’m using the PASCAL VOC 2012 dataset (21 classes). However, my model doesn’t learn, the accuracy is always around 0.75 and the loss is also stable.
My labels are converted to a tensor with pixel values between 0 and 21 and converted to one-hot (I’m not sure that’s the way it is, but that’s what I get). Images are normalized from 0 to 1. I use cross entropy as loss function and Adam optimizer.
I’ve tried to fix the problem in a number of ways, but I really don’t know what else to do.
can anybody help me? Thanks.
The code is below.
MODEL:
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, MaxPool2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model
def conv_block(inputs, filters, pool = True):
x = Conv2D(filters, 3, padding = "same")(inputs)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = Conv2D(filters, 3, padding = "same")(inputs)
x = BatchNormalization()(x)
x = Activation("relu")(x)
if pool == True:
p = MaxPool2D((2, 2))(x)
return x, p
else:
return x
def build_unet(shape, num_classes):
inputs = Input(shape)
#Encoder
x1, p1 = conv_block(inputs, 16, pool = True)
x2, p2 = conv_block(p1, 32, pool = True)
x3, p3 = conv_block(p2, 48, pool = True)
x4, p4 = conv_block(p3, 64, pool = True)
#Bridge
b1 = conv_block(p4, 128, pool = False)
#Decoder
u1 = UpSampling2D((2, 2), interpolation = "bilinear")(b1)
c1 = Concatenate()([u1, x4])
x5 = conv_block(c1, 64, pool = False)
u2 = UpSampling2D((2, 2), interpolation = "bilinear")(x5)
c2 = Concatenate()([u2, x3])
x6 = conv_block(c2, 48, pool = False)
u3 = UpSampling2D((2, 2), interpolation = "bilinear")(x6)
c3 = Concatenate()([u3, x2])
x7 = conv_block(c3, 32, pool = False)
u4 = UpSampling2D((2, 2), interpolation = "bilinear")(x7)
c4 = Concatenate()([u4, x1])
x8 = conv_block(c4, 16, pool = False)
#output layer
output = Conv2D(num_classes, 1, padding = "same", activation = "softmax")(x8) #sigmoid
return Model(inputs, output)
if __name__ == "__name__":
model = build_unet((256, 256, 3), 21)
model.summary()
TRAIN:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from data import load_data, tf_dataset
from model import build_unet
if __name__ == "__main__":
""" Seeding """
np.random.seed(42)
tf.random.set_seed(42)
""" Dataset """
path = "VOC2012/" #root
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)
print(f"Dataset: Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")
""" Hyperparameters """
shape = (256, 256, 3)
num_classes = 21
lr = 1e-2
batch_size = 2
epochs = 100
""" Model """
model = build_unet(shape, num_classes)
model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr), metrics = ['CategoricalAccuracy']) #CategoricalCrossentropy
train_dataset = tf_dataset(train_x, train_y, batch=batch_size)
valid_dataset = tf_dataset(valid_x, valid_y, batch=batch_size)
train_steps = len(train_x)//batch_size
valid_steps = len(valid_x)//batch_size
callbacks = [
ModelCheckpoint("model.h5", verbose=1, save_best_model=True),
ReduceLROnPlateau(monitor="val_loss", patience=3, factor=0.5, verbose=1, min_lr=5e-8),
EarlyStopping(monitor="val_loss", patience=8, verbose=1)
]
model.fit(train_dataset,
steps_per_epoch=train_steps,
validation_data=valid_dataset,
validation_steps=valid_steps,
epochs=epochs,
callbacks=callbacks
)
DATA:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow as tf
import cv2
from mapeamento import augment_data
from mapeamento import rgb_to_mask
H = 256
W = 256
factor = 8
batch = 32
def process_data(data_path, file_path):
df = pd.read_csv(file_path, sep=" ", header=None)
names = df[0].values
#names = names[0:int(len(names)/factor)]
images = [os.path.join(data_path, f"JPEGImages/{name}.jpg") for name in names]
masks = [os.path.join(data_path, f"SegmentationClass/{name}.png") for name in names]
return images, masks
def load_data(path):
train_valid_path = os.path.join(path, "ImageSets/Segmentation/trainval.txt")
test_path = os.path.join(path, "ImageSets/Segmentation/val.txt")
train_x, train_y = process_data(path, train_valid_path)
test_x, test_y = process_data(path, test_path)
train_x, valid_x = train_test_split(train_x, test_size=0.15, random_state=42)
train_y, valid_y = train_test_split(train_y, test_size=0.15, random_state=42)
return (train_x, train_y), (valid_x, valid_y), (test_x, test_y)
def read_image(x):
x = cv2.imread(x, cv2.IMREAD_COLOR)
x = cv2.resize(x, (W, H), interpolation = cv2.INTER_NEAREST)
x = x/255.0
x = x.astype(np.float32)
return x
def read_mask(x):
x = cv2.imread(x, cv2.IMREAD_UNCHANGED)
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) #masks colormap is RGB
x = cv2.resize(x, (W, H), interpolation = cv2.INTER_NEAREST)
x = rgb_to_mask(x)
#x = x.astype(np.int32)
return x
def tf_dataset(x, y, batch=batch):
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(buffer_size=5000)
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch)
dataset = dataset.repeat()
dataset = dataset.prefetch(8)
return dataset
def preprocess(x, y):
def f(x, y):
x = x.decode()
y = y.decode()
image = read_image(x)
mask = read_mask(y) #uint8
#print(f"\nMAX UINT8: {mask.max()} - MIN UINT8: {mask.min()}")
image, mask = augment_data(image, mask)
mask = mask.astype(np.int32)
#print(f"\nMAX INT32: {mask.max()} - MIN INT32: {mask.min()}")
return image, mask
image, mask = tf.numpy_function(f, [x, y], [tf.float32, tf.int32])
mask = tf.one_hot(mask, 21, dtype=tf.int32)
image.set_shape([H, W, 3])
mask.set_shape([H, W, 21])
return image, mask
if __name__ == "__main__":
path = "VOC2012/" #root
(train_x, train_y), (valid_x, valid_y), (test_x, test_y) = load_data(path)
print(f"Dataset: Train: {len(train_x)} - Valid: {len(valid_x)} - Test: {len(test_x)}")
dataset = tf_dataset(train_x, train_y, batch=batch)
for x, y in dataset:
print(x.shape, y.shape) ## (b, 256, 256, 3), (b, 256, 256, 21)
TRANSFORMS:
import numpy as np
from albumentations import HorizontalFlip, GridDistortion, OpticalDistortion, ChannelShuffle, CoarseDropout, CenterCrop, Crop, Rotate
import cv2
#from numba import jit
#import random
#@jit
def rgb_to_mask(img):
ClassIndex = {0:[0, 0, 0], 1:[128, 0, 0], 2:[0, 128, 0], 3:[128, 128, 0], 4:[0, 0, 128], 5:[128, 0, 128],
6:[0, 128, 128], 7:[128, 128, 128], 8:[64, 0, 0], 9:[192, 0, 0], 10:[64, 128, 0],
11:[192, 128, 0], 12:[64, 0, 128], 13:[192, 0, 128], 14:[64, 128, 128], 15:[192, 128, 128],
16:[0, 64, 0], 17:[128, 64, 0], 18:[0, 192, 0], 19:[128, 192, 0], 20:[0, 64, 128]}
h, w, c = img.shape
seg_labels = np.zeros((h, w), dtype = np.uint8)
mask_0 = (img == ClassIndex[0]).all(axis=2)
mask_1 = (img == ClassIndex[1]).all(axis=2)
mask_2 = (img == ClassIndex[2]).all(axis=2)
mask_3 = (img == ClassIndex[3]).all(axis=2)
mask_4 = (img == ClassIndex[4]).all(axis=2)
mask_5 = (img == ClassIndex[5]).all(axis=2)
mask_6 = (img == ClassIndex[6]).all(axis=2)
mask_7 = (img == ClassIndex[7]).all(axis=2)
mask_8 = (img == ClassIndex[8]).all(axis=2)
mask_9 = (img == ClassIndex[9]).all(axis=2)
mask_10 = (img == ClassIndex[10]).all(axis=2)
mask_11 = (img == ClassIndex[11]).all(axis=2)
mask_12 = (img == ClassIndex[12]).all(axis=2)
mask_13 = (img == ClassIndex[13]).all(axis=2)
mask_14 = (img == ClassIndex[14]).all(axis=2)
mask_15 = (img == ClassIndex[15]).all(axis=2)
mask_16 = (img == ClassIndex[16]).all(axis=2)
mask_17 = (img == ClassIndex[17]).all(axis=2)
mask_18 = (img == ClassIndex[18]).all(axis=2)
mask_19 = (img == ClassIndex[19]).all(axis=2)
mask_20 = (img == ClassIndex[20]).all(axis=2)
seg_labels[mask_0] = 0
seg_labels[mask_1] = 1
seg_labels[mask_2] = 2
seg_labels[mask_3] = 3
seg_labels[mask_4] = 4
seg_labels[mask_5] = 5
seg_labels[mask_6] = 6
seg_labels[mask_7] = 7
seg_labels[mask_8] = 8
seg_labels[mask_9] = 9
seg_labels[mask_10] = 10
seg_labels[mask_11] = 11
seg_labels[mask_12] = 12
seg_labels[mask_13] = 13
seg_labels[mask_14] = 14
seg_labels[mask_15] = 15
seg_labels[mask_16] = 16
seg_labels[mask_17] = 17
seg_labels[mask_18] = 18
seg_labels[mask_19] = 19
seg_labels[mask_20] = 20
return seg_labels
def augment_data(images, masks):
H = 256
W = 256
aug = HorizontalFlip(p=1.0)
augmented = aug(image=images, mask=masks)
x1 = augmented["image"]
y1 = augmented["mask"]
aug = ChannelShuffle(p=1)
augmented = aug(image=x1, mask=y1)
x3 = augmented['image']
y3 = augmented['mask']
aug = CoarseDropout(p=1, min_holes=3, max_holes=10, max_height=32, max_width=32)
augmented = aug(image=x3, mask=y3)
x4 = augmented['image']
y4 = augmented['mask']
aug = Rotate(limit=45, interpolation = cv2.INTER_NEAREST, p=1.0)
augmented = aug(image=x4, mask=y4)
x5 = augmented["image"]
y5 = augmented["mask"]
try:
""" Center Cropping """
aug = CenterCrop(H, W, p=1.0)
augmented = aug(image=x5, mask=y5)
i = augmented["image"]
m = augmented["mask"]
return i, m
except Exception as e:
return i, m