How to visualize the individual branch embeddings from the below model
def branch(self):
res_inp = Input(shape=(self.res_num, self.max_nbr))
res_fea_len = Input(shape=(self.res_num, self.res_fea_len))
res_type = Input(shape=(self.res_num,))
res_fea = Embedding(input_dim=200, output_dim=self.res_fea_len)(res_type)
for _ in range(self.ncov):
res_fea = self.cov_layer(a_inp, res_inp, res_fea)
res_fea = Activation('relu')(res_fea)
ouput = Dense(self.num_target)(res_fea)
branch = Model(inputs=[res_fea_len, res_inp, res_type], outputs=ouput)
return branch
def multi_modal(self):
res_inp_1 = Input(shape=(self.res_num, self.max_nbr))
res_fea_len_1 = Input(shape=(self.res_num, self.res_fea_len))
res_type = Input(shape=(self.res_num,))
res_inp_2 = Input(shape=(self.res_num, self.max_nbr))
res_fea_len_2 = Input(shape=(self.res_num, self.res_fea_len))
model_brach = branch()
branch_1 = model_branch(res_fea_len_1, res_inp_1, res_type)
branch_2 = model_branch(res_fea_len_2, res_inp_2, res_type)
combined = concatenate([branch_1, branch_2])
self.model = Model(inputs=[res_fea_len_1, res_inp_1, res_type, res_fea_len_2, res_inp_2], outputs=combined)
self.model.compile(optimizer=self.optimizer)
self.losses = ['mse', 'mqse']
I was trying to save the weights but was not able to extract the embeddings from an individual branch.