Low Accuracy and Stagnant Validation Accuracy in BERT Model for Multilabel Classification

Hello everyone,

I’m fine-tuning a BERT-based model for a multilabel classification task. Specifically, I’m trying to predict the Big Five personality traits based on text inputs. However, I’m encountering an issue where both my accuracy and validation accuracy remain quite low (below 30%) and stagnant during training. I’d greatly appreciate any insights or advice on how to address this issue.

Dataset:
I have a dataset containing 9,918 samples, where each sample is a text input and is labeled with binary values (0 or 1) for each of the Big Five personality traits. The traits are represented in the following columns:

E: Extraversion
N: Neuroticism
A: Agreeableness
C: Conscientiousness
O: Openness

Below is the architecture of my model:

# Load the pre-trained BERT model
bert_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased', num_labels=NUM_CLASSES)

from tensorflow.keras.layers import Input, Dense, Layer, Dropout
from tensorflow.keras.models import Model, load_model
import keras

@keras.saving.register_keras_serializable(package="BertLayer")
class BertLayer(Layer):
    def __init__(self, bert_model, **kwargs):
        super().__init__(**kwargs)
        self.bert_model = bert_model

    def call(self, inputs):
        input_ids, attention_mask = inputs
        return self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[0]

    def get_config(self):
        base_config = super().get_config()
        config = {
            "bert_model": keras.saving.serialize_keras_object(self.bert_model),
        }
        return {**base_config, **config}

    @classmethod
    def from_config(cls, config):
        bert_model_config = config.pop("bert_model")
        bert_model = keras.saving.deserialize_keras_object(bert_model_config)
        return cls(bert_model, **config)

def create_model(bert_model, MAX_LENGTH, NUM_CLASSES):
    input_ids = Input(shape=(MAX_LENGTH,), dtype='int32', name='input_ids')
    attention_mask = Input(shape=(MAX_LENGTH,), dtype='int32', name='attention_mask')
    bert_output = BertLayer(bert_model)([input_ids, attention_mask])
    cls_token = bert_output[:, 0, :]

    x = Dense(256, activation='relu')(cls_token)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)

    output = Dense(NUM_CLASSES, activation='sigmoid', name='output')(x)

    return Model(inputs=[input_ids, attention_mask], outputs=output)

model = create_model(bert_model, MAX_LENGTH, NUM_CLASSES)
model.compile(optimizer=Adam(learning_rate=1e-5),
              loss='binary_crossentropy', 
              metrics=['acc'])

model.summary()

history = model.fit(
    train_dataset,
    epochs=5,  
    validation_data=val_dataset,
    batch_size=64
)

Why might the accuracy and validation accuracy be stagnating at such a low level (~27%)? Is this expected for a dataset like mine, or could there be something wrong with the architecture/training process?

Hi @Xuan5251, Could you please try to increase the data samples of training data and see if there is any increase in accuracy. Also let us know the number of samples present per class. Thank You.

Hi @Kiran_Sai_Ramineni, thank you for your reply. Here is the number of samples present per class:
O: 7370
C: 4556
E: 4210
A: 5268
N: 3717

Hi @Xuan5251, Initially you mentioned that the data was labeled with (0 or 1) which means 2 classes in this case the binary_crossentropy can be used as a loss function.

and these are the features.

But here you mentioned that there are 5 classes, if it is the case it will be a multi class classification where the loss function should be categorical_crossentropy.

Could you please confirm the numbers of labels. Thank You.

Sorry for the misunderstanding. Let me clarify. The dataset contains one column for text and five columns representing the Big Five personality traits: Extraversion (E), Neuroticism (N), Agreeableness (A), Conscientiousness (C), and Openness (O). Each of these traits is labelled as either 0 or 1, indicating whether a given trait is present or not.

Since each sample can have multiple traits (labels), I think it is a multilabel classification problem rather than a multiclass classification.