I’m following the instructions for tf.keras.layers.RNN to define a custom RNN layer and When I call the layer the initial_state passed in as an argument, the initial state is over-written by the output of my get_initial_state() method (which randomly initializes the initial state) when stateful = True. The custom RNN cell is defined below:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend
import numpy as np
np.random.seed(42)
class MinimalRNNCell(keras.layers.Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
#weights initialized as zero
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer=tf.keras.initializers.Constant(value=0),
name='kernel')
#constant weights for testing
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer=tf.keras.initializers.Constant(value=1),
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, [output]
#randomly initialized initial state
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1., seed=None)
return initializer((batch_size, self.units))
batch_size = 8
input_dim = 4
state_size = 3
timestep = 2
input = tf.zeros((batch_size, timestep, input_dim)) #input initialized to zero
init = ( tf.random.stateless_normal([batch_size, state_size], seed = (1,2) ))
min_rnn = MinimalRNNCell(state_size)
layer = keras.layers.RNN(min_rnn, return_sequences=True, stateful=True)
out =layer(input , training = True , initial_state = init)
out[:,0,:] == backend.dot(init, min_rnn.recurrent_kernel)
Regarding the last four lines, when layer is defined such that stateful = False, the last line generates a boolean array of all True indicating that the initial_state is being passed in properly. However, when stateful = True, that’s no longer the case. I’m wondering how to fix this such that the user input initial state always over-rides the random initialization from my get_initial_state() method.