Hello
I built a model for sequence classification with the hugging face model:
bert_config = BertConfig.from_pretrained(MODEL_NAME)
bert_config.output_hidden_states = True
backbone = TFAutoModelForSequenceClassification.from_pretrained(MODEL_NAME,config=bert_config)
input_ids = tf.keras.layers.Input(shape=(MAX_LENGTH,), name='input_ids', dtype='int32')
features = backbone(input_ids)[1][-1]
pooling = tf.keras.layers.GlobalAveragePooling1D()(features)
dense1 = tfp.layers.DenseFlipout(512, activation=tf.nn.relu)(pooling)
final = tfp.layers.DenseFlipout(len(label2id), name='output' ,activation=tf.nn.softmax)(dense1)
model = tf.keras.Model(inputs=[input_ids], outputs = [final])
optimizer = tf.keras.optimizers.Adam(learning_rate=6e-6)
#loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
loss = elbo_loss
model.compile(optimizer=optimizer,loss=loss, metrics='accuracy')
where elbo_loss is
@tf.function
def elbo_loss(labels, logits):
loss_en = tf.nn.softmax_cross_entropy_with_logits(labels, logits)
loss_kl = tf.keras.losses.KLD(labels, logits)
loss = tf.reduce_mean(tf.add(loss_en, loss_kl))
return loss
After several epochs train accuracy is decreasing, but the loss is decreasing. The same thing with CategoricalCrossentropy. I do not see such problem with classical approach (without Bayssien layers)
Epoch 16/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1489082.6484 - accuracy: 0.8298
Epoch 17/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1485046.3760 - accuracy: 0.8398
Epoch 18/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1481018.1098 - accuracy: 0.8446
Epoch 19/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1476998.0167 - accuracy: 0.8469
Epoch 20/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1472984.8798 - accuracy: 0.8375
Epoch 21/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1468978.1942 - accuracy: 0.8240
Epoch 22/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1464977.2963 - accuracy: 0.7886
Epoch 23/100
1119/1119 [==============================] - 455s 407ms/step - loss: 1460981.7481 - accuracy: 0.7421
… up to zero
What can be reason for such behavior?