Take a look at the 2 model declarations below. Both are valid and functionally equivalent but if you take a look at the images, you’ll see the difference between an autoencoder trained for 1 epoch with the tf.keras.model declaration, and the other is an autoencoder defined by a class, trained for 10 epochs.
Model function:
def Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim):
inputs = layers.Input(shape=(256, 256, 3))
patches = Patches(patch_size)(inputs)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
x = TransformerBlock(64, num_heads, ff_dim)(encoded_patches)
x = TransformerBlock(64, num_heads, ff_dim)(x)
x = TransformerBlock(64, num_heads, ff_dim)(x)
x = TransformerBlock(64, num_heads, ff_dim)(x)
x = layers.Reshape((8, 8, 1024))(x)
x = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = residual_block(x, downsample=False, filters=512)
x = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = residual_block(x, downsample=False, filters=256)
x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = residual_block(x, downsample=False, filters=64)
x = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)
x = residual_block(x, downsample=False, filters=32)
x = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')(x)
return tf.keras.Model(inputs=inputs, outputs=x)
Model Class:
class Generator(tf.keras.Model):
def __init__(self, gen_input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim):
super(Generator, self).__init__()
self.patches = Patches(patch_size)
self.patch_encoder = PatchEncoder(num_patches, projection_dim)
self.trans1 = TransformerBlock(64, num_heads, ff_dim)
self.trans2 = TransformerBlock(64, num_heads, ff_dim)
self.trans3 = TransformerBlock(64, num_heads, ff_dim)
self.trans4 = TransformerBlock(64, num_heads, ff_dim)
self.reshape = layers.Reshape((8, 8, num_patches))
self.conv1 = layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.bat1 = layers.BatchNormalization()
self.relu1 = layers.LeakyReLU()
self.res = residual_block
self.conv2 = layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.bat2 = layers.BatchNormalization()
self.relu2 = layers.LeakyReLU()
self.conv3 = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)
self.bat3 = layers.BatchNormalization()
self.relu3 = layers.LeakyReLU()
self.conv4 = layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False)
self.bat4 = layers.BatchNormalization()
self.relu4 = layers.LeakyReLU()
self.conv5 = layers.Conv2D(3, (3, 3), strides=(1, 1), padding='same', use_bias=False, activation='tanh')
def call(self, input):
x = self.patches(input)
x = self.patch_encoder(x)
x = self.trans1(x)
x = self.trans2(x)
x = self.trans3(x)
x = self.trans4(x)
x = self.reshape(x)
x = self.conv1(x)
x = self.bat1(x)
x = self.relu1(x)
x = self.res(x, downsample=False, filters=512)
x = self.conv2(x)
x = self.bat2(x)
x = self.relu2(x)
x = self.res(x, downsample=False, filters=256)
x = self.conv3(x)
x = self.bat3(x)
x = self.relu3(x)
x = self.res(x, downsample=False, filters=64)
x = self.conv4(x)
x = self.bat4(x)
x = self.relu4(x)
x = self.res(x, downsample=False, filters=32)
x = self.conv5(x)
return x
Prediction after 1 epoch with model function:
Prediction after 10 epochs with model class:
The class model and model function work interchangibly so there is no difference in the training code. Interestingly the class model still achieves its purpose but looks far worse, especially on the first epoch. I tried this repeatedly for many seeds but it always looked like this. Oh and don’t worry about the blurr on the prediction model, it clears up the more epochs you train for.
Here’s the training code:
tf.config.run_functions_eagerly(False)
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.map(utils.load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
generator = Generator(input_shape, patch_size, num_patches, projection_dim, num_heads, ff_dim)
optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
# Initialize weights prior to loading a checkpoint
_ = generator(tf.random.uniform((BATCH_SIZE, 256, 256, 3), minval=-1, maxval=1, dtype=tf.float32))
def train_step(input_image, target, epoch):
with tf.device(device):
with tf.GradientTape() as gen_tape:
gen_output = generator(input_image, training=True)
gen_total_loss = tf.reduce_mean(tf.abs(target - gen_output))
generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
def fit(train_ds, epochs):
for epoch in range(epochs):
print(f'Epoch: [{epoch}/{epochs}]')
for n, (input_image, target) in train_ds.enumerate():
train_step(input_image, target, epoch)
generator.save_weights(f'{SAVE_PATH}/generator.h5')
fit(train_dataset, EPOCHS)
What is the difference between these 2 models that are in such a contrast to one another? I can follow up with details if asked.