I am working on an autoencoder to remove motion blur from images. I am using a small dataset of 1816 blurry and 1816 sharp images.
This is my autoencoder with 6 layers:
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
seed = 21
random.seed = seed
np.random.seed = seed
# Paths to the good images and the corresponding motion blurred images
good_frames = '/mnt/share/Datasets/BLUR_small/BLUR/sharp'
bad_frames = '/mnt/share/Datasets/BLUR_small/BLUR/motion_blurred'
# Network Parameters
dims = 128
input_shape = (dims, dims, 3)
batch_size = 32
kernel_size = 3
latent_dim = 256
# Below is a custom data loader.
def load_image(file, target_size):
image = tf.keras.preprocessing.image.load_img(file, target_size=target_size)
image = tf.keras.preprocessing.image.img_to_array(image).astype('float32') / 255
return image
clean_frames = []
blurry_frames = []
extensions = ['.jpg', 'jpeg', '.png']
for file in tqdm(sorted(os.listdir(good_frames))):
if any(extension in file for extension in extensions):
file_path = os.path.join(good_frames, file)
clean_frames.append(load_image(file_path, (dims,dims)))
clean_frames = np.array(clean_frames)
for file in tqdm(sorted(os.listdir(bad_frames))):
if any(extension in file for extension in extensions):
file_path = os.path.join(bad_frames, file)
blurry_frames.append(load_image(file_path, (dims,dims)))
blurry_frames = np.array(blurry_frames)
with tf.device('GPU:0'):
inputs = Input(shape = input_shape, name = 'encoder_input')
x = inputs
# Layers of the encoder
x = Conv2D(filters=64, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
x = Conv2D(filters=128, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
x = Conv2D(filters=256, kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
encoder = Model(inputs, latent, name='encoder')
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
x = Conv2DTranspose(filters=256,kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
x = Conv2DTranspose(filters=128,kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
x = Conv2DTranspose(filters=64,kernel_size=kernel_size, strides=2, activation='relu', padding='same')(x)
outputs = Conv2DTranspose(filters=3, kernel_size=kernel_size, activation='sigmoid', padding='same', name='decoder_output')(x)
decoder = Model(latent_inputs, outputs, name='decoder')
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.compile(loss='mse', optimizer='adam',metrics=["acc"])
# Automated Learning Rate reducer
lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
cooldown=0,
patience=5,
verbose=1,
min_lr=0.5e-6)
callbacks = [lr_reducer]
# Define the data generator
data_gen = ImageDataGenerator(rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.1,
horizontal_flip=True)
# Create the flow method to apply data augmentation
train_gen = data_gen.flow(blurry_frames, clean_frames, batch_size=batch_size)
# Begins training
history = autoencoder.fit(train_gen, epochs=100, validation_data=(blurry_frames, clean_frames), batch_size=batch_size, callbacks=callbacks)
And this is how I’m running predictions:
path = '/mnt/share/Datasets/kaggle_motion_blur/motion_blurred/'
images_list = os.listdir(path)
model = load_model('trained_BLUR_Small_5L_autoencoder.h5')
counter = 0
for image in images_list:
image = Image.open(os.path.join(path, image))
image = image.resize((128, 128))
image_array = img_to_array(image)
image_array = np.expand_dims(image_array, axis=0)
prediction = model.predict(image_array)
prediction = np.squeeze(prediction, axis=0)
prediction = prediction * 255.0
prediction = prediction.astype(np.uint8)
prediction = prediction.reshape(128, 128, 3)
img = array_to_img(prediction, scale=False)
# Plotting the input and predicted images side by side
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(image)
axarr[1].imshow(img)
plt.savefig('./output/combined_image_{}.png'.format(counter))
img.save('./output/saved_image_{}.png'.format(counter), format="PNG")
counter+= 1
These are my image results (Discuss does not let me add images to my post so you can see what I’m talking about in the S/O link below:
Check this S/O post for images
This is the JSON output of the training epochs, loss, accuracy and learning rates:
{
"loss": {
"0": 0.0559637062,
"1": 0.0473968722,
"2": 0.0444460958,
"3": 0.041732043,
"4": 0.0403001383,
"5": 0.0384916142,
"6": 0.0371751711,
"7": 0.0357408971,
"8": 0.0342390947,
"9": 0.0336149298,
"10": 0.0324862674,
"11": 0.0320564546,
"12": 0.0311057456,
"13": 0.0306660384,
"14": 0.0306252893,
"15": 0.0300185885,
"16": 0.0297512896,
"17": 0.0291082468,
"18": 0.0286234375,
"19": 0.0283738524,
"20": 0.0281269737,
"21": 0.0278169811,
"22": 0.0271670725,
"23": 0.027053671,
"24": 0.0268216059,
"25": 0.0265594069,
"26": 0.0262372922,
"27": 0.0260510016,
"28": 0.025675552,
"29": 0.0256246217,
"30": 0.0249390211,
"31": 0.0247799307,
"32": 0.0243325885,
"33": 0.0240763351,
"34": 0.0239928197,
"35": 0.0239428878,
"36": 0.0235468745,
"37": 0.0233788602,
"38": 0.0230851565,
"39": 0.0227795113,
"40": 0.0226637572,
"41": 0.0221564285,
"42": 0.0219694152,
"43": 0.0221202765,
"44": 0.0215503499,
"45": 0.0214253683,
"46": 0.0210911259,
"47": 0.0210189149,
"48": 0.0211602859,
"49": 0.0209334362,
"50": 0.0206862725,
"51": 0.0202936586,
"52": 0.020195011,
"53": 0.0203813333,
"54": 0.0199731477,
"55": 0.0199599732,
"56": 0.0198422913,
"57": 0.0200028252,
"58": 0.020069683,
"59": 0.0193502475,
"60": 0.0193256326,
"61": 0.0192854889,
"62": 0.0191295687,
"63": 0.0188199598,
"64": 0.0187446959,
"65": 0.0185292251,
"66": 0.0184431728,
"67": 0.0184016339,
"68": 0.0186514556,
"69": 0.0182920825,
"70": 0.0180315617,
"71": 0.0181470234,
"72": 0.0180766061,
"73": 0.0179452486,
"74": 0.0177276377,
"75": 0.0177405551,
"76": 0.0177125093,
"77": 0.017679546,
"78": 0.0177526604,
"79": 0.0168209467,
"80": 0.016371103,
"81": 0.0162614379,
"82": 0.0160754975,
"83": 0.0160399396,
"84": 0.0158573128,
"85": 0.0158148855,
"86": 0.0157786086,
"87": 0.0156954341,
"88": 0.0155929513,
"89": 0.015328614,
"90": 0.0152604291,
"91": 0.0153146992,
"92": 0.015211042,
"93": 0.0151338875,
"94": 0.0151586076,
"95": 0.0150965769,
"96": 0.0151056759,
"97": 0.0150074158,
"98": 0.0149844224,
"99": 0.0150296595
},
"acc": {
"0": 0.4606803954,
"1": 0.467546761,
"2": 0.4655869901,
"3": 0.4671278894,
"4": 0.4661996067,
"5": 0.4693737924,
"6": 0.4923563302,
"7": 0.5473389626,
"8": 0.6008113027,
"9": 0.6279739141,
"10": 0.6372382045,
"11": 0.6473849416,
"12": 0.650372386,
"13": 0.6581180692,
"14": 0.6512622237,
"15": 0.6611989737,
"16": 0.6612861753,
"17": 0.6657290459,
"18": 0.6735184789,
"19": 0.6756489277,
"20": 0.6758747697,
"21": 0.6802487969,
"22": 0.6829707623,
"23": 0.6869801283,
"24": 0.6898216009,
"25": 0.6890658736,
"26": 0.6886766553,
"27": 0.6913455725,
"28": 0.6947268248,
"29": 0.6963167787,
"30": 0.6970059276,
"31": 0.6994771361,
"32": 0.7033333778,
"33": 0.701754868,
"34": 0.7039641738,
"35": 0.7053239346,
"36": 0.7035363913,
"37": 0.7069551945,
"38": 0.7062308192,
"39": 0.7106642723,
"40": 0.7096898556,
"41": 0.7116783261,
"42": 0.7139574289,
"43": 0.7117571831,
"44": 0.7149085999,
"45": 0.7143892646,
"46": 0.7177314162,
"47": 0.7178704143,
"48": 0.7152007222,
"49": 0.7162288427,
"50": 0.7176439762,
"51": 0.7200078368,
"52": 0.7232849598,
"53": 0.7205632329,
"54": 0.7218416929,
"55": 0.7249743342,
"56": 0.7220694423,
"57": 0.7221901417,
"58": 0.7248255014,
"59": 0.7284849882,
"60": 0.7274199128,
"61": 0.7269214392,
"62": 0.728218019,
"63": 0.7291227579,
"64": 0.7305186987,
"65": 0.7319476008,
"66": 0.7326906323,
"67": 0.732211709,
"68": 0.7313355207,
"69": 0.7331151366,
"70": 0.7346365452,
"71": 0.7359617949,
"72": 0.7359285951,
"73": 0.7339394689,
"74": 0.7341305614,
"75": 0.7364740372,
"76": 0.7368541956,
"77": 0.7347199321,
"78": 0.7379792929,
"79": 0.7438659072,
"80": 0.7457298636,
"81": 0.7466909289,
"82": 0.7475773692,
"83": 0.747853756,
"84": 0.7487729788,
"85": 0.7476714253,
"86": 0.7501927614,
"87": 0.7497540116,
"88": 0.7511804104,
"89": 0.752410531,
"90": 0.7530171275,
"91": 0.7530553937,
"92": 0.7526908517,
"93": 0.7528964281,
"94": 0.7526913285,
"95": 0.753929615,
"96": 0.7543320656,
"97": 0.7543010116,
"98": 0.7544336319,
"99": 0.7536581159
},
"val_loss": {
"0": 0.0491532497,
"1": 0.0428819433,
"2": 0.0403589457,
"3": 0.0374303386,
"4": 0.0374171548,
"5": 0.0354377925,
"6": 0.0335176662,
"7": 0.0325156339,
"8": 0.0312758237,
"9": 0.0298724882,
"10": 0.0290645473,
"11": 0.0287973378,
"12": 0.0283008274,
"13": 0.0279917233,
"14": 0.0277446155,
"15": 0.0267075617,
"16": 0.0269869734,
"17": 0.0260658357,
"18": 0.0261003803,
"19": 0.0256356988,
"20": 0.024952108,
"21": 0.0248472486,
"22": 0.0243210737,
"23": 0.024381537,
"24": 0.0239129663,
"25": 0.0230953004,
"26": 0.0232715681,
"27": 0.0230326615,
"28": 0.0231741108,
"29": 0.0226652324,
"30": 0.0221607611,
"31": 0.0218544509,
"32": 0.0217084009,
"33": 0.0216371976,
"34": 0.0214716587,
"35": 0.0209531747,
"36": 0.0207797866,
"37": 0.0206740964,
"38": 0.0203931443,
"39": 0.0200728513,
"40": 0.0201499071,
"41": 0.0196091868,
"42": 0.0195282493,
"43": 0.0196591243,
"44": 0.0192415901,
"45": 0.0187332761,
"46": 0.018796768,
"47": 0.0186291263,
"48": 0.0188825149,
"49": 0.0184612796,
"50": 0.018214304,
"51": 0.018066233,
"52": 0.0180678181,
"53": 0.0183129609,
"54": 0.0177096091,
"55": 0.017653849,
"56": 0.0177150834,
"57": 0.0176380556,
"58": 0.0175104905,
"59": 0.018088758,
"60": 0.0175425448,
"61": 0.0171913933,
"62": 0.0168936346,
"63": 0.0169047564,
"64": 0.0168141816,
"65": 0.0170337521,
"66": 0.016725529,
"67": 0.0169604756,
"68": 0.0165361054,
"69": 0.0165681373,
"70": 0.0163681172,
"71": 0.0164794382,
"72": 0.0163597036,
"73": 0.0161324181,
"74": 0.0161470957,
"75": 0.0162882674,
"76": 0.0163499508,
"77": 0.0161468163,
"78": 0.0160854217,
"79": 0.0153935291,
"80": 0.0151251443,
"81": 0.0148611758,
"82": 0.0150120985,
"83": 0.0146467723,
"84": 0.0146741169,
"85": 0.0146548087,
"86": 0.0147513337,
"87": 0.0145894634,
"88": 0.0145941339,
"89": 0.0144529464,
"90": 0.0142841199,
"91": 0.0142779676,
"92": 0.0142102633,
"93": 0.0142730884,
"94": 0.014272172,
"95": 0.014208761,
"96": 0.0141810579,
"97": 0.0141390217,
"98": 0.0140905399,
"99": 0.0141332904
},
"val_acc": {
"0": 0.4686006606,
"1": 0.4674130976,
"2": 0.4529431164,
"3": 0.4659602046,
"4": 0.4662201703,
"5": 0.4551965594,
"6": 0.5195012093,
"7": 0.5835120082,
"8": 0.6048460603,
"9": 0.6395056844,
"10": 0.6545616388,
"11": 0.662117064,
"12": 0.6643206477,
"13": 0.6685555577,
"14": 0.6260975599,
"15": 0.6835178137,
"16": 0.6557799578,
"17": 0.6855746508,
"18": 0.6705876589,
"19": 0.6919317245,
"20": 0.6861786842,
"21": 0.6914568543,
"22": 0.6953245401,
"23": 0.690965116,
"24": 0.7035027146,
"25": 0.705488801,
"26": 0.7045339942,
"27": 0.7091979384,
"28": 0.7046523094,
"29": 0.7073927522,
"30": 0.7071937323,
"31": 0.7159055471,
"32": 0.7098062038,
"33": 0.7090770006,
"34": 0.7168362737,
"35": 0.7215492129,
"36": 0.7177282572,
"37": 0.7153795362,
"38": 0.7185223699,
"39": 0.7179871202,
"40": 0.7176033854,
"41": 0.7234653234,
"42": 0.713513732,
"43": 0.7200869322,
"44": 0.7262759805,
"45": 0.7287439704,
"46": 0.7239111066,
"47": 0.7238330245,
"48": 0.7275136709,
"49": 0.7328174114,
"50": 0.7293641567,
"51": 0.7332755923,
"52": 0.7260206938,
"53": 0.7235223651,
"54": 0.7341402769,
"55": 0.7222154737,
"56": 0.7359579802,
"57": 0.7379191518,
"58": 0.7348037958,
"59": 0.7349512577,
"60": 0.7326335907,
"61": 0.73770684,
"62": 0.7337531447,
"63": 0.7356146574,
"64": 0.736933589,
"65": 0.7397321463,
"66": 0.7403889894,
"67": 0.7353423238,
"68": 0.7442782521,
"69": 0.7409216762,
"70": 0.7454639077,
"71": 0.7416924238,
"72": 0.7463977337,
"73": 0.7433106303,
"74": 0.7427838445,
"75": 0.739207983,
"76": 0.7410522699,
"77": 0.7474275231,
"78": 0.7450863123,
"79": 0.7468729615,
"80": 0.7513384223,
"81": 0.7542303801,
"82": 0.7525380254,
"83": 0.7562230825,
"84": 0.7542062998,
"85": 0.7541561723,
"86": 0.7559394836,
"87": 0.7560024261,
"88": 0.7567353249,
"89": 0.7569227815,
"90": 0.7569333911,
"91": 0.7582103014,
"92": 0.7579286695,
"93": 0.7579852939,
"94": 0.7573221326,
"95": 0.7593886256,
"96": 0.7584165335,
"97": 0.7586974502,
"98": 0.7584565282,
"99": 0.7585714459
},
"lr": {
"0": 0.001,
"1": 0.001,
"2": 0.001,
"3": 0.001,
"4": 0.001,
"5": 0.001,
"6": 0.001,
"7": 0.001,
"8": 0.001,
"9": 0.001,
"10": 0.001,
"11": 0.001,
"12": 0.001,
"13": 0.001,
"14": 0.001,
"15": 0.001,
"16": 0.001,
"17": 0.001,
"18": 0.001,
"19": 0.001,
"20": 0.001,
"21": 0.001,
"22": 0.001,
"23": 0.001,
"24": 0.001,
"25": 0.001,
"26": 0.001,
"27": 0.001,
"28": 0.001,
"29": 0.001,
"30": 0.001,
"31": 0.001,
"32": 0.001,
"33": 0.001,
"34": 0.001,
"35": 0.001,
"36": 0.001,
"37": 0.001,
"38": 0.001,
"39": 0.001,
"40": 0.001,
"41": 0.001,
"42": 0.001,
"43": 0.001,
"44": 0.001,
"45": 0.001,
"46": 0.001,
"47": 0.001,
"48": 0.001,
"49": 0.001,
"50": 0.001,
"51": 0.001,
"52": 0.001,
"53": 0.001,
"54": 0.001,
"55": 0.001,
"56": 0.001,
"57": 0.001,
"58": 0.001,
"59": 0.001,
"60": 0.001,
"61": 0.001,
"62": 0.001,
"63": 0.001,
"64": 0.001,
"65": 0.001,
"66": 0.001,
"67": 0.001,
"68": 0.001,
"69": 0.001,
"70": 0.001,
"71": 0.001,
"72": 0.001,
"73": 0.001,
"74": 0.001,
"75": 0.001,
"76": 0.001,
"77": 0.001,
"78": 0.001,
"79": 0.0003162278,
"80": 0.0003162278,
"81": 0.0003162278,
"82": 0.0003162278,
"83": 0.0003162278,
"84": 0.0003162278,
"85": 0.0003162278,
"86": 0.0003162278,
"87": 0.0003162278,
"88": 0.0003162278,
"89": 0.0001,
"90": 0.0001,
"91": 0.0001,
"92": 0.0001,
"93": 0.0001,
"94": 0.0001,
"95": 0.0001,
"96": 0.0000316228,
"97": 0.0000316228,
"98": 0.0000316228,
"99": 0.0000316228
}
}
When I change my architecture to this:
inputs = Input(shape = input_shape, name = 'encoder_input')
x = inputs
# Layers of the encoder
x = Conv2D(filters=64, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2D(filters=128, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2D(filters=256, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
# Additional layers
x = Conv2D(filters=512, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2D(filters=1024, kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name='latent_vector')(x)
encoder = Model(inputs, latent, name='encoder')
latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
x = Dense(shape[1]*shape[2]*shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
# Layers of the dencoder
# Additional layers
x = Conv2DTranspose(filters=1024,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2DTranspose(filters=512,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2DTranspose(filters=256,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2DTranspose(filters=128,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
x = Conv2DTranspose(filters=64,kernel_size=kernel_size, strides=2, activation='relu', padding=padding)(x)
outputs = Conv2DTranspose(filters=3, kernel_size=kernel_size, activation='sigmoid', padding=padding, name='decoder_output')(x)
And change the dimensions to 1024x1024x3, these are my results: