Getting very less accuracy in vision transformer

Can anyone tell me what’s wrong in my code, I am getting very less accuracy in vision transformer, I have at least of 600 samples of all 6 classes, I even used data augmentation as well. I am sharing ipynb link as well

# TARGET_IMAGE_SIZE = (256, 256)
inputs = layers.Input(shape=TARGET_IMAGE_SIZE+(3,))


# Get Patches
x = Patches(PATCH_SIZE)(inputs)

# PatchEncoding Network
x = PatchEncoder(NUM_PATCHES, PROJECTION_DIM)(x)

# Transformer Network
x = Transformer(TRANSFORMER_LAYER, NUM_HEADS, PROJECTION_DIM, TRANSFORMER_UNITS)(x)

# Output Network
x = layers.LayerNormalization(epsilon=1e-6)(x)

x = layers.Flatten()(x)

x = layers.Dropout(0.5)(x)

x = MLP(OUTPUT_UNITS, rate=0.5)(x)

# Ouput Layer
outputs = layers.Dense(NUM_CLASSES)(x)


model = tf.keras.Model(
    inputs=[inputs],
    outputs=[outputs],
  )

optimizer = tf.keras.optimizers.Adam(
       learning_rate=LEARNING_RATE
   )

model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
        tf.keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)
checkpoint_filepath = os.path.join("checkpoint", "vision_tranformer.h5")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_filepath,
    monitor="val_accuracy",
    save_best_only=True,
    save_weights_only=True,
)

LOG_DIR = lambda f_name : os.path.join(
    "logs",
    f_name,
    datetime.now().strftime("%Y%m%d-%H%M%S")
)

FOLDER_NAME = 'vision_transformer'

log_path = LOG_DIR(FOLDER_NAME)

tensorboard_callback = lambda log_path : tf.keras.callbacks.TensorBoard(
    log_dir=log_path,
    histogram_freq=1,
    write_graph=True
)

history = model.fit(
    train_generator,
    validation_data=valid_generator,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    callbacks=[
        checkpoint_callback, 
        tensorboard_callback(log_path),
        tf.keras.callbacks.EarlyStopping(
            patience=5,
            monitor='val_accuracy',
            mode='max',
            restore_best_weights=True
        )
    ],
)

output

Epoch 1/10 164/164 [==============================] - 3077s 19s/step - loss: 58.1293 - accuracy: 0.1836 - top-5-accuracy: 0.8523 - val_loss: 1.7684 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144

Epoch 2/10 164/164 [==============================] - 2334s 14s/step - loss: 1.7695 - accuracy: 0.1917 - top-5-accuracy: 0.9038 - val_loss: 1.7616 - val_accuracy: 0.1849 - val_top-5-accuracy: 0.9144

Epoch 3/10 164/164 [==============================] - 2380s 15s/step - loss: 1.7673 - accuracy: 0.1954 - top-5-accuracy: 0.9065 - val_loss: 1.7609 - val_accuracy: 0.2047 - val_top-5-accuracy: 0.9144

Epoch 4/10 164/164 [==============================] - 3009s 18s/step - loss: 1.7639 - accuracy: 0.2002 - top-5-accuracy: 0.9081 - val_loss: 1.7608 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144

Epoch 5/10 164/164 [==============================] - 2817s 17s/step - loss: 1.7649 - accuracy: 0.1953 - top-5-accuracy: 0.9060 - val_loss: 1.7607 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144

Epoch 6/10 164/164 [==============================] - 2496s 15s/step - loss: 1.7653 - accuracy: 0.2080 - top-5-accuracy: 0.9008 - val_loss: 1.7604 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144

Epoch 7/10 164/164 [==============================] - 2826s 17s/step - loss: 1.7621 - accuracy: 0.2051 - top-5-accuracy: 0.9089 - val_loss: 1.7609 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144

Epoch 8/10 164/164 [==============================] - 2199s 13s/step - loss: 1.7632 - accuracy: 0.1924 - top-5-accuracy: 0.9069 - val_loss: 1.7609 - val_accuracy: 0.1856 - val_top-5-accuracy: 0.9144