Zero gradients problem for gradCAM in Siamese Network

Dear all,

I am trying to implement gradCAM for Siamese Model, however, I am getting zero gradients when I was trying to compute them using GradientTape().

Below is my code for you to have a look.

def SiameseNetwork(input_shape):

moving_input = tf.keras.Input(input_shape)
ref_input = tf.keras.Input(input_shape)

1st 3D conv blocks, which involves, convolution, BN, activation and pooling

x_1 = tf.keras.layers.Conv3D(32, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_1')(moving_input)

x_1_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_1)

x_1_bn_ac = tf.keras.layers.Activation('relu')(x_1_bn)

x_2 = tf.keras.layers.Conv3D(32, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_2')(x_1_bn_ac)

x_2_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_2)

x_2_bn_ac = tf.keras.layers.Activation('relu')(x_2_bn)

x_2_bn_ac_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_2_bn_ac)

          

# 2nd 3D conv block, which involves, convolution, BN, activation and pooling 

x_3 = tf.keras.layers.Conv3D(64, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_3')(x_2_bn_ac_pooling)

x_3_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_3)

x_3_bn_ac = tf.keras.layers.Activation('relu')(x_3_bn)

x_4 = tf.keras.layers.Conv3D(64, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_4')(x_3_bn_ac)

x_4_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_4)

x_4_bn_ac = tf.keras.layers.Activation('relu')(x_4_bn)

x_4_bn_ac_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_4_bn_ac)



# 3rd 3D conv block, which involves, convolution, BN, activation and pooling 

x_5 = tf.keras.layers.Conv3D(256, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_5')(x_4_bn_ac_pooling)

x_5_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_5)

x_5_bn_ac = tf.keras.layers.Activation('relu')(x_5_bn)

x_5_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_5_bn_ac)       



gap_layer = tf.keras.layers.GlobalAveragePooling3D()(x_5_pooling)

#model.add(tf.keras.layers.Dropout(0.3))

dense_layer = tf.keras.layers.Dense(1024, activation = 'relu', kernel_regularizer = 'L2')(gap_layer)



encoding_model = tf.keras.Model(inputs = moving_input, outputs =  dense_layer)

encoded_moving  = encoding_model(moving_input)

encoded_ref = encoding_model(ref_input)

L1_layer = tf.keras.layers.Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))

#L2_layer = tf.keras.layers.Lambda(lambda tensors:K.l2_normalize((tensors[0] - tensors[1]), axis = 1))

L1_distance = L1_layer([encoded_moving, encoded_ref]) # L1-norm

#L2_distance = L2_layer([encoded_moving, encoded_ref]) # L2-norm or Euclidean Norm

#L2_distance = tf.keras.layers.Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([encoded_moving, encoded_ref])

prediction = tf.keras.layers.Dense(1, activation='sigmoid')(L1_distance)

siamesenet = tf.keras.Model(inputs = [moving_input, ref_input], outputs = prediction)



return siamesenet, encoding_model

img_shape = (30, 45, 30, 1)

siamese_model, base_model = SiameseNetwork(img_shape)
base_learning_rate = 0.00005
base_model.summary()
siamese_model.summary()
siamese_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = base_learning_rate), loss = ‘binary_crossentropy’, metrics = [accuracy, recall_m, specificity, precision_m, f1_m])

fine_tune_epochs = 20
history_fine = siamese_model.fit([left_input_cv, right_input_cv], targets_cv, batch_size = 32,

                          epochs = fine_tune_epochs,

                          shuffle = True,

                          validation_split = 0.2)

For extracting last convolution layer, I have used the following

last_conv_layer_name = ‘conv3d_5’
s_model = tf.keras.models.Model(

    [loaded_siamese_model.inputs], [loaded_siamese_model.get_layer('model').get_layer(last_conv_layer_name).output, loaded_siamese_model.output]

)

with tf.GradientTape(persistent=True) as tape:
last_conv_layer_output, prediction = s_model([left_test, right_test])

grads = tape.gradient(prediction[:, 0], last_conv_layer_output, unconnected_gradients=‘zero’)

grads is zero. Any help is highly appreciated.

Thanks.