Hello, I am working on a code that was written in TF1. I am studying the code and try to convert it into TF2. Today I struggle to know what is the equivalent of zero_state method from tf.compat.v1.nn.rnn_cell.RNNCell. I would like to be able to compute this zero-filled state tensor, but with native TF2 code (using keras).
In TensorFlow 2.x, the zero_state method is replaced by the get_initial_state method.
To use get_initial_state , you need to first create an instance of the RNN cell, and then call get_initial_state method on the cell with the batch size and dtype as arguments.
import tensorflow as tf
cell = tf.keras.layers.LSTMCell(128)# 128 units
batch_size = #your choice
initial_state = cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
print(initial_state.shape)
This will create an instance of the LSTM cell with 128 units and then use the get_initial_state method to get the initial state tensor with the given batch size and data type.
Because, for me it is not clear how to make sure it is a tensor full of zeros in the RNN documentation.
Anyway I checked and this give same results as you. Also thanks for the link about TF1 to TF2 migration. I am a novice in TensorFlow, so beginning by trying to translate TF1 to TF2 is not the optimal way of learning it ! But I have to do it anyway, at least it is a way forcing me to go deep in the documentation.
Best regards.