Applying skip connections for pre-trained vgg19 in keras

I am trying to use VGG19 as an encoder in convolutional LSTM autoencoder structure, i want to apply skip connections similarly in UNet between the last convolutional layer of each block in VGG19 to my decoder ( which has a similar architecture with the VGG19, just upsampling instead of max pooling).

Since the inputs are time dependent, i wrapped the VGG19 with a timedistributed layer.

But however, when i initalized my model, i have a graph disconnection error.

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name=‘input_2’), name=‘input_2’, description=“created by layer ‘input_2’”) at layer “block1_conv1”. The following previous layers were accessed without issue:

My code is as follows:

def build_lstm(inputs,lstm_h_dim):
lstm_cell = MultiRNNCell([LSTMCell(h_dim, forget_bias=1.0) for h_dim in lstm_h_dim], state_is_tuple=True)
x = RNN(lstm_cell,unroll = True,return_sequences = True,time_major = False)(inputs)
return x



def build_decoder(inputs,skip_1,skip_2,skip_3,skip_4,skip_5):
    conv_size = 512
    ## 1st conv block
    # inputs = keras.Input((self.compressed_dim,self.compressed_dim,self.conv_size))
    x = TimeDistributed(UpSampling2D((2,2)))(inputs)
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Add())([skip_1,x])
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(UpSampling2D((2,2)))(x)
    x = BatchNormalization()(x)
    ### 2nd conv block
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Add())([skip_2,x])
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters = conv_size,kernel_size = 3,padding = 'same',activation = 'relu'))(x)    
    x = TimeDistributed(UpSampling2D((2,2)))(x)
    x = BatchNormalization()(x)
    ## 3rd conv block
    x = TimeDistributed(Conv2D(filters = conv_size//2,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Add())([skip_3,x])
    x = TimeDistributed(Conv2D(filters = conv_size//2,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters = conv_size//2,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters = conv_size//2,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(UpSampling2D((2,2)))(x)
    x = BatchNormalization()(x)
    x = TimeDistributed(Conv2D(filters = conv_size//4,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Add())([skip_4,x])
    x = TimeDistributed(Conv2D(filters = conv_size//4,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(UpSampling2D((2,2)))(x)
    x = BatchNormalization()(x)
    ## 4th conv block
    x = TimeDistributed(Conv2D(filters = conv_size//8,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Add())([skip_5,x])
    x = TimeDistributed(Conv2D(filters = conv_size//8,kernel_size = 3,padding = 'same',activation = 'relu'))(x)
    x = TimeDistributed(Conv2D(filters =3,kernel_size = 3,padding = 'same',activation = 'sigmoid'))(x)
    # decoder_model = Model([inputs,encoder.inputs],x)
    return x
    
def build_res_vgg(batch_size,seq_length):
    rnn_size = 512
    seq_length = seq_length
    conv_size = 512
    compressed_dim = 7
    img_dim = (224,224)
    encoder_model = VGG19(weights='imagenet',include_top =False)
    # vgg_outputs = [encoder_model.layers[i].output for i in range(len(encoder_model.layers))]
    time_series_input = keras.Input(shape = (seq_length-1,img_dim[0],img_dim[1],3),batch_size = batch_size)
    encoder = Model(inputs = encoder_model.input,outputs = encoder_model.get_layer('block5_pool').output)

    time_wrapped_output = TimeDistributed(encoder)(time_series_input)
    

    # encoder = Model(inputs = encoder_model.input,outputs = vgg_outputs)
    # time_series_model = []
    # for out in encoder.output:
    #     time_series_model.append(layers.Wrapper(TimeDistributed(Model(encoder.input,out)))(time_series_input))
    time_encoder = Model(inputs = time_series_input,outputs = time_wrapped_output)
    
    encoder_output = TimeDistributed(layers.AveragePooling2D(pool_size = (7,7)))(time_encoder.output)
    encoder_output = layers.Reshape(target_shape = (seq_length-1,conv_size))(encoder_output)

    encoder_model.trainable = False
    skip_1 = encoder.layers[20].output
    skip_2 = encoder.layers[15].output
    skip_3 = encoder.layers[10].output
    skip_4 = encoder.layers[5].output
    skip_5 = encoder.layers[2].output
    lstm_h_dim = [512,512,512]

    # lstm_inputs = keras.Input(shape = (seq_length-1,conv_size),batch_size = batch_size)

    # encoded_img = SequentialEncoder(seq_length,batch_size,encoder)(encoder.input)
    lstm_input = LSTM_VAR(rnn_size,rnn_size)(encoder_output,activation = 'relu')
    lstm_output = build_lstm(lstm_input,lstm_h_dim)
    lstm_output = LSTM_VAR(rnn_size,rnn_size)(lstm_output)
    lstm_output = LSTM_VAR(rnn_size,rnn_size/2)(lstm_output)
    lstm_output = LSTM_VAR(rnn_size/2,rnn_size//4)(lstm_output)
    bottleneck = TimeDistributed(Dense(7*7*conv_size,activation = 'relu'))(lstm_output)
    bottleneck = layers.Reshape(target_shape = (seq_length-1,compressed_dim,compressed_dim,conv_size))(bottleneck)
    # decoded_img = SequentialDecoder(seq_length,batch_size,self.build_decoder(bottleneck,skip_1,skip_2,skip_3,skip_4,skip_5),compressed_dim)
    decoded_img = build_decoder(bottleneck,skip_1,skip_2,skip_3,skip_4,skip_5)

    lstm_model = Model(inputs = [time_encoder.inputs],outputs = decoded_img)

    return lstm_model

class LSTM_VAR(layers.Layer):
    def __init__(self,input_size,output_size):
        super(LSTM_VAR,self).__init__()
        self.W = tf.Variable(tf.random.normal([int(input_size), int(output_size)]),trainable=True)
        self.B = tf.Variable(tf.random.normal([int(output_size)]),trainable=True)
    
    def call(self,inputs,activation = None):
        if activation == 'relu':
            output = tf.nn.relu(tf.matmul(inputs,self.W)+self.B) 
        else:
            output = tf.matmul(inputs,self.W)+self.B
        return output

The issue happens in the build_res_vgg function. When i call the Model API.

It seems to only work if i include 2 inputs in the lstm_model, however i should only have 1 input, input shape of (batch size, timestep, img_h,img_w,3). The objective is to take in a sequence of images, encode it using pre-trained vgg, pass it through the lstm layers, and reconstruct it back with the same shapes.

Hi @Yeo_Wei_jie

Welcome to the TensorFlow Forum!

Please let us know if this issue still persists. If so, Could you please share the standalone code along with the dataset type and shape you are using for model training to better understand the issue? Thank you.