Good morning everyone,
I’m trying an U-Net based IA model to segmentate rat brains from MRI images. For that, I’m training the model with 574 192x192x1 MRI images (.tiff) and same dimension masks (.nii) that I manually segmentated using ImageJ. The code is based on this paper from January 2020: “Automatic Skull Stripping of Rat and Mouse Brain Data Using U-Net”. More specifically:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate, BatchNormalization
def model4(input_size=(192, 192,1)):
inputs = Input(input_size)
conv1 = Conv2D(32, 3, activation='relu', padding='same')(inputs)
conv1 = BatchNormalization()(conv1)
conv1 = Conv2D(32, 3, activation='relu', padding='same')(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(64, 3, activation='relu', padding='same')(pool1)
conv2 = BatchNormalization()(conv2)
conv2 = Conv2D(64, 3, activation='relu', padding='same')(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(96, 3, activation='relu', padding='same')(pool2)
conv3 = BatchNormalization()(conv3)
conv3 = Conv2D(96, 3, activation='relu', padding='same')(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(128, 3, activation='relu', padding='same')(pool3)
conv4 = BatchNormalization()(conv4)
conv4 = Conv2D(128, 3, activation='relu', padding='same')(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(256, 3, activation='relu', padding='same')(pool4)
conv5 = BatchNormalization()(conv5)
conv5 = Conv2D(256, 3, activation='relu', padding='same')(conv5)
conv5 = BatchNormalization()(conv5)
pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)
conv6 = Conv2D(512, 3, activation='relu', padding='same')(pool5)
conv6 = BatchNormalization()(conv6)
conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
conv6 = BatchNormalization()(conv6)
pool6 = MaxPooling2D(pool_size=(2, 2))(conv6)
up1 = UpSampling2D(size=(2, 2))(conv6)
up1 = Conv2D(256, 3, activation='relu', padding='same')(up1)
up1 = BatchNormalization()(up1)
merge1 = concatenate([conv5, up1], axis=3)
conv7 = Conv2D(256, 3, activation='relu', padding='same')(merge1)
conv7 = BatchNormalization()(conv7)
conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
conv7 = BatchNormalization()(conv7)
up2 = UpSampling2D(size=(2, 2))(conv7)
up2 = Conv2D(128, 3, activation='relu', padding='same')(up2)
up2 = BatchNormalization()(up2)
merge2 = concatenate([conv4, up2], axis=3)
conv8 = Conv2D(128, 3, activation='relu', padding='same')(merge2)
conv8 = BatchNormalization()(conv8)
conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)
conv8 = BatchNormalization()(conv8)
up3 = UpSampling2D(size=(2, 2))(conv8)
up3 = Conv2D(96, 3, activation='relu', padding='same')(up3)
up3 = BatchNormalization()(up3)
merge3 = concatenate([conv3, up3], axis=3)
conv9 = Conv2D(96, 3, activation='relu', padding='same')(merge3)
conv9 = BatchNormalization()(conv9)
conv9 = Conv2D(96, 3, activation='relu', padding='same')(conv9)
conv9 = BatchNormalization()(conv9)
up4 = UpSampling2D(size=(2, 2))(conv9)
up4 = Conv2D(64, 3, activation='relu', padding='same')(up4)
up4 = BatchNormalization()(up4)
merge4 = concatenate([conv2, up4], axis=3)
conv10 = Conv2D(64, 3, activation='relu', padding='same')(merge4)
conv10 = BatchNormalization()(conv10)
conv10 = Conv2D(64, 3, activation='relu', padding='same')(conv10)
conv10 = BatchNormalization()(conv10)
up5 = UpSampling2D(size=(2, 2))(conv10)
up5 = Conv2D(32, 3, activation='relu', padding='same')(up5)
up5 = BatchNormalization()(up5)
merge5 = concatenate([conv1, up5], axis=3)
conv11 = Conv2D(32, 3, activation='relu', padding='same')(merge5)
conv11 = BatchNormalization()(conv11)
conv11 = Conv2D(32, 3, activation='relu', padding='same')(conv11)
conv11 = BatchNormalization()(conv11)
outputs = Conv2D(1, 1, activation='sigmoid')(conv11)
model = Model(inputs=inputs, outputs=outputs)
return model
model4 = model4()
Futhermore, I normalized the MRI images using intensity and spacial normalization. The loss functions I’ve been aplying are Dice Loss, BCE and Tversky Focal Loss with different alpha/beta values.
The result is an entire white image, no masks. I don’t know what I’m doing wrong. I’m very new in all this, so probably I’m skipping something important.
By the way, the batch size and the learning rate are 16 and 0.001, respectively.
Thanks in advance and sorry for any grammatical mistakes.