tf.autodiff.ForwardAccumulator() in a function decorated with tf.function doesn`t work for RNNs

I have created a simple model with two SimpleRNN layers to evaluate the JVP. To speed up the calculations I decorated the function with tf.function (as in the test code below).

import tensorflow as tf

x = tf.random.normal((64, 64, 1))

inputs = tf.keras.Input((None, 1))
features = tf.keras.layers.SimpleRNN(32, return_sequences=True)(inputs)
features = tf.keras.layers.SimpleRNN(32, return_sequences=True)(features)
outputs = tf.keras.layers.Dense(1)(features)
model = tf.keras.Model(inputs, outputs)

v = tf.ones((64, 64, 1))

@tf.function
def jac_vec_prod(net, inp, tangents):

with tf.autodiff.ForwardAccumulator(primals=inp, tangents=tangents) as acc:
out = net(inp)
jvp = acc.jvp(out)
return jvp

jvp = jac_vec_prod(model, inp=x, tangents=v)

The code returns with the following error:

Traceback (most recent call last):

File ~/miniconda3/envs/spyder/lib/python3.11/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
exec(code, globals, locals)

File ~/test.py:28
jvp = jac_vec_prod(model, inp=x, tangents=v)

File ~/miniconda3/envs/spyder/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:153 in error_handler
raise e.with_traceback(filtered_tb) from None

File /tmp/autograph_generated_file618o9y0z.py:11 in tf__jac_vec_prod
out = ag
.converted_call(ag__.ld(net), (ag__.ld(inp),), None, fscope)

File ~/miniconda3/envs/spyder/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:124 in error_handler
del filtered_tb

TypeError: in user code:

File "/test.py", line 24, in jac_vec_prod  *
    out = net(inp)
File "/miniconda3/envs/spyder/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler  **
    raise e.with_traceback(filtered_tb) from None

TypeError: Exception encountered when calling SimpleRNN.call().

`dtype` <dtype: 'variant'> is not compatible with 1 of dtype int64.

Arguments received by SimpleRNN.call():
  • sequences=tf.Tensor(shape=(64, 64, 1), dtype=float32)
  • initial_state=None
  • mask=None
  • training=False

However, if @tf.function is not used in the example above everything works fine.
I have also tried to substitute the SimpleRNN layers with Dense layers and in that case the code works as intended with and without the decorator.

Hi @Francesco_Munzone,

A warm welcome to the TensorFlow forum!

The error is due to SimpleRNN.call() function expects attributes dtype as int64 while using tf.function. In graph execution mode, tensorflow creates computational graphs, and expects tensors dtype should be consistent throughout the execution. For above case, I suggest to cast the dtype of inputs as int64 using tf.cast and execute it.

Here’s working code gist attached for implementation of above in both eager and graph mode. Kindly check out this documentation for more information about tf.function.

Thank You.