I have created a simple model with two SimpleRNN layers to evaluate the JVP. To speed up the calculations I decorated the function with tf.function (as in the test code below).
import tensorflow as tf
x = tf.random.normal((64, 64, 1))
inputs = tf.keras.Input((None, 1))
features = tf.keras.layers.SimpleRNN(32, return_sequences=True)(inputs)
features = tf.keras.layers.SimpleRNN(32, return_sequences=True)(features)
outputs = tf.keras.layers.Dense(1)(features)
model = tf.keras.Model(inputs, outputs)v = tf.ones((64, 64, 1))
@tf.function
def jac_vec_prod(net, inp, tangents):with tf.autodiff.ForwardAccumulator(primals=inp, tangents=tangents) as acc:
out = net(inp)
jvp = acc.jvp(out)
return jvpjvp = jac_vec_prod(model, inp=x, tangents=v)
The code returns with the following error:
Traceback (most recent call last):
File ~/miniconda3/envs/spyder/lib/python3.11/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
exec(code, globals, locals)File ~/test.py:28
jvp = jac_vec_prod(model, inp=x, tangents=v)File ~/miniconda3/envs/spyder/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py:153 in error_handler
raise e.with_traceback(filtered_tb) from NoneFile /tmp/autograph_generated_file618o9y0z.py:11 in tf__jac_vec_prod
out = ag.converted_call(ag__.ld(net), (ag__.ld(inp),), None, fscope)File ~/miniconda3/envs/spyder/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:124 in error_handler
del filtered_tbTypeError: in user code:
File "/test.py", line 24, in jac_vec_prod * out = net(inp) File "/miniconda3/envs/spyder/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler ** raise e.with_traceback(filtered_tb) from None TypeError: Exception encountered when calling SimpleRNN.call(). `dtype` <dtype: 'variant'> is not compatible with 1 of dtype int64. Arguments received by SimpleRNN.call(): • sequences=tf.Tensor(shape=(64, 64, 1), dtype=float32) • initial_state=None • mask=None • training=False
However, if @tf.function is not used in the example above everything works fine.
I have also tried to substitute the SimpleRNN layers with Dense layers and in that case the code works as intended with and without the decorator.