Home made RealNVP with TF v2: training ok but get wrong shape for "prob"

Hello,

As I have problems with the current TF2 RealNVP see this post, I have bored some code from this github.

tf.version,tfp.version = 2.8.2, 0.16.0

import numpy as np

import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
from tensorflow import keras as tfk
from tensorflow.keras import layers as tfkl

tf.random.set_seed(0)

DTYPE=tf.float32
NP_DTYPE=np.float32
MODEL = 'NVP' #  MAF or NVP(home made)
TARGET_DENSITY = 'GAUSS1' # Which dataset to model. 
USE_BATCHNORM = False # not yet implemented

# dataset-specific settings
settings = {
    'GAUSS1':{  #Banana shape use case of Papamakarios et al. 2017arXiv170507057P
        'n_samples': 6000,
        'batch_size': 100,
        'num_bijectors': 1 if (MODEL == "MADE" or MODEL == "NVP") else 6,
        'n_epochs': 500
    },
}

from numpy.random import default_rng
from sklearn.model_selection import train_test_split

rng = default_rng()

n_samples = int(settings[TARGET_DENSITY]['n_samples'])
x2 = rng.standard_normal(n_samples).astype(dtype=np.float32) * 2.
x1 = rng.standard_normal(n_samples).astype(dtype=np.float32) + (x2 * x2 / 4.)
 X = np.stack([x1, x2], axis=-1)
 xlim, ylim = [-4, 8], [-6, 6]

X = X.astype(NP_DTYPE)
X, X_test = train_test_split(X, test_size=0.33, random_state=42)


#from tensorflow.keras.layers import Layer, Dense, BatchNormalization, ReLU
#from tensorflow.keras import Model

class NN(tfkl.Layer):
    def __init__(self, input_shape, n_hidden=[512, 512], activation="relu", name="nn"):
        super(NN, self).__init__(name="nn")
        layer_list = []
        for i, hidden in enumerate(n_hidden):
            layer_list.append(tfkl.Dense(hidden, activation=activation, name='dense_{}_1'.format(i)))
            layer_list.append(tfkl.Dense(hidden, activation=activation, name='dense_{}_2'.format(i)))
        self.layer_list = layer_list
        self.log_s_layer = tfkl.Dense(input_shape, activation="tanh", name='log_s')
        self.t_layer = tfkl.Dense(input_shape, name='t')

    def call(self, x):
        y = x
        for layer in self.layer_list:
            y = layer(y)
        log_s = self.log_s_layer(y)
        t = self.t_layer(y)
        return log_s, t

#def nn_test():
#    nn = NN(1, [512, 512])
#    x = tf.keras.Input([1])
#    log_s, t = nn(x)
#    # Non trainable params: -> Batch Normalization's params
#    tf.keras.Model(x, [log_s, t], name="nn_test").summary()
#nn_test()


class RealNVP(tfb.Bijector):
    def __init__(
        self,
        input_shape,
        n_hidden=[512, 512],
        # this bijector do vector wise quantities.
        forward_min_event_ndims=1,
        validate_args: bool = False,
        name="real_nvp",
    ):
        """
        Args:
            input_shape: 
                input_shape, 
                ex. [28, 28, 3] (image) [2] (x-y vector)
                          
        """
        super(RealNVP, self).__init__(
            validate_args=validate_args, forward_min_event_ndims=forward_min_event_ndims, name=name
        )

        assert input_shape[-1] % 2 == 0
        self.input_shape = input_shape
        nn_layer = NN(input_shape[-1] // 2, n_hidden)
        nn_input_shape = input_shape.copy()
        nn_input_shape[-1] = input_shape[-1] // 2
        x = tf.keras.Input(nn_input_shape)
        log_s, t = nn_layer(x)
        self.nn = tfk.Model(x, [log_s, t], name="nn_"+name)

    def _forward(self, x):
        x_a, x_b = tf.split(x, 2, axis=-1)
        y_b = x_b
        log_s, t = self.nn(x_b)
        s = tf.exp(log_s)
        y_a = s * x_a + t
        y = tf.concat([y_a, y_b], axis=-1)
        return y

    def _inverse(self, y):
        y_a, y_b = tf.split(y, 2, axis=-1)
        x_b = y_b
        log_s, t = self.nn(y_b)
        s = tf.exp(log_s)
        x_a = (y_a - t) / s
        x = tf.concat([x_a, x_b], axis=-1)
        return x

    def _forward_log_det_jacobian(self, x):
        _, x_b = tf.split(x, 2, axis=-1)
        log_s, t = self.nn(x_b)
        return log_s


#def realnvp_test():
#    realnvp = RealNVP(input_shape=[2], n_hidden=[512, 512])
#    x = tf.keras.Input([2])
#    y = realnvp.forward(x)
#    print('trainable_variables :', len(realnvp.trainable_variables))
#    tfk.Model(x, y, name="realnvp_test").summary()
#realnvp_test()


num_bijectors = settings[TARGET_DENSITY]['num_bijectors']
bijectors = []

for i in range(num_bijectors):

    if MODEL == 'NVP':
        nvp =  RealNVP(
            input_shape=[2], n_hidden=[128,128], name='myNVP%d' %i)
        bijectors.append(nvp)

    elif MODEL == 'MAF':
        hidden_units = [128, 128]
        bijectors.append(tfb.MaskedAutoregressiveFlow( name ='MAF%d' %i, 
            shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
                params=2, hidden_units=hidden_units, activation=tf.nn.leaky_relu)))
    
    if USE_BATCHNORM:
        # BatchNorm helps to stabilize deep normalizing flows, esp. Real-NVP
        bijectors.append(tfb.BatchNormalization(name='BN%d' % i))

    #Permutation (don't forget)
    bijectors.append(tfb.Permute(permutation=[1, 0]))
# Discard the last Permute layer.
flow_bijector = tfb.Chain(list(reversed(bijectors[:-1])))

for bij in flow_bijector.bijectors:
  print(bij)

base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros([2], DTYPE),name='base dist')
class MyModel(tf.keras.models.Model):

    def __init__(self, *, base, bij, **kwargs):
        super().__init__(**kwargs)

        # Defining the flow
        self.flow = tfd.TransformedDistribution(
            distribution=base,
            bijector=bij)
        
        self._variables = self.flow.variables

    def call(self, *inputs): 
        return self.flow.bijector.forward(*inputs)


    def getFlow(self, num):
        return self.flow.sample(num)

mymodel = MyModel(base=base_dist, bij=flow_bijector)

X.shape, mymodel.flow.prob(X).shape, base_dist.prob(X).shape

The output is

((4020, 2), TensorShape([4020, 4020]), TensorShape([4020]))

where the mymodel.flow.prob(X).shape is not what I was expecting, with MODEL=“MAF” I get TensorShape([4020])

Now when asking for # of trainable variables

x_ = tfkl.Input(shape=(2,), dtype=DTYPE)
log_prob_ = mymodel(x_)
model = tfk.Model(x_, log_prob_)
model.summary() 

I get something like

Model: "model_16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_30 (InputLayer)       [(None, 2)]               0         
                                                                 
 my_model_5 (MyModel)        (None, 2)                 50050     
                                                                 
=================================================================
Total params: 50,050
Trainable params: 50,050
Non-trainable params: 0
_____________________________

and when I train it gives a good result.

So does someone can help me to get the “prob” correct? Thanks