Autoencoder returning broken results for deblurring

I am working on an autoencoder to remove motion blur from images. I am using a small dataset of 1816 blurry and 1816 sharp images.

This is my autoencoder with 6 layers:

physical_devices = tf.config.list_physical_devices('GPU')

if physical_devices:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)

seed = 21
random.seed = seed
np.random.seed = seed

# Paths to the good images and the corresponding motion blurred images

good_frames = '/mnt/share/Datasets/BLUR_small/BLUR/sharp'
bad_frames = '/mnt/share/Datasets/BLUR_small/BLUR/motion_blurred'

# Network Parameters
dims = 128
input_shape = (dims, dims, 3)
batch_size = 32
kernel_size = 3
latent_dim = 256

# Below is a custom data loader.
def load_image(file, target_size):
    image = tf.keras.preprocessing.image.load_img(file, target_size=target_size)
    image = tf.keras.preprocessing.image.img_to_array(image).astype('float32') / 255
    return image

clean_frames = []
blurry_frames = []
extensions = ['.jpg', 'jpeg', '.png']

for file in tqdm(sorted(os.listdir(good_frames))):
    if any(extension in file for extension in extensions):
        file_path = os.path.join(good_frames, file)
        clean_frames.append(load_image(file_path, (dims,dims)))

clean_frames = np.array(clean_frames)

for file in tqdm(sorted(os.listdir(bad_frames))):
    if any(extension in file for extension in extensions):
        file_path = os.path.join(bad_frames, file)
        blurry_frames.append(load_image(file_path, (dims,dims)))

blurry_frames = np.array(blurry_frames)


with tf.device('GPU:0'):
    inputs = Input(shape = input_shape, name = 'encoder_input')

    x = inputs

    # Layers of the encoder

    x = Conv2D(filters=64, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
    x = Conv2D(filters=128, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
    x = Conv2D(filters=256, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)

    shape = K.int_shape(x)
    x = Flatten()(x)
    latent = Dense(latent_dim, name='latent_vector')(x)

    encoder = Model(inputs, latent, name='encoder')

    latent_inputs = Input(shape=(latent_dim,), name='decoder_input')

    x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
    x = Reshape((shape[1], shape[2], shape[3]))(x)

    x = Conv2DTranspose(filters=256,kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
    x = Conv2DTranspose(filters=128,kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
    x = Conv2DTranspose(filters=64,kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)


    outputs = Conv2DTranspose(filters=3, kernel_size=kernel_size, activation='sigmoid', padding='same', name='decoder_output')(x)

    decoder = Model(latent_inputs, outputs, name='decoder')

    autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')

    autoencoder.compile(loss='mse', optimizer='adam',metrics=["acc"])

    # Automated Learning Rate reducer
    lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                                cooldown=0,
                                patience=5,
                                verbose=1,
                                min_lr=0.5e-6)

    callbacks = [lr_reducer]

    # Define the data generator
    data_gen = ImageDataGenerator(rotation_range=20,
                                width_shift_range=0.1,
                                height_shift_range=0.1,
                                zoom_range=0.1,
                                horizontal_flip=True)

    # Create the flow method to apply data augmentation
    train_gen = data_gen.flow(blurry_frames, clean_frames, batch_size=batch_size)

    # Begins training
    history = autoencoder.fit(train_gen, epochs=100, validation_data=(blurry_frames, clean_frames), batch_size=batch_size, callbacks=callbacks)

And this is how I’m running predictions:

path = '/mnt/share/Datasets/kaggle_motion_blur/motion_blurred/'
images_list = os.listdir(path)

model = load_model('trained_BLUR_Small_5L_autoencoder.h5')
counter = 0

for image in images_list:

    image = Image.open(os.path.join(path, image))
    image = image.resize((128, 128))
    image_array = img_to_array(image)
    image_array = np.expand_dims(image_array, axis=0)
    prediction = model.predict(image_array)
    prediction = np.squeeze(prediction, axis=0)
    prediction = prediction * 255.0
    prediction = prediction.astype(np.uint8)
    prediction = prediction.reshape(128, 128, 3)
    img = array_to_img(prediction, scale=False)

     # Plotting the input and predicted images side by side
    f, axarr = plt.subplots(1, 2)
    axarr[0].imshow(image)
    axarr[1].imshow(img)
    plt.savefig('./output/combined_image_{}.png'.format(counter))
    img.save('./output/saved_image_{}.png'.format(counter), format="PNG")
    counter+= 1

These are my image results (Discuss does not let me add images to my post so you can see what I’m talking about in the S/O link below:

Check this S/O post for images

This is the JSON output of the training epochs, loss, accuracy and learning rates:

{
  "loss": {
    "0": 0.0559637062,
    "1": 0.0473968722,
    "2": 0.0444460958,
    "3": 0.041732043,
    "4": 0.0403001383,
    "5": 0.0384916142,
    "6": 0.0371751711,
    "7": 0.0357408971,
    "8": 0.0342390947,
    "9": 0.0336149298,
    "10": 0.0324862674,
    "11": 0.0320564546,
    "12": 0.0311057456,
    "13": 0.0306660384,
    "14": 0.0306252893,
    "15": 0.0300185885,
    "16": 0.0297512896,
    "17": 0.0291082468,
    "18": 0.0286234375,
    "19": 0.0283738524,
    "20": 0.0281269737,
    "21": 0.0278169811,
    "22": 0.0271670725,
    "23": 0.027053671,
    "24": 0.0268216059,
    "25": 0.0265594069,
    "26": 0.0262372922,
    "27": 0.0260510016,
    "28": 0.025675552,
    "29": 0.0256246217,
    "30": 0.0249390211,
    "31": 0.0247799307,
    "32": 0.0243325885,
    "33": 0.0240763351,
    "34": 0.0239928197,
    "35": 0.0239428878,
    "36": 0.0235468745,
    "37": 0.0233788602,
    "38": 0.0230851565,
    "39": 0.0227795113,
    "40": 0.0226637572,
    "41": 0.0221564285,
    "42": 0.0219694152,
    "43": 0.0221202765,
    "44": 0.0215503499,
    "45": 0.0214253683,
    "46": 0.0210911259,
    "47": 0.0210189149,
    "48": 0.0211602859,
    "49": 0.0209334362,
    "50": 0.0206862725,
    "51": 0.0202936586,
    "52": 0.020195011,
    "53": 0.0203813333,
    "54": 0.0199731477,
    "55": 0.0199599732,
    "56": 0.0198422913,
    "57": 0.0200028252,
    "58": 0.020069683,
    "59": 0.0193502475,
    "60": 0.0193256326,
    "61": 0.0192854889,
    "62": 0.0191295687,
    "63": 0.0188199598,
    "64": 0.0187446959,
    "65": 0.0185292251,
    "66": 0.0184431728,
    "67": 0.0184016339,
    "68": 0.0186514556,
    "69": 0.0182920825,
    "70": 0.0180315617,
    "71": 0.0181470234,
    "72": 0.0180766061,
    "73": 0.0179452486,
    "74": 0.0177276377,
    "75": 0.0177405551,
    "76": 0.0177125093,
    "77": 0.017679546,
    "78": 0.0177526604,
    "79": 0.0168209467,
    "80": 0.016371103,
    "81": 0.0162614379,
    "82": 0.0160754975,
    "83": 0.0160399396,
    "84": 0.0158573128,
    "85": 0.0158148855,
    "86": 0.0157786086,
    "87": 0.0156954341,
    "88": 0.0155929513,
    "89": 0.015328614,
    "90": 0.0152604291,
    "91": 0.0153146992,
    "92": 0.015211042,
    "93": 0.0151338875,
    "94": 0.0151586076,
    "95": 0.0150965769,
    "96": 0.0151056759,
    "97": 0.0150074158,
    "98": 0.0149844224,
    "99": 0.0150296595
  },
  "acc": {
    "0": 0.4606803954,
    "1": 0.467546761,
    "2": 0.4655869901,
    "3": 0.4671278894,
    "4": 0.4661996067,
    "5": 0.4693737924,
    "6": 0.4923563302,
    "7": 0.5473389626,
    "8": 0.6008113027,
    "9": 0.6279739141,
    "10": 0.6372382045,
    "11": 0.6473849416,
    "12": 0.650372386,
    "13": 0.6581180692,
    "14": 0.6512622237,
    "15": 0.6611989737,
    "16": 0.6612861753,
    "17": 0.6657290459,
    "18": 0.6735184789,
    "19": 0.6756489277,
    "20": 0.6758747697,
    "21": 0.6802487969,
    "22": 0.6829707623,
    "23": 0.6869801283,
    "24": 0.6898216009,
    "25": 0.6890658736,
    "26": 0.6886766553,
    "27": 0.6913455725,
    "28": 0.6947268248,
    "29": 0.6963167787,
    "30": 0.6970059276,
    "31": 0.6994771361,
    "32": 0.7033333778,
    "33": 0.701754868,
    "34": 0.7039641738,
    "35": 0.7053239346,
    "36": 0.7035363913,
    "37": 0.7069551945,
    "38": 0.7062308192,
    "39": 0.7106642723,
    "40": 0.7096898556,
    "41": 0.7116783261,
    "42": 0.7139574289,
    "43": 0.7117571831,
    "44": 0.7149085999,
    "45": 0.7143892646,
    "46": 0.7177314162,
    "47": 0.7178704143,
    "48": 0.7152007222,
    "49": 0.7162288427,
    "50": 0.7176439762,
    "51": 0.7200078368,
    "52": 0.7232849598,
    "53": 0.7205632329,
    "54": 0.7218416929,
    "55": 0.7249743342,
    "56": 0.7220694423,
    "57": 0.7221901417,
    "58": 0.7248255014,
    "59": 0.7284849882,
    "60": 0.7274199128,
    "61": 0.7269214392,
    "62": 0.728218019,
    "63": 0.7291227579,
    "64": 0.7305186987,
    "65": 0.7319476008,
    "66": 0.7326906323,
    "67": 0.732211709,
    "68": 0.7313355207,
    "69": 0.7331151366,
    "70": 0.7346365452,
    "71": 0.7359617949,
    "72": 0.7359285951,
    "73": 0.7339394689,
    "74": 0.7341305614,
    "75": 0.7364740372,
    "76": 0.7368541956,
    "77": 0.7347199321,
    "78": 0.7379792929,
    "79": 0.7438659072,
    "80": 0.7457298636,
    "81": 0.7466909289,
    "82": 0.7475773692,
    "83": 0.747853756,
    "84": 0.7487729788,
    "85": 0.7476714253,
    "86": 0.7501927614,
    "87": 0.7497540116,
    "88": 0.7511804104,
    "89": 0.752410531,
    "90": 0.7530171275,
    "91": 0.7530553937,
    "92": 0.7526908517,
    "93": 0.7528964281,
    "94": 0.7526913285,
    "95": 0.753929615,
    "96": 0.7543320656,
    "97": 0.7543010116,
    "98": 0.7544336319,
    "99": 0.7536581159
  },
  "val_loss": {
    "0": 0.0491532497,
    "1": 0.0428819433,
    "2": 0.0403589457,
    "3": 0.0374303386,
    "4": 0.0374171548,
    "5": 0.0354377925,
    "6": 0.0335176662,
    "7": 0.0325156339,
    "8": 0.0312758237,
    "9": 0.0298724882,
    "10": 0.0290645473,
    "11": 0.0287973378,
    "12": 0.0283008274,
    "13": 0.0279917233,
    "14": 0.0277446155,
    "15": 0.0267075617,
    "16": 0.0269869734,
    "17": 0.0260658357,
    "18": 0.0261003803,
    "19": 0.0256356988,
    "20": 0.024952108,
    "21": 0.0248472486,
    "22": 0.0243210737,
    "23": 0.024381537,
    "24": 0.0239129663,
    "25": 0.0230953004,
    "26": 0.0232715681,
    "27": 0.0230326615,
    "28": 0.0231741108,
    "29": 0.0226652324,
    "30": 0.0221607611,
    "31": 0.0218544509,
    "32": 0.0217084009,
    "33": 0.0216371976,
    "34": 0.0214716587,
    "35": 0.0209531747,
    "36": 0.0207797866,
    "37": 0.0206740964,
    "38": 0.0203931443,
    "39": 0.0200728513,
    "40": 0.0201499071,
    "41": 0.0196091868,
    "42": 0.0195282493,
    "43": 0.0196591243,
    "44": 0.0192415901,
    "45": 0.0187332761,
    "46": 0.018796768,
    "47": 0.0186291263,
    "48": 0.0188825149,
    "49": 0.0184612796,
    "50": 0.018214304,
    "51": 0.018066233,
    "52": 0.0180678181,
    "53": 0.0183129609,
    "54": 0.0177096091,
    "55": 0.017653849,
    "56": 0.0177150834,
    "57": 0.0176380556,
    "58": 0.0175104905,
    "59": 0.018088758,
    "60": 0.0175425448,
    "61": 0.0171913933,
    "62": 0.0168936346,
    "63": 0.0169047564,
    "64": 0.0168141816,
    "65": 0.0170337521,
    "66": 0.016725529,
    "67": 0.0169604756,
    "68": 0.0165361054,
    "69": 0.0165681373,
    "70": 0.0163681172,
    "71": 0.0164794382,
    "72": 0.0163597036,
    "73": 0.0161324181,
    "74": 0.0161470957,
    "75": 0.0162882674,
    "76": 0.0163499508,
    "77": 0.0161468163,
    "78": 0.0160854217,
    "79": 0.0153935291,
    "80": 0.0151251443,
    "81": 0.0148611758,
    "82": 0.0150120985,
    "83": 0.0146467723,
    "84": 0.0146741169,
    "85": 0.0146548087,
    "86": 0.0147513337,
    "87": 0.0145894634,
    "88": 0.0145941339,
    "89": 0.0144529464,
    "90": 0.0142841199,
    "91": 0.0142779676,
    "92": 0.0142102633,
    "93": 0.0142730884,
    "94": 0.014272172,
    "95": 0.014208761,
    "96": 0.0141810579,
    "97": 0.0141390217,
    "98": 0.0140905399,
    "99": 0.0141332904
  },
  "val_acc": {
    "0": 0.4686006606,
    "1": 0.4674130976,
    "2": 0.4529431164,
    "3": 0.4659602046,
    "4": 0.4662201703,
    "5": 0.4551965594,
    "6": 0.5195012093,
    "7": 0.5835120082,
    "8": 0.6048460603,
    "9": 0.6395056844,
    "10": 0.6545616388,
    "11": 0.662117064,
    "12": 0.6643206477,
    "13": 0.6685555577,
    "14": 0.6260975599,
    "15": 0.6835178137,
    "16": 0.6557799578,
    "17": 0.6855746508,
    "18": 0.6705876589,
    "19": 0.6919317245,
    "20": 0.6861786842,
    "21": 0.6914568543,
    "22": 0.6953245401,
    "23": 0.690965116,
    "24": 0.7035027146,
    "25": 0.705488801,
    "26": 0.7045339942,
    "27": 0.7091979384,
    "28": 0.7046523094,
    "29": 0.7073927522,
    "30": 0.7071937323,
    "31": 0.7159055471,
    "32": 0.7098062038,
    "33": 0.7090770006,
    "34": 0.7168362737,
    "35": 0.7215492129,
    "36": 0.7177282572,
    "37": 0.7153795362,
    "38": 0.7185223699,
    "39": 0.7179871202,
    "40": 0.7176033854,
    "41": 0.7234653234,
    "42": 0.713513732,
    "43": 0.7200869322,
    "44": 0.7262759805,
    "45": 0.7287439704,
    "46": 0.7239111066,
    "47": 0.7238330245,
    "48": 0.7275136709,
    "49": 0.7328174114,
    "50": 0.7293641567,
    "51": 0.7332755923,
    "52": 0.7260206938,
    "53": 0.7235223651,
    "54": 0.7341402769,
    "55": 0.7222154737,
    "56": 0.7359579802,
    "57": 0.7379191518,
    "58": 0.7348037958,
    "59": 0.7349512577,
    "60": 0.7326335907,
    "61": 0.73770684,
    "62": 0.7337531447,
    "63": 0.7356146574,
    "64": 0.736933589,
    "65": 0.7397321463,
    "66": 0.7403889894,
    "67": 0.7353423238,
    "68": 0.7442782521,
    "69": 0.7409216762,
    "70": 0.7454639077,
    "71": 0.7416924238,
    "72": 0.7463977337,
    "73": 0.7433106303,
    "74": 0.7427838445,
    "75": 0.739207983,
    "76": 0.7410522699,
    "77": 0.7474275231,
    "78": 0.7450863123,
    "79": 0.7468729615,
    "80": 0.7513384223,
    "81": 0.7542303801,
    "82": 0.7525380254,
    "83": 0.7562230825,
    "84": 0.7542062998,
    "85": 0.7541561723,
    "86": 0.7559394836,
    "87": 0.7560024261,
    "88": 0.7567353249,
    "89": 0.7569227815,
    "90": 0.7569333911,
    "91": 0.7582103014,
    "92": 0.7579286695,
    "93": 0.7579852939,
    "94": 0.7573221326,
    "95": 0.7593886256,
    "96": 0.7584165335,
    "97": 0.7586974502,
    "98": 0.7584565282,
    "99": 0.7585714459
  },
  "lr": {
    "0": 0.001,
    "1": 0.001,
    "2": 0.001,
    "3": 0.001,
    "4": 0.001,
    "5": 0.001,
    "6": 0.001,
    "7": 0.001,
    "8": 0.001,
    "9": 0.001,
    "10": 0.001,
    "11": 0.001,
    "12": 0.001,
    "13": 0.001,
    "14": 0.001,
    "15": 0.001,
    "16": 0.001,
    "17": 0.001,
    "18": 0.001,
    "19": 0.001,
    "20": 0.001,
    "21": 0.001,
    "22": 0.001,
    "23": 0.001,
    "24": 0.001,
    "25": 0.001,
    "26": 0.001,
    "27": 0.001,
    "28": 0.001,
    "29": 0.001,
    "30": 0.001,
    "31": 0.001,
    "32": 0.001,
    "33": 0.001,
    "34": 0.001,
    "35": 0.001,
    "36": 0.001,
    "37": 0.001,
    "38": 0.001,
    "39": 0.001,
    "40": 0.001,
    "41": 0.001,
    "42": 0.001,
    "43": 0.001,
    "44": 0.001,
    "45": 0.001,
    "46": 0.001,
    "47": 0.001,
    "48": 0.001,
    "49": 0.001,
    "50": 0.001,
    "51": 0.001,
    "52": 0.001,
    "53": 0.001,
    "54": 0.001,
    "55": 0.001,
    "56": 0.001,
    "57": 0.001,
    "58": 0.001,
    "59": 0.001,
    "60": 0.001,
    "61": 0.001,
    "62": 0.001,
    "63": 0.001,
    "64": 0.001,
    "65": 0.001,
    "66": 0.001,
    "67": 0.001,
    "68": 0.001,
    "69": 0.001,
    "70": 0.001,
    "71": 0.001,
    "72": 0.001,
    "73": 0.001,
    "74": 0.001,
    "75": 0.001,
    "76": 0.001,
    "77": 0.001,
    "78": 0.001,
    "79": 0.0003162278,
    "80": 0.0003162278,
    "81": 0.0003162278,
    "82": 0.0003162278,
    "83": 0.0003162278,
    "84": 0.0003162278,
    "85": 0.0003162278,
    "86": 0.0003162278,
    "87": 0.0003162278,
    "88": 0.0003162278,
    "89": 0.0001,
    "90": 0.0001,
    "91": 0.0001,
    "92": 0.0001,
    "93": 0.0001,
    "94": 0.0001,
    "95": 0.0001,
    "96": 0.0000316228,
    "97": 0.0000316228,
    "98": 0.0000316228,
    "99": 0.0000316228
  }
}

When I change my architecture to this:

inputs = Input(shape = input_shape, name = 'encoder_input')

    x = inputs

    # Layers of the encoder

    x = Conv2D(filters=64, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
    x = Conv2D(filters=128, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
    x = Conv2D(filters=256, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)

    # Additional layers
    x = Conv2D(filters=512, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
    x = Conv2D(filters=1024, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)


    shape = K.int_shape(x)
    x = Flatten()(x)
    latent = Dense(latent_dim, name='latent_vector')(x)

    encoder = Model(inputs, latent, name='encoder')

    latent_inputs = Input(shape=(latent_dim,), name='decoder_input')

    x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
    x = Reshape((shape[1], shape[2], shape[3]))(x)

    # Layers of the dencoder

    # Additional layers
    x = Conv2DTranspose(filters=1024,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
    x = Conv2DTranspose(filters=512,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)

    x = Conv2DTranspose(filters=256,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
    x = Conv2DTranspose(filters=128,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
    x = Conv2DTranspose(filters=64,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)


    outputs = Conv2DTranspose(filters=3, kernel_size=kernel_size, activation='sigmoid', padding=padding, name='decoder_output')(x)

And change the dimensions to 1024x1024x3, these are my results:

Check this S/O post for images