base_model = MobileNetV3Large(input_shape=image_shape,
include_top=False, # <== Important!!!!
weights="imagenet", include_preprocessing=False)
# freeze the base model by making it non trainable
base_model.trainable = False
inputs = tf.keras.Input(shape=image_shape)
# apply data augmentation to the inputs
x = data_augmentation(inputs)
#x = preprocess_input(x) Redundant as included in base model
# set training to False to avoid keeping track of statistics in the batch norm layer
x = base_model(x, training=False)
x = GlobalAveragePooling2D(name="AveragePooling")(x)
x = BatchNormalization()(x)
# include dropout with probability of 0.2 to avoid overfitting
x = Dropout(0.2, name="FinalDropout")(x)
# use a prediction layer
outputs = Dense(output_shape, kernel_regularizer=l2(0.1), name="output_layer")(x)
model = tf.keras.Model(inputs, outputs)
model.name = "DogBreedsClassification"
model.compile(optimizer=Adam(learning_rate),
loss=CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
I have scaled the input images using img = tf.image.convert_image_dtype(img, tf.float32)
. I have tried many different things and keep getting negative predictions.
# What does a single batch look like?
image_batch, label_batch = next(iter(train_data))
image_batch.shape, label_batch.shape
# Get a single image with a batch size of 1
single_image_input = tf.expand_dims(image_batch[0], axis=0)
print(f"image: min: {tf.reduce_min(single_image_input)}, max: {tf.reduce_max(single_image_input)}")
# Pass the image through our model
single_image_output_sequential = model(single_image_input)
# Check the output
print(f"output: min: {tf.reduce_min(single_image_output_sequential)}, max: {tf.reduce_max(single_image_output_sequential)}")
print(f"sum: {numpy.sum(single_image_output_sequential)}")
single_image_output_sequential
Output:
image: min: 0.0, max: 0.9990289211273193
output: min: -0.3796367049217224, max: 0.5337352752685547
sum: -0.8821430206298828
<tf.Tensor: shape=(1, 120), dtype=float32, numpy=
array([[-0.2188825 , 0.10653329, 0.00597369, 0.12478244, 0.26514864,
-0.26880372, 0.1822165 , -0.10658702, 0.00581348, -0.3208883 ,
-0.24656023, 0.07816146, 0.13889077, 0.06685722, 0.05403857,
0.03082676, 0.0735651 , 0.04353707, -0.14945427, 0.06788599,
-0.17124134, -0.09991369, -0.0794204 , 0.00860454, -0.28096104,
0.17883578, 0.40302822, -0.30102172, -0.31097123, -0.06332889,
0.04188699, -0.2644603 , -0.10414463, 0.14845969, -0.3796367 ,
0.12521605, -0.25900525, 0.03852078, 0.35168853, -0.06216867,
-0.17495483, 0.04499481, 0.22087093, 0.03582342, -0.24091361,
0.10360496, 0.185919 , -0.02944497, 0.24843341, 0.06673928,
-0.02435508, -0.17979601, -0.03893549, 0.04256336, -0.15558268,
0.1588718 , 0.13245626, -0.13164645, -0.14717901, 0.03356938,
0.00357068, -0.16383977, -0.2017885 , -0.17665851, -0.08628735,
0.0995516 , 0.14680424, -0.22888815, 0.14236785, -0.01733635,
0.00285829, -0.06281093, -0.08636647, 0.08349405, -0.05924843,
-0.0192999 , 0.06708179, 0.08897432, -0.17808396, -0.00832896,
-0.15415476, 0.01466018, -0.18801415, -0.04791185, 0.13846274,
-0.04429017, 0.12047836, -0.03919715, 0.5337353 , -0.08102562,
0.18035048, 0.1974282 , 0.44417682, 0.12379853, -0.040514 ,
0.10690069, 0.28111115, -0.24229927, 0.01829374, -0.00342152,
0.3781557 , -0.15650302, -0.281237 , -0.20091408, -0.17689967,
-0.19114447, -0.01090574, 0.28148118, 0.03928091, 0.06777059,
-0.31927913, -0.20672244, -0.13228607, -0.00281268, 0.14999568,
-0.01670685, -0.02973013, 0.0202262 , 0.16579026, 0.07190445]],
dtype=float32)>
What do I miss?