def bn_act(x, act=True):
x = keras.layers.BatchNormalization()(x)
if act == True:
x = keras.layers.Activation("relu")(x)
return x
def conv_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
conv = bn_act(x)
conv = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides)(conv)
return conv
def stem(x, filters, kernel_size=(3, 3), padding="same", strides=1):
conv = keras.layers.Conv2D(filters, kernel_size, padding=padding, strides=strides)(x)
conv = conv_block(conv, filters, kernel_size=kernel_size, padding=padding, strides=strides)
#identity mapping
shortcut = keras.layers.Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
shortcut = bn_act(shortcut, act=False)
output = keras.layers.Add()([conv, shortcut])
return output
def residual_block(x, filters, kernel_size=(3, 3), padding="same", strides=1):
res = conv_block(x, filters, kernel_size=kernel_size, padding=padding, strides=strides)
res = conv_block(res, filters, kernel_size=kernel_size, padding=padding, strides=1)
shortcut = keras.layers.Conv2D(filters, kernel_size=(1, 1), padding=padding, strides=strides)(x)
shortcut = bn_act(shortcut, act=False)
output = keras.layers.Add()([shortcut, res])
return output
def upsample_concat_block(x, xskip):
u = keras.layers.UpSampling2D((2, 2))(x)
c = keras.layers.Concatenate()([u, xskip])
return c
def ResUNet():
f = [16, 32, 64, 128, 256]
inputs = keras.layers.Input((image_size, image_size, 3))
## Encoder
e0 = inputs
e1 = stem(e0, f[0])
e2 = residual_block(e1, f[1], strides=2)
e3 = residual_block(e2, f[2], strides=2)
e4 = residual_block(e3, f[3], strides=2)
## Bridge
b0 = residual_block(e4, f[4], strides=2)
## Decoder
u1 = upsample_concat_block(b0, e4)
d1 = residual_block(u1, f[4])
u2 = upsample_concat_block(d1, e3)
d2 = residual_block(u2, f[3])
u3 = upsample_concat_block(d2, e2)
d3 = residual_block(u3, f[2])
u4 = upsample_concat_block(d3, e1)
d4 = residual_block(u4, f[1])
outputs = keras.layers.Conv2D(1, (1, 1), padding="same", activation="sigmoid")(d4)
model = keras.models.Model(inputs, outputs)
return model
model.summary()
I am trying to execute this model however it is executing in zero seconds which means its not executing and when u try to print model summary it gives me an error saying model is not defined