Training a Image Segmentation UNet for more than 20 epochs hit `loss: nan`

I am trying to build an image segmentation app using UNet architecture and here is my model:

Encoder:

conv = Conv2D(n_filters, # Number of filters
                    3,   # Kernel size   
                    activation="relu",
                    padding="same",
                    kernel_initializer='he_normal')(inputs)
conv = BatchNormalization()(conv)
conv = Conv2D(n_filters, # Number of filters
                    3,   # Kernel size
                    activation="relu",
                    padding="same",
                    # set 'kernel_initializer' same as above
                    kernel_initializer='he_normal')(conv)
conv = BatchNormalization()(conv)
        
if dropout_prob > 0:
    conv = Dropout(dropout_prob)(conv)
            
next_layer = MaxPooling2D(2)(conv) if max_pooling else conv
skip_connection = conv

Decoder:

up = Conv2DTranspose(
                    n_filters,    # number of filters
                    3,    # Kernel size
                    strides=(2,2),
                    padding="same")(expansive_input)
        
merge = concatenate([up, contractive_input], axis=3)
conv = Conv2D(n_filters,   # Number of filters
                    3,     # Kernel size
                    activation="relu",
                    padding="same",
                    kernel_initializer='he_normal')(merge)
conv = BatchNormalization()(conv)
conv = Conv2D(n_filters,  # Number of filters
                    3,   # Kernel size
                    activation="relu",
                    padding="same",
                    # set 'kernel_initializer' same as above
                    kernel_initializer='he_normal')(conv)
conv = BatchNormalization()(conv)

The model:

inputs = Input(self._input_size)
inputs = Normalization(axis=-1)(inputs)
# Contracting Path (encoding)
# Add a Encoder with the inputs of the unet_ model and n_filters
cblock1 = self.Encoder(inputs, self._n_filters)
# Chain the first element, [0], of the output of each block to be the input of the next Encoder. 
# Double the number of filters at each new step
cblock2 = self.Encoder(cblock1[0], self._n_filters * 2)
cblock3 = self.Encoder(cblock2[0], self._n_filters * 4)
cblock4 = self.Encoder(cblock3[0], self._n_filters * 8, dropout_prob=0.3) # Include a dropout_prob of 0.3 for this layer
# Include a dropout_prob of 0.3 for this layer, and avoid the max_pooling layer
cblock5 = self.Encoder(cblock4[0], self._n_filters * 16, dropout_prob=0.3, max_pooling=False)
            
# Expanding Path (decoding)
# Add the first Decoder.
# Use the cblock5[0] as expansive_input and cblock4[1] as contractive_input and n_filters * 8
ublock6 = self.Decoder(cblock5[0], cblock4[1],  self._n_filters * 8)
# Chain the output of the previous block as expansive_input and the corresponding contractive block output.
# Note that you must use the second element, [1], of the contractive block i.e before the maxpooling layer. 
# At each step, use half the number of filters of the previous block 
ublock7 = self.Decoder(ublock6, cblock3[1],  self._n_filters * 4)
ublock8 = self.Decoder(ublock7, cblock2[1],  self._n_filters * 2)
ublock9 = self.Decoder(ublock8, cblock1[1],  self._n_filters)

conv9 = Conv2D(self._n_filters,
                        3,
                        activation='relu',
                        padding='same',
                        # set 'kernel_initializer' same as above exercises
                        kernel_initializer='he_normal')(ublock9)

# Add a Conv2D layer with n_classes filter, kernel size of 1 and a 'same' padding
conv10 = Conv2D(self._n_classes, 1, padding="same")(conv9)
            
self._model = tf.keras.Model(inputs=inputs, outputs=conv10)
self._model.compile(optimizer=Adam(),
                        loss=SparseCategoricalCrossentropy(from_logits=True),
                        metrics=['accuracy'])

The data is 1060, CARLA self-driving car dataset.

When I train the model for more than 20 epochs, it hits loss: nanAny insight is appreciated.

Hi @khteh, NaN loss typically occurs due to exploding gradients from high learning rate or unnormalized inputs. Try to reduce the Learning rate and apply Gradient Clipping to fix the gradient explosion issue. Thanks!

1 Like