Understanding RNN Base Layer

Hi, I am trying to understand the RNN base layer from the keras example - https://keras.io/keras_core/api/layers/recurrent_layers/rnn/

class MinimalRNNCell(keras.layers.Layer):

    def __init__(self, units,**kwargs):
        self.units = units
        self.state_size = units
  
        super(MinimalRNNCell, self).__init__(**kwargs)

    def build(self, input_shape):
        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                      initializer='uniform',
                                      name='kernel')
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units),
            initializer='uniform',
            name='recurrent_kernel')
        self.built = True
        #print("Self kernel shape:",self.kernel)
        #print("Self recurrent kernel shape:",self.recurrent_kernel)
        
    # def get_initial_state(self):
    #   return self.initial_state

    def call(self, inputs,states):
        #print("Units:",self.units)
        print("States:",states)
        prev_output = states[0]
        h = backend.dot(inputs, self.kernel)
        #print("h shape:",h.shape)
        output = h + backend.dot(prev_output, self.recurrent_kernel)
        return output, [output]

Now, if my input is like this

cell = MinimalRNNCell(32)
x = tf.ones((2,10,5))
layer = RNN(cell,return_sequences = True)
y = layer(x)

it means that the RNN sequence has 10 timestamps where input_size is (1,5) and the size of the hidden dimension is (1,32). Am I right? Another question is - how can I define the initial hidden state? I went through the RNN documentation but actually could not understand how to define the initial state for the first call.
Thanks.

Hi @Arkaprava_Majumdar ,

You’re on the right track , need some more points .

  1. RNN sequence structure:
  • your input sequence has 10 timestamps.
  • For each timestamp, the input size is (batch_size, 5). In your case, batch_size is 2, so each input at each timestamp is (2, 5).
  • The hidden state size is indeed (batch_size, 32), so in your case, it’s (2, 32).
  1. Defining the initial hidden state: You can define the initial state in a few ways: a) Implement the get_initial_state method in your RNN cell:

I have Replicated the code incorporating these points attaching the gist for your reference .
You can Also Refer to Official Documentation.

Hope this helps,

Thank you !