Can anyone tell me what’s wrong in my code, I am getting very less accuracy in vision transformer, I have at least of 600 samples of all 6 classes, I even used data augmentation as well. I am sharing ipynb link as well
# TARGET_IMAGE_SIZE = (256, 256)
inputs = layers.Input(shape=TARGET_IMAGE_SIZE+(3,))
# Get Patches
x = Patches(PATCH_SIZE)(inputs)
# PatchEncoding Network
x = PatchEncoder(NUM_PATCHES, PROJECTION_DIM)(x)
# Transformer Network
x = Transformer(TRANSFORMER_LAYER, NUM_HEADS, PROJECTION_DIM, TRANSFORMER_UNITS)(x)
# Output Network
x = layers.LayerNormalization(epsilon=1e-6)(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
x = MLP(OUTPUT_UNITS, rate=0.5)(x)
# Ouput Layer
outputs = layers.Dense(NUM_CLASSES)(x)
model = tf.keras.Model(
inputs=[inputs],
outputs=[outputs],
)
optimizer = tf.keras.optimizers.Adam(
learning_rate=LEARNING_RATE
)
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
tf.keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
checkpoint_filepath = os.path.join("checkpoint", "vision_tranformer.h5")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
LOG_DIR = lambda f_name : os.path.join(
"logs",
f_name,
datetime.now().strftime("%Y%m%d-%H%M%S")
)
FOLDER_NAME = 'vision_transformer'
log_path = LOG_DIR(FOLDER_NAME)
tensorboard_callback = lambda log_path : tf.keras.callbacks.TensorBoard(
log_dir=log_path,
histogram_freq=1,
write_graph=True
)
history = model.fit(
train_generator,
validation_data=valid_generator,
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,
callbacks=[
checkpoint_callback,
tensorboard_callback(log_path),
tf.keras.callbacks.EarlyStopping(
patience=5,
monitor='val_accuracy',
mode='max',
restore_best_weights=True
)
],
)
output
Epoch 1/10 164/164 [==============================] - 3077s 19s/step - loss: 58.1293 - accuracy: 0.1836 - top-5-accuracy: 0.8523 - val_loss: 1.7684 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144
Epoch 2/10 164/164 [==============================] - 2334s 14s/step - loss: 1.7695 - accuracy: 0.1917 - top-5-accuracy: 0.9038 - val_loss: 1.7616 - val_accuracy: 0.1849 - val_top-5-accuracy: 0.9144
Epoch 3/10 164/164 [==============================] - 2380s 15s/step - loss: 1.7673 - accuracy: 0.1954 - top-5-accuracy: 0.9065 - val_loss: 1.7609 - val_accuracy: 0.2047 - val_top-5-accuracy: 0.9144
Epoch 4/10 164/164 [==============================] - 3009s 18s/step - loss: 1.7639 - accuracy: 0.2002 - top-5-accuracy: 0.9081 - val_loss: 1.7608 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144
Epoch 5/10 164/164 [==============================] - 2817s 17s/step - loss: 1.7649 - accuracy: 0.1953 - top-5-accuracy: 0.9060 - val_loss: 1.7607 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144
Epoch 6/10 164/164 [==============================] - 2496s 15s/step - loss: 1.7653 - accuracy: 0.2080 - top-5-accuracy: 0.9008 - val_loss: 1.7604 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144
Epoch 7/10 164/164 [==============================] - 2826s 17s/step - loss: 1.7621 - accuracy: 0.2051 - top-5-accuracy: 0.9089 - val_loss: 1.7609 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144
Epoch 8/10 164/164 [==============================] - 2199s 13s/step - loss: 1.7632 - accuracy: 0.1924 - top-5-accuracy: 0.9069 - val_loss: 1.7609 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144