Hello,
I am not an expert on TF and I try to implement some Normalizing Flows to setup exercises on density estimations. I am working with TF 2.8.2 and import usual tf lib.
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl
Now, I face a problem trying to use a simple IAF layer
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros([2], DTYPE),name='base dist')
flow_bijector = tfb.Invert(tfb.MaskedAutoregressiveFlow(name ='IAF',
shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
params=2, hidden_units=[512, 512], activation='relu')))
trans_dist = tfd.TransformedDistribution(
distribution=base_dist,
bijector=flow_bijector)
If the following calls seems to be ok
x_test = trans_dist.sample()
trans_dist.log_prob(x_test)
giving
<tf.Tensor: shape=(), dtype=float32, numpy=-3.4625406>
preparing a training raise an error
x_ = tfkl.Input(shape=(2,), dtype=tf.float32)
log_prob_ = trans_dist.log_prob(x_)
model = tfk.Model(x_, log_prob_)
model.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-4),
loss=lambda _, log_prob: -tf.reduce_mean(log_prob))
as
----> 2 log_prob_ = trans_dist.log_prob(x_)
TypeError: You are passing KerasTensor(type_spec=TensorSpec(shape=(), dtype=tf.int32, name=None),
inferred_value=[2], name='tf.math.reduce_prod_2/Prod:0', description="created by layer
'tf.math.reduce_prod_2'"), an intermediate Keras symbolic input/output, to a TF API that does not allow
registering custom dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, or `tf.map_fn`. Keras
Functional model construction only supports TF API calls that *do* support dispatching, such as
`tf.math.add` or `tf.reshape`. Other APIs cannot be called directly on symbolic Kerasinputs/outputs.
You can work around this limitation by putting the operation in a custom Keras layer `call` and calling
that layer on this symbolic input/output.
Notice that if instead of IAF , I use a MAF layer
flow_bijector = tfb.MaskedAutoregressiveFlow(name ='IAF',
shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
params=2, hidden_units=[512, 512], activation='relu'))
then no error is raised.
Does someone can help me, that would be great.
Thanks