Hi,
I need to compute gradient of a scalar valued function with vector inputs for various inputs. I am currently doing something like below. But no matter what value I set for parallel_iterations
of the tf.while_loop
, it is only computing gradients for one input at a time. What am I missing?
import tensorflow as tf
import time
@tf.function
def f(x):
x_hat = tf.signal.rfft(x)
x_recon = tf.signal.irfft(x_hat)
lnp = tf.reduce_sum(x_recon)
tf.print(lnp)
return lnp
@tf.function
def ode_fn(x):
return -x + 1.0
@tf.function
def integrate(x, nsteps, time_step):
y = tf.TensorArray(dtype=tf.float32, size=nsteps)
x_next = x
for i in tf.range(nsteps):
x_next = x_next + time_step * ode_fn(x_next)
y = y.write(i, x_next)
return y.stack()
@tf.function
def compute_grads(x):
lnpgrad = tf.TensorArray(dtype=tf.float32, size=x.shape[0])
def cond(i, lnpgrad):
return tf.less(i, x.shape[0])
def body(i, lnpgrad):
x_i = x[i]
with tf.GradientTape() as tape:
tape.watch(x_i)
lnp_i = tf.reduce_sum(integrate(x_i, 20000, 0.1))
tf.print(lnp_i)
lnpgrad = lnpgrad.write(i, tape.gradient(lnp_i, x_i))
return i + 1, lnpgrad
i = tf.constant(0, dtype=tf.int32)
i, lnpgrad = tf.while_loop(cond,
body,
loop_vars=[i, lnpgrad],
parallel_iterations=5)
lnpgrad = lnpgrad.stack()
return lnpgrad
x = tf.random.normal(shape=[10, 5000],
mean=10.0,
stddev=1.0,
dtype=tf.float32)
start = time.time()
fx_grads = compute_grads(x)
end = time.time()
print(f"Elapsed {end - start} seconds")