Modifying the sub-graph within conditional blocks implemented with functionals in TF 2.x

[ not a contribution ]
Hi,

Tensorflow version : TF 2.6
LSTM Layer : tf.keras.layers.LSTM  |  TensorFlow v2.16.1

With TF 2.x, the conditional blocks - namely “while” block inside an RNN is implemented as a functional - tensorflow/tensorflow/core/ops/functional_ops.cc at fc4504edb1ab419ae59b0ebb9ff8d943beb61117 · tensorflow/tensorflow · GitHub (please see attached snapshot from TensorBoard). We need to modify the sub-graph executed within the while context, to include some additional TF operations. What would be the recommended way to achieve this?

In TF 1.x, we were able to access the sub-graph within while context and modify it to include additional operations in while block’s execution context. With TF 2.x, we don’t seem to have access to this internal sub-graph anymore. Could you please advise.

Example code block with Keras LSTM layer:

– Code Block –
def test_lstm():
tf.compat.v1.reset_default_graph()
with tf.device(‘/cpu:0’):
model = tf.keras.Sequential()
# Add an Embedding layer expecting input vocab of size 1000, and
# output embedding dimension of size 64.
model.add(tf.keras.layers.Embedding(input_dim=1000, output_dim=64))
# Add a LSTM layer with 128 internal units.
model.add(tf.keras.layers.LSTM(128))
# Add a Dense layer with 10 unit…
model.add(tf.keras.layers.Dense(10))
model.summary()
session = tf.compat.v1.Session()
writer = tf.compat.v1.summary.FileWriter(‘./lstm’, session.graph)
–code block –

LSTM Layer with conditionals :

Thanks,
Sangeetha

Hi @quic-ssiddego

Welcome to the TensorFlow Forum!

You will not be able to modify the sub-graph inside the RNN in TensorFlow 2.x . However you can achieve this behavior using Custom RNN. You may need to implement a custom RNN cell which inherits from tf.keras.layers.LSTMCell to include some additional TF operations and then use this custom RNN cell inside the LSTM or RNN layer defined in the model as below:

class MyCustomCell(tf.keras.layers.LSTMCell):
    def __init__(self, units, **kwargs):
        super(MyCustomLSTMCell, self).__init__(units, **kwargs) # You can define any additional state variables or operations here

    def call(self, inputs, states, training=None):
        lstm_output, next_states = super().call(inputs, states, training)
        custom_output = tf.nn.relu(lstm_output). # You can add your custom operations here as shown in this line
        return custom_output, next_states

model = tf.keras.Sequential([
  tf.keras.layers.Embedding(input_dim=1000, output_dim=64),
  tf.keras.layers.RNN(MyCustomLSTMCell(128)),  #<-------
  tf.keras.layers.Dense(10)])