Get_initial_state method in custom RNN cell over-writing user input initial_state when stateful = True

I’m following the instructions for tf.keras.layers.RNN to define a custom RNN layer and When I call the layer the initial_state passed in as an argument, the initial state is over-written by the output of my get_initial_state() method (which randomly initializes the initial state) when stateful = True. The custom RNN cell is defined below:

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend
import numpy as np
np.random.seed(42)

class MinimalRNNCell(keras.layers.Layer):

def __init__(self, units, **kwargs):
    self.units = units
    self.state_size = units
    super(MinimalRNNCell, self).__init__(**kwargs)

def build(self, input_shape):
    #weights initialized as zero
    self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                  initializer=tf.keras.initializers.Constant(value=0),
                                  name='kernel')
    #constant weights for testing
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units),
        initializer=tf.keras.initializers.Constant(value=1),
        name='recurrent_kernel')
    self.built = True

def call(self, inputs, states):
    prev_output = states[0]
    h = backend.dot(inputs, self.kernel)
    output = h + backend.dot(prev_output, self.recurrent_kernel)
    return output, [output]

#randomly initialized initial state
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
    initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1., seed=None)
    return initializer((batch_size, self.units))

batch_size = 8
input_dim = 4
state_size = 3
timestep = 2

input = tf.zeros((batch_size, timestep, input_dim)) #input initialized to zero
init = ( tf.random.stateless_normal([batch_size, state_size], seed = (1,2) ))


min_rnn = MinimalRNNCell(state_size)
layer = keras.layers.RNN(min_rnn, return_sequences=True, stateful=True)
out =layer(input , training = True , initial_state =  init)
out[:,0,:] == backend.dot(init, min_rnn.recurrent_kernel)

Regarding the last four lines, when layer is defined such that stateful = False, the last line generates a boolean array of all True indicating that the initial_state is being passed in properly. However, when stateful = True, that’s no longer the case. I’m wondering how to fix this such that the user input initial state always over-rides the random initialization from my get_initial_state() method.

1 Like

Not sure if it will fit your solution but the following tutorial gives you an idea of setting the initial LSTM states of a decoder in an NMT setting:

2 Likes

The above behaviour was observed on Tensorflow version 2.5.0. I just tested my exact same code on Tensorflow version 2.3.0 and get the desired behaviour (user passed in initial state in the call method over-rides get_initial_state() of the custom RNN cell whether Stateful is True or False), which seems to be consistent with the documentation:

A get_initial_state(inputs=None, batch_size=None, dtype=None) method that creates a tensor meant to be fed to call() as the initial state, if the user didn’t specify any initial state via other means.

However, in Tensorflow 2.3.0, if Stateful = True, it is ignoring the get_initial_state() method even if a user initial_state is not passed in (leading to initializing initial state as all zeroes).

1 Like