Hello everyone, I’m trying to define a custom layer in my Keras model. It’s purpose is to take a motion prediction vector and apply it to an image. This is done using image warping. The warped frame is then used in the loss function to see how well it matches the second frame in the series. The model should in theory try and find the best motion vectors to describe this transition, from frame T to frame T+1.
However, as it is the model complains that there are no gradients. This is likely due to the operations in the custom layer not being differentiable. Which is fine. The only thing I need from those operations is the warped frame to guide the model.
So my question is, how do I prevent the gradients from flowing into this custom layer? I only need to update the CNN portions of the model with respect to the warp frame and the loss.
class WarpFrameLayer(Layer):
def __init__(self, **kwargs):
super(WarpFrameLayer, self).__init__(**kwargs)
def call(self, inputs):
frame, dx_dy = inputs
warped_frame = self.warp_function(frame, dx_dy)
return warped_frame
@tf.function
def warp_function(self, frame, dx_dy):
num_images = tf.shape(frame)[0]
translations = tf.reshape(dx_dy, [num_images, 2])
zeros = tf.zeros([num_images, 1], dtype=tf.float32)
ones = tf.ones([num_images, 1], dtype=tf.float32)
transforms = tf.concat([ones, zeros, translations[:, 0:1], zeros, ones, translations[:, 1:2], zeros, zeros],
axis=1)
output_shape = tf.shape(frame)[1:3]
warped_frame = tf.raw_ops.ImageProjectiveTransformV3(images=frame, transforms=transforms,
output_shape=output_shape, interpolation="BILINEAR",
fill_value=0)
return warped_frame
def camera_translation_model(input_shape):
# Frame inputs
frame_t = Input(shape=input_shape, name='frame_t')
frame_t_plus_1 = Input(shape=input_shape, name='frame_t_plus_1')
# Motion prediction model
conv1_t = Conv2D(16, (3, 3), activation='relu', padding='same')(frame_t)
conv1_t_plus_1 = Conv2D(16, (3, 3), activation='relu', padding='same')(frame_t_plus_1)
concat_features = concatenate([conv1_t, conv1_t_plus_1])
conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(concat_features)
conv3 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
flattened = Flatten()(conv3)
motion_pred = Dense(2, activation='linear', name='motion_pred')(flattened)
warped_frame = WarpFrameLayer()([frame_t, motion_pred])
# warped frame guides the model's motion estimation between frames
model = Model(inputs=[frame_t, frame_t_plus_1], outputs=warped_frame)
return model