Dear all,
I am trying to implement gradCAM for Siamese Model, however, I am getting zero gradients when I was trying to compute them using GradientTape().
Below is my code for you to have a look.
def SiameseNetwork(input_shape):
moving_input = tf.keras.Input(input_shape)
ref_input = tf.keras.Input(input_shape)
1st 3D conv blocks, which involves, convolution, BN, activation and pooling
x_1 = tf.keras.layers.Conv3D(32, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_1')(moving_input)
x_1_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_1)
x_1_bn_ac = tf.keras.layers.Activation('relu')(x_1_bn)
x_2 = tf.keras.layers.Conv3D(32, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_2')(x_1_bn_ac)
x_2_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_2)
x_2_bn_ac = tf.keras.layers.Activation('relu')(x_2_bn)
x_2_bn_ac_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_2_bn_ac)
# 2nd 3D conv block, which involves, convolution, BN, activation and pooling
x_3 = tf.keras.layers.Conv3D(64, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_3')(x_2_bn_ac_pooling)
x_3_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_3)
x_3_bn_ac = tf.keras.layers.Activation('relu')(x_3_bn)
x_4 = tf.keras.layers.Conv3D(64, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_4')(x_3_bn_ac)
x_4_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_4)
x_4_bn_ac = tf.keras.layers.Activation('relu')(x_4_bn)
x_4_bn_ac_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_4_bn_ac)
# 3rd 3D conv block, which involves, convolution, BN, activation and pooling
x_5 = tf.keras.layers.Conv3D(256, (3,3,3), strides = (1,1,1), padding = 'same', kernel_regularizer = 'L2', name = 'conv3d_5')(x_4_bn_ac_pooling)
x_5_bn = tf.keras.layers.BatchNormalization(axis = -1)(x_5)
x_5_bn_ac = tf.keras.layers.Activation('relu')(x_5_bn)
x_5_pooling = tf.keras.layers.MaxPooling3D(strides = (2, 2, 2))(x_5_bn_ac)
gap_layer = tf.keras.layers.GlobalAveragePooling3D()(x_5_pooling)
#model.add(tf.keras.layers.Dropout(0.3))
dense_layer = tf.keras.layers.Dense(1024, activation = 'relu', kernel_regularizer = 'L2')(gap_layer)
encoding_model = tf.keras.Model(inputs = moving_input, outputs = dense_layer)
encoded_moving = encoding_model(moving_input)
encoded_ref = encoding_model(ref_input)
L1_layer = tf.keras.layers.Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
#L2_layer = tf.keras.layers.Lambda(lambda tensors:K.l2_normalize((tensors[0] - tensors[1]), axis = 1))
L1_distance = L1_layer([encoded_moving, encoded_ref]) # L1-norm
#L2_distance = L2_layer([encoded_moving, encoded_ref]) # L2-norm or Euclidean Norm
#L2_distance = tf.keras.layers.Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([encoded_moving, encoded_ref])
prediction = tf.keras.layers.Dense(1, activation='sigmoid')(L1_distance)
siamesenet = tf.keras.Model(inputs = [moving_input, ref_input], outputs = prediction)
return siamesenet, encoding_model
img_shape = (30, 45, 30, 1)
siamese_model, base_model = SiameseNetwork(img_shape)
base_learning_rate = 0.00005
base_model.summary()
siamese_model.summary()
siamese_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = base_learning_rate), loss = ‘binary_crossentropy’, metrics = [accuracy, recall_m, specificity, precision_m, f1_m])
fine_tune_epochs = 20
history_fine = siamese_model.fit([left_input_cv, right_input_cv], targets_cv, batch_size = 32,
epochs = fine_tune_epochs,
shuffle = True,
validation_split = 0.2)
For extracting last convolution layer, I have used the following
last_conv_layer_name = ‘conv3d_5’
s_model = tf.keras.models.Model(
[loaded_siamese_model.inputs], [loaded_siamese_model.get_layer('model').get_layer(last_conv_layer_name).output, loaded_siamese_model.output]
)
with tf.GradientTape(persistent=True) as tape:
last_conv_layer_output, prediction = s_model([left_test, right_test])
grads = tape.gradient(prediction[:, 0], last_conv_layer_output, unconnected_gradients=‘zero’)
grads is zero. Any help is highly appreciated.
Thanks.