Blockwise distribution of MultivariateNormal and TransformedDistribution

Hi everyone,

I’m trying to use the Blockwise distribution in order to be able to sample from a TransformedDistribution and a MultivariateNormal and have a single vector as output. The code is the following:

import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd

mean = tf.random.normal(shape=(32, 256))
sigma = tf.random.normal(shape=(32, 256))
transformed_distribution = tfd.TransformedDistribution(
    distribution=tfd.Normal(loc=mean[:, 0], scale=sigma[:, 0]),
    bijector=tfb.BatchNormalization()
)
multivariate_normal = tfd.MultivariateNormalDiag(
    loc=mean[:, 1:], 
    scale_diag=sigma[:, 1:]
)
blockwise = tfd.Blockwise([transformed_distribution, multivariate_normal])
samples = blockwise.sample()  # Here I expect samples.shape = (32, 256)

But It does not execute with success. Here’s the traceback:

Traceback (most recent call last):
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/distribution.py”, line 1205, in sample
return self._call_sample_n(sample_shape, seed, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/distribution.py”, line 1182, in _call_sample_n
samples = self._sample_n(
^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/blockwise.py”, line 350, in _sample_n
self._distribution.sample(n, seed=seed))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/distribution.py”, line 1205, in sample
return self._call_sample_n(sample_shape, seed, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/joint_distribution.py”, line 956, in _call_sample_n
return self._sample_n(
^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/internal/distribution_util.py”, line 1350, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/joint_distribution.py”, line 699, in _sample_n
xs = self._call_execute_model(
^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/joint_distribution.py”, line 850, in _call_execute_model
return self._execute_model(
^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/joint_distribution.py”, line 1005, in _execute_model
next_value, traced_values = sample_and_trace_fn(
^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/joint_distribution.py”, line 132, in trace_values_only
ret = trace_distributions_and_values(dist, sample_shape, seed, value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/joint_distribution.py”, line 115, in trace_distributions_and_values
value = dist.sample(sample_shape, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/distribution.py”, line 1205, in sample
return self._call_sample_n(sample_shape, seed, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/distributions/transformed_distribution.py”, line 338, in _call_sample_n
return self.bijector.forward(x, **bijector_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/bijectors/bijector.py”, line 1326, in forward
return self._call_forward(x, name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/bijectors/bijector.py”, line 1308, in _call_forward
return self._cache.forward(x, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/internal/cache_util.py”, line 334, in forward
return self._lookup(x, self._forward_name, self._inverse_name, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/internal/cache_util.py”, line 493, in _lookup
self._invoke(input, forward_name, kwargs, attrs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/internal/cache_util.py”, line 532, in _invoke
return getattr(self.bijector, fn_name)(input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/bijectors/batch_normalization.py”, line 224, in _forward
return self._de_normalize(x)
^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/bijectors/batch_normalization.py”, line 216, in _de_normalize
mean = broadcast_fn(self.batchnorm.moving_mean)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow_probability/python/bijectors/batch_normalization.py”, line 204, in _broadcast
return tf.reshape(v, broadcast_shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow/python/ops/weak_tensor_ops.py”, line 88, in wrapper
return op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File “/.venv/lib/python3.11/site-packages/tensorflow/python/util/traceback_utils.py”, line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File “/.venv/lib/python3.11/site-packages/tensorflow/python/eager/execute.py”, line 53, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InvalidArgumentError:
{{function_node _wrapped__Reshape_device/job:localhost/replica:0/task:0/device:GPU:0}}
Input to reshape is a tensor with 32 values, but the requested shape has 1 [Op:Reshape]
python-BaseException

I assume it has to do with the way I’m building the Blockwise distribution and the batch size and event size of its components but I’m not able to fix it.

Thanks for your help.