TFP with JAX backend: how to train a bijector?

Dear experts,
I was aiming to transpose this demo of training a user Bijector with the JAX backend.

I run on Google Colab. Here is part of the code which seems to work. After there are problems for the training code

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
mpl.rc('image', cmap='jet')
mpl.rcParams['font.size'] = 16

import jax
import jax.numpy as jnp
import numpy as np

import tensorflow.compat.v2 as tf
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors


class Cubic(tfb.Bijector):
    def __init__(self, a, b, validate_args=False, name='Cubic'):
        self.a = jnp.atleast_1d(a)
        self.b = jnp.atleast_1d(b)
        super(Cubic, self).__init__(
            validate_args=validate_args, forward_min_event_ndims=0, name=name)
        
    def _forward(self,x):
        return jnp.squeeze(jnp.power(self.a*x + self.b,3))
    
    def _inverse(self,y):
        return (jnp.sign(y)*jnp.power(jnp.abs(y),1/3)-self.b)/self.a
    
    def _forward_log_det_jacobian(self,x):
        return jnp.log(3.*jnp.abs(self.a))+2.*jnp.log(jnp.abs(self.a*x+self.b))

cubic = Cubic([0.25],[-0.1])
x = jnp.linspace(-10,10,500).reshape(-1,1)
plt.plot(x,cubic.forward(x)) 
plt.show()

plt.plot(x,cubic.inverse(x),lw=3)
plt.plot(x,tfb.Invert(cubic).forward(x),ls="--",c="cyan")
plt.show()

plt.plot(x,cubic.forward_log_det_jacobian(x,event_ndims=0))
plt.show()

# Target distrib
probs = [0.45,0.55]
mix_gauss = tfd.MixtureSameFamily(
      mixture_distribution=tfd.Categorical(
          probs=probs),
      components_distribution=tfd.Normal(
        loc=[2.3, -0.8],       # One for each component.
        scale=[0.4, 0.4]))  # And same here.

x = jnp.linspace(-5.0,5.0,100)
plt.plot(x,mix_gauss.prob(x))
plt.title('Data distribution')
plt.show()

Now I would like to get some training/validation dataset so I proceed to

x_train = mix_gauss.sample(10000, seed = jax.random.PRNGKey(0))
x_train = tf.data.Dataset.from_tensor_slices(x_train)
x_train = x_train.batch(128)

x_valid = mix_gauss.sample(1000, seed = jax.random.PRNGKey(1))
x_valid = tf.data.Dataset.from_tensor_slices(x_valid)
x_valid = x_valid.batch(128)

Then,

trainable_inv_cubic = tfb.Invert(Cubic(a=0.25,b=-0.1))
# (1) Base distn
normal = tfd.Normal(loc=0.,scale=1.)
# trainable distrib
trainable_dist = tfd.TransformedDistribution(normal,trainable_inv_cubic)

x = jnp.linspace(-5,5,100)
plt.figure(figsize=(12,4))
plt.plot(x,mix_gauss.prob(x),label='data')
plt.plot(x,trainable_dist.prob(x),label='trainable')
plt.title('Data & Trainable distribution')
plt.show()

The problem is the following as there are no trainable_inv_cubic.trainable_variables the computation of grads cannot be done

num_epochs = 10
opt = tf.keras.optimizers.Adam()
train_losses = []
valid_losses = []

for epoch in range(num_epochs):
    print("Epoch {}...".format(epoch))
    train_loss = tf.keras.metrics.Mean()
    val_loss = tf.keras.metrics.Mean()
    
    # Train
    for train_batch in x_train:
        with tf.GradientTape() as tape:
            tape.watch(trainable_inv_cubic.trainable_variables)
            loss = -trainable_dist.log_prob(train_batch)
        train_loss(loss)
        grads = tape.gradient(loss, trainable_inv_cubic.trainable_variables)
        opt.apply_gradients(zip(grads, trainable_inv_cubic.trainable_variables))
    train_losses.append(train_loss.result().numpy())
        
    # Validation
    for valid_batch in x_valid:
        loss = -trainable_dist.log_prob(valid_batch)
        val_loss(loss)
    valid_losses.append(val_loss.result().numpy())

I’m used to perform some optimization with pure JAX code and I am not a TF expert at all, so if someone can help me in this translation I would be very grateful. Thanks

If You find a bug the bijector forward function weakly caches the result->input mapping to make downstream inverses and log-determinants fast. But somehow this is also interfering with the gradient. A workaround is adding a `del out.

Here is the error message

AttributeError                            Traceback (most recent call last)

<ipython-input-17-9a8348b0b336> in <module>
     15             loss = -trainable_dist.log_prob(train_batch)
     16         train_loss(loss)
---> 17         grads = tape.gradient(loss, trainable_inv_cubic.trainable_variables)
     18         opt.apply_gradients(zip(grads, trainable_inv_cubic.trainable_variables))
     19     train_losses.append(train_loss.result().numpy())

1 frames

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/imperative_grad.py in imperative_grad(tape, target, sources, output_gradients, sources_raw, unconnected_gradients)
     71       output_gradients,
     72       sources_raw,
---> 73       compat.as_str(unconnected_gradients.value))

AttributeError: 'DeviceArray' object has no attribute '_id'

@Babak_Zahedi,

We notice your response with an unusual hyperlink for the above query. Could you help us to understand the purpose? We are here to help you to resolve your problem. Thank you.

Heuu… I am not sure to understand if the comment is for my original post… My concern is to get TF with JAX backend working. Thanks

http://discuss.ai.google.dev/t/tfp-with-jax-backend-how-to-train-a-bijector/12026/3?u=chunduriv

For the above reply. Not to your post. Thank you.