Tf.keras.model vs model class declaration

Take a look at the 2 model declarations below. Both are valid and functionally equivalent but if you take a look at the images, you’ll see the difference between an autoencoder trained for 1 epoch with the tf.keras.model declaration, and the other is an autoencoder defined by a class, trained for 10 epochs.

Model function:

def Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim):
    inputs = layers.Input(shape=(256, 256, 3))

    patches = Patches(patch_size)(inputs)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)
    x = TransformerBlock(64, num_heads, ff_dim)(x)
    x = TransformerBlock(64, num_heads, ff_dim)(x)
    x = TransformerBlock(64, num_heads, ff_dim)(x)

    x = layers.Reshape((8, 8, 1024))(x)

    x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = residual_block(x, downsample=False, filters=512)

    x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = residual_block(x, downsample=False, filters=256)

    x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = residual_block(x, downsample=False, filters=64)

    x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = residual_block(x, downsample=False, filters=32)

    x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

Model Class:

class Generator(tf.keras.Model):
    def __init__(self, gen_input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim):
        super(Generator, self).__init__()

        self.patches = Patches(patch_size)
        self.patch_encoder = PatchEncoder(num_patches, projection_dim)

        self.trans1 = TransformerBlock(64, num_heads, ff_dim)
        self.trans2 = TransformerBlock(64, num_heads, ff_dim)
        self.trans3 = TransformerBlock(64, num_heads, ff_dim)
        self.trans4 = TransformerBlock(64, num_heads, ff_dim)

        self.reshape = layers.Reshape((8, 8, num_patches))

        self.conv1 = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.bat1 = layers.BatchNormalization()
        self.relu1 = layers.LeakyReLU()

        self.res = residual_block

        self.conv2 = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.bat2 = layers.BatchNormalization()
        self.relu2 = layers.LeakyReLU()

        self.conv3 = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
        self.bat3 = layers.BatchNormalization()
        self.relu3 = layers.LeakyReLU()

        self.conv4 = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)
        self.bat4 = layers.BatchNormalization()
        self.relu4 = layers.LeakyReLU()

        self.conv5 = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')

    def call(self, input):
        x = self.patches(input)
        x = self.patch_encoder(x)
        x = self.trans1(x)
        x = self.trans2(x)
        x = self.trans3(x)
        x = self.trans4(x)
        x = self.reshape(x)
        x = self.conv1(x)
        x = self.bat1(x)
        x = self.relu1(x)
        x = self.res(x, downsample=False, filters=512)
        x = self.conv2(x)
        x = self.bat2(x)
        x = self.relu2(x)
        x = self.res(x, downsample=False, filters=256)
        x = self.conv3(x)
        x = self.bat3(x)
        x = self.relu3(x)
        x = self.res(x, downsample=False, filters=64)
        x = self.conv4(x)
        x = self.bat4(x)
        x = self.relu4(x)
        x = self.res(x, downsample=False, filters=32)
        x = self.conv5(x)

        return x

Prediction after 1 epoch with model function:


Prediction after 10 epochs with model class:

The class model and model function work interchangibly so there is no difference in the training code. Interestingly the class model still achieves its purpose but looks far worse, especially on the first epoch. I tried this repeatedly for many seeds but it always looked like this. Oh and don’t worry about the blurr on the prediction model, it clears up the more epochs you train for.
Here’s the training code:

tf.config.run_functions_eagerly(False)

train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.map(utils.load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

generator = Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim)
optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Initialize weights prior to loading a checkpoint
_ = generator(tf.random.uniform((BATCH_SIZE, 256, 256, 3), minval=-1, maxval=1, dtype=tf.float32))

def train_step(input_image, target, epoch):
    with tf.device(device):
        with tf.GradientTape() as gen_tape:
            gen_output = generator(input_image, training=True)

            gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))

        generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)

        optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))


def fit(train_ds, epochs):
    for epoch in range(epochs):
        print(f'Epoch: [{epoch}/{epochs}]')

        for n, (input_image, target) in train_ds.enumerate():
            train_step(input_image, target, epoch)

        generator.save_weights(f'{SAVE_PATH}/generator.h5')

fit(train_dataset, EPOCHS)

What is the difference between these 2 models that are in such a contrast to one another? I can follow up with details if asked.

I found the issue. After inspecting both model declarations with a model.summary(), my class declaration had less layers and parameters as the model function. Make sure to use classes for all of your custom layers if you are making a model class. In my case, residual_block was a function that created the layers and passed the input through it but they never get saved, unlike when residual_block is a class, in which case it behaves properly.

TL;DR: model.summary() is a helpful debugging tool