Dear experts,
I was aiming to transpose this demo of training a user Bijector with the JAX backend.
I run on Google Colab. Here is part of the code which seems to work. After there are problems for the training code
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
mpl.rc('image', cmap='jet')
mpl.rcParams['font.size'] = 16
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
class Cubic(tfb.Bijector):
def __init__(self, a, b, validate_args=False, name='Cubic'):
self.a = jnp.atleast_1d(a)
self.b = jnp.atleast_1d(b)
super(Cubic, self).__init__(
validate_args=validate_args, forward_min_event_ndims=0, name=name)
def _forward(self,x):
return jnp.squeeze(jnp.power(self.a*x + self.b,3))
def _inverse(self,y):
return (jnp.sign(y)*jnp.power(jnp.abs(y),1/3)-self.b)/self.a
def _forward_log_det_jacobian(self,x):
return jnp.log(3.*jnp.abs(self.a))+2.*jnp.log(jnp.abs(self.a*x+self.b))
cubic = Cubic([0.25],[-0.1])
x = jnp.linspace(-10,10,500).reshape(-1,1)
plt.plot(x,cubic.forward(x))
plt.show()
plt.plot(x,cubic.inverse(x),lw=3)
plt.plot(x,tfb.Invert(cubic).forward(x),ls="--",c="cyan")
plt.show()
plt.plot(x,cubic.forward_log_det_jacobian(x,event_ndims=0))
plt.show()
# Target distrib
probs = [0.45,0.55]
mix_gauss = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=probs),
components_distribution=tfd.Normal(
loc=[2.3, -0.8], # One for each component.
scale=[0.4, 0.4])) # And same here.
x = jnp.linspace(-5.0,5.0,100)
plt.plot(x,mix_gauss.prob(x))
plt.title('Data distribution')
plt.show()
Now I would like to get some training/validation dataset so I proceed to
x_train = mix_gauss.sample(10000, seed = jax.random.PRNGKey(0))
x_train = tf.data.Dataset.from_tensor_slices(x_train)
x_train = x_train.batch(128)
x_valid = mix_gauss.sample(1000, seed = jax.random.PRNGKey(1))
x_valid = tf.data.Dataset.from_tensor_slices(x_valid)
x_valid = x_valid.batch(128)
Then,
trainable_inv_cubic = tfb.Invert(Cubic(a=0.25,b=-0.1))
# (1) Base distn
normal = tfd.Normal(loc=0.,scale=1.)
# trainable distrib
trainable_dist = tfd.TransformedDistribution(normal,trainable_inv_cubic)
x = jnp.linspace(-5,5,100)
plt.figure(figsize=(12,4))
plt.plot(x,mix_gauss.prob(x),label='data')
plt.plot(x,trainable_dist.prob(x),label='trainable')
plt.title('Data & Trainable distribution')
plt.show()
The problem is the following as there are no trainable_inv_cubic.trainable_variables
the computation of grads cannot be done
num_epochs = 10
opt = tf.keras.optimizers.Adam()
train_losses = []
valid_losses = []
for epoch in range(num_epochs):
print("Epoch {}...".format(epoch))
train_loss = tf.keras.metrics.Mean()
val_loss = tf.keras.metrics.Mean()
# Train
for train_batch in x_train:
with tf.GradientTape() as tape:
tape.watch(trainable_inv_cubic.trainable_variables)
loss = -trainable_dist.log_prob(train_batch)
train_loss(loss)
grads = tape.gradient(loss, trainable_inv_cubic.trainable_variables)
opt.apply_gradients(zip(grads, trainable_inv_cubic.trainable_variables))
train_losses.append(train_loss.result().numpy())
# Validation
for valid_batch in x_valid:
loss = -trainable_dist.log_prob(valid_batch)
val_loss(loss)
valid_losses.append(val_loss.result().numpy())
I’m used to perform some optimization with pure JAX code and I am not a TF expert at all, so if someone can help me in this translation I would be very grateful. Thanks