Hello everyone,
I’m fine-tuning a BERT-based model for a multilabel classification task. Specifically, I’m trying to predict the Big Five personality traits based on text inputs. However, I’m encountering an issue where both my accuracy and validation accuracy remain quite low (below 30%) and stagnant during training. I’d greatly appreciate any insights or advice on how to address this issue.
Dataset:
I have a dataset containing 9,918 samples, where each sample is a text input and is labeled with binary values (0 or 1) for each of the Big Five personality traits. The traits are represented in the following columns:
E: Extraversion
N: Neuroticism
A: Agreeableness
C: Conscientiousness
O: Openness
Below is the architecture of my model:
# Load the pre-trained BERT model
bert_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased', num_labels=NUM_CLASSES)
from tensorflow.keras.layers import Input, Dense, Layer, Dropout
from tensorflow.keras.models import Model, load_model
import keras
@keras.saving.register_keras_serializable(package="BertLayer")
class BertLayer(Layer):
def __init__(self, bert_model, **kwargs):
super().__init__(**kwargs)
self.bert_model = bert_model
def call(self, inputs):
input_ids, attention_mask = inputs
return self.bert_model(input_ids=input_ids, attention_mask=attention_mask)[0]
def get_config(self):
base_config = super().get_config()
config = {
"bert_model": keras.saving.serialize_keras_object(self.bert_model),
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
bert_model_config = config.pop("bert_model")
bert_model = keras.saving.deserialize_keras_object(bert_model_config)
return cls(bert_model, **config)
def create_model(bert_model, MAX_LENGTH, NUM_CLASSES):
input_ids = Input(shape=(MAX_LENGTH,), dtype='int32', name='input_ids')
attention_mask = Input(shape=(MAX_LENGTH,), dtype='int32', name='attention_mask')
bert_output = BertLayer(bert_model)([input_ids, attention_mask])
cls_token = bert_output[:, 0, :]
x = Dense(256, activation='relu')(cls_token)
x = Dropout(0.1)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.1)(x)
output = Dense(NUM_CLASSES, activation='sigmoid', name='output')(x)
return Model(inputs=[input_ids, attention_mask], outputs=output)
model = create_model(bert_model, MAX_LENGTH, NUM_CLASSES)
model.compile(optimizer=Adam(learning_rate=1e-5),
loss='binary_crossentropy',
metrics=['acc'])
model.summary()
history = model.fit(
train_dataset,
epochs=5,
validation_data=val_dataset,
batch_size=64
)
Why might the accuracy and validation accuracy be stagnating at such a low level (~27%)? Is this expected for a dataset like mine, or could there be something wrong with the architecture/training process?