This is it
def create_simple_cnn():
“”"
Creates a simple CNN model using the functional API for binary image classification.
Parameters:None
Returns:
- model: a tf.keras Model instance.
"""
input_layer = layers.Input(shape=(IMG_SIZE,IMG_SIZE,1))
x = augment_pipeline()(input_layer)
x = layers.Conv2D(8, (3,3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(16, (3,3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(32, (3,3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(64, (3, 3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(128, (3, 3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(128, (3, 3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(128, (3, 3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(128, (3, 3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(256, (3, 3), use_bias=False)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.Activation('relu')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
output = layers.Dense(1,activation='sigmoid')(x)
model = models.Model(inputs=input_layer, outputs=output)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
# Compiling the model
model.compile(optimizer=optimizer,
loss='binary_crossentropy',
metrics=['accuracy'])
return model
with strategy.scope():
model = create_simple_cnn()
history = model.fit(
train_dataset,
epochs=15,
callbacks=[cp_callback,reduce_lr],
validation_data=valid_dataset
)