How to train parameters of 2 different classes together?

How to train the parameters of Class1 and Class2 together? That is weights of self.linear1 and self.linear2 fromClass1 along with weight of Class2? Since Class1 calls Class2 as self.conv1 = Class2(w_in, w_out) hence they are interlinked and will form a chain during forward pass. That’s why I wish to train them together! What will I write in my training loop, while calculating the grads? grads = tape.gradient(loss, ? )

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class Class1(layers.Layer):

def __init__(self, num_channels, w_in, w_out, num_class):
    super(Class1, self).__init__()

    self.num_channels = num_channels
    self.w_in = w_in
    self.w_out = w_out    

    self.conv1 = Class2(w_in, w_out)

    self.linear1 = tf.keras.layers.Dense( self.w_out, input_shape =(self.w_out*self.num_channels, ), activation= None) 
    self.linear2 = tf.keras.layers.Dense( self.num_class, input_shape=(self.w_out, ), activation= None)

    def call(self, A):
        a = self.conv1(A)
        return a
            
class Class2(tf.keras.layers.Layer):

def __init__(self, in_channels, out_channels):
    super(Class2, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels  
    
    self.weight = self.add_weight(
        shape= (out_channels,in_channels,1,1), initializer="random_normal", trainable=True)

    
def call(self, A):
    print(A)
    A = tf.reduce_sum(A*(tf.nn.softmax(self.weight,1)), 1)
    print(A)
    return A

Hi @Anshuman_Sinha,

To train the parameters of both classes at a time, you must compute the gradients with respect to the loss function, taking into account the weights of both {Class1} and {Class2}, in order to train the parameters of both classes simultaneously. I have done few changes to the code like as Class1 refers to the Class2 , Class2 is instantiated first and then class1 and assigned num_class to self. Please refer to the sample code for training loop in the gist.

Thank You