Hi,
After running
losses = tfp.vi.fit_surrogate_posterior(
target_log_prob_fn,
surrogate_posterior,
optimizer=optimizer,
num_steps=1000,
seed=42,
sample_size=100)
I receive the following error:
ValueError: Dimensions must be equal, but are 1462 and 100 for ‘{{node monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched_CONSTRUCTED_AT_top_level/log_prob/make_rank_polymorphic/loop_body/fn_of_vectorized_args/add}} = AddV2[T=DT_FLOAT](monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched_CONSTRUCTED_AT_top_level/log_prob/make_rank_polymorphic/loop_body/fn_of_vectorized_args/GatherV2, monte_carlo_variational_loss/expectation/JointDistributionCoroutineAutoBatched_CONSTRUCTED_AT_top_level/log_prob/make_rank_polymorphic/loop_body/GatherV2_3)’ with input shapes: [1462], [100].
The model integrates a neural net into a multilevel model:
nn_model_layers = keras.Sequential([
keras.layers.InputLayer(input_shape = (1,)),
keras.layers.Dense(1),
tfp.layers.DistributionLambda(lambda t: tfd.Normal(loc=t, scale=1)),
])
nn_model = keras.Model(inputs=nn_model_layers.inputs,
outputs=nn_model_layers.outputs)
def make_joint_distribution_coroutine(genre, year, num_years, num_observations, nn_model):
def model():
# Hyperpriors:
# mu_alpha ~ Normal(0,1)
mu_alpha = yield tfd.Normal(loc=0., scale=.1, name = 'alpha_mu')
# sigma_alpha ~ HalfNormal(0,1)
sigma_alpha = yield tfd.HalfNormal(scale=.1, name = 'alpha_sigma')
# Priors:
# alpha ~ Normal(alpha_mu, alpha_sigma)
alpha = yield tfd.Normal(loc=mu_alpha*tf.ones(num_years),
scale=sigma_alpha,
name='alpha')
# beta ~ neural_network(X)
beta = yield nn_model(genre)
# sigma ~ HalfNormal(0,1)
sigma = yield tfd.HalfNormal(scale=.1, name = 'sigma')
# Likelihood
random_effect = tf.gather(alpha, year, axis=-1)
mu = random_effect + beta
yield tfd.Normal(loc=mu, scale=sigma, name = 'likelihood')
return tfd.JointDistributionCoroutineAutoBatched(model)
The line beta = yield nn_model(genre)
seems to cause the error. nn_model returns a distribution, which should be exepcted when “using” yield, right?
Grateful for any hints on how to resolve this.