Issues about quantized MobileBert model

Hello, I’m using the quantized MobileBert provided by the Google Team. I’ve got a few questions about its usage.

  1. Using [Netron] to check the .tflite architecture, I found that the SoftMax are linear quantized. Usually this kind of non-linear functions will remain fp32 or be quantized with some polynomial approximation, as such linear quantization will affect the accuracy a lot. I found no description info or any related documents about this part in the model website…
  2. Then comes to the F1/EM test on SQUAD dataset. I only got 1.03% EM and 8.02% F1 score, far from expected. I’m a new beginner of AI field, so I’m not quite sure if this is because of my code bug or the model itself. I’ll attach my code later.

I’m wondering if anyone has used this model before. Any advice on the model usage or the code will be appreciated.

Here’s my code:

import re
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import AutoTokenizer
import string
from collections import Counter

TFLITE_MODEL    = "mobilebert_edgetpu_quant.tflite"
SPLIT           = "validation"
BATCH_SIZE      = 1
MAX_ANSWER_LEN  = 30
N_BEST          = 20

interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL)
interpreter.allocate_tensors()
in_det  = interpreter.get_input_details()
out_det = interpreter.get_output_details()
print("Inputs:")
for d in in_det:  print(d["name"], d["shape"], d["dtype"])
print("Outputs:")
for d in out_det: print(d["name"], d["shape"], d["dtype"])
_, MAX_LEN = in_det[0]["shape"]   # e.g. [1,384] -> MAX_LEN=384

def normalize_answer(s):
    s = s.lower()
    s = ''.join(ch for ch in s if ch not in set(string.punctuation))
    s = re.sub(r'\b(a|an|the)\b', ' ', s) 
    return ' '.join(s.split())

def f1_em(pred, truths):
    pred_tokens = normalize_answer(pred).split()
    best_f1, best_em = 0.0, 0
    for t in truths:
        gold_tokens = normalize_answer(t).split()
        common = Counter(pred_tokens) & Counter(gold_tokens)
        num_same = sum(common.values())
        if len(pred_tokens)==0 and len(gold_tokens)==0:
            f1 = 1.0
        elif num_same==0:
            f1 = 0.0
        else:
            prec = num_same / max(1, len(pred_tokens))
            rec  = num_same / max(1, len(gold_tokens))
            f1   = 2*prec*rec/(prec+rec)
        em = int(normalize_answer(pred) == normalize_answer(t))
        best_f1 = max(best_f1, f1)
        best_em = max(best_em, em)
    return best_f1, best_em

def get_best_span(start_logits, end_logits, tt_ids, attn_mask, offsets,
                  max_answer_len=MAX_ANSWER_LEN, n_best=N_BEST):
    neg_inf = -1e9
    is_ctx   = (tt_ids == 0)
    has_text = (offsets[:,1] > offsets[:,0])  
    is_valid = is_ctx & (attn_mask == 1) & has_text

    if not np.any(is_valid):
        return 0, 0  

    valid_idx = np.where(is_valid)[0]
    s_logits = np.full_like(start_logits, neg_inf, dtype=np.float32)
    e_logits = np.full_like(end_logits,   neg_inf, dtype=np.float32)
    s_logits[valid_idx] = start_logits[valid_idx]
    e_logits[valid_idx] = end_logits[valid_idx]

    topk = int(min(n_best, len(valid_idx)))
    start_top = valid_idx[np.argpartition(s_logits[valid_idx], -topk)[-topk:]]
    end_top   = valid_idx[np.argpartition(e_logits[valid_idx],   -topk)[-topk:]]

    best_score, best_s, best_e = -np.inf, valid_idx[0], valid_idx[0]
    for s in start_top:
        e_max = s + max_answer_len - 1
        for e in end_top:
            if e < s or e > e_max:  
                continue
            score = s_logits[s] + e_logits[e]
            if score > best_score:
                best_score, best_s, best_e = score, int(s), int(e)
    return best_s, best_e


tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased", use_fast=True)
ds_raw = tfds.load("squad", split=SPLIT, as_supervised=False)

SEP = b"\x1f" 

def encode_fn(question_bytes, context_bytes, answers_bytes):
    question = question_bytes.numpy().decode("utf-8")
    context  = context_bytes.numpy().decode("utf-8")
    enc = tokenizer(
        question, context,
        truncation="only_second",
        padding="max_length",
        max_length=MAX_LEN,
        return_offsets_mapping=True,
        return_tensors="np"
    )
    ans_np = answers_bytes.numpy()                      
    answers_bytes_list = [bytes(x) for x in ans_np.tolist()] if ans_np.size > 0 else []
    all_ans_joined = SEP.join(answers_bytes_list) if answers_bytes_list else b""
    return (
        enc["input_ids"][0],
        enc["attention_mask"][0],
        enc["token_type_ids"][0],
        enc["offset_mapping"][0].astype(np.int64),
        all_ans_joined,
        context.encode("utf-8")        
    )


def tf_encode(ex):
    inp_ids, attn, tt, offsets, truth_bytes, context_bytes = tf.py_function(
        func=encode_fn,
        inp=[ex["question"], ex["context"], ex["answers"]["text"]],
        Tout=(tf.int32, tf.int32, tf.int32, tf.int64, tf.string, tf.string)
    )

    inp_ids.set_shape([MAX_LEN])
    attn.set_shape([MAX_LEN])
    tt.set_shape([MAX_LEN])
    offsets.set_shape([MAX_LEN,2])
    truth_bytes.set_shape([])
    context_bytes.set_shape([])


    return {
        "input_ids":      inp_ids,
        "input_mask":     attn,
        "input_type_ids": tt,
        "offsets":        offsets,
        "context_b":      context_bytes
    }, truth_bytes

val_ds = ds_raw.map(tf_encode).batch(BATCH_SIZE).prefetch(1)


sum_f1, sum_em, total = 0.0, 0, 0

for feat, truth_b in val_ds:

    raw = truth_b[0].numpy()                      
    answers_list = [x.decode("utf-8") for x in raw.split(SEP) if len(x) > 0]
    if not answers_list:                          
        answers_list = [""]
    context_str = feat["context_b"][0].numpy().decode("utf-8")
    offsets     = feat["offsets"][0].numpy()  

    for d in in_det:
        name, idx, dtype = d["name"], d["index"], d["dtype"]
        if "word_ids" in name or "input_ids" in name:
            data = feat["input_ids"].numpy()
        elif "mask" in name:
            data = feat["input_mask"].numpy()
        elif "type_ids" in name:
            data = feat["input_type_ids"].numpy()
        else:
            continue
        interpreter.set_tensor(idx, data.astype(dtype))
    interpreter.invoke()

    end_raw   = interpreter.get_tensor(out_det[0]["index"])[0]
    start_raw = interpreter.get_tensor(out_det[1]["index"])[0]
    es, ez = out_det[0]["quantization"]
    ss, sz = out_det[1]["quantization"]
    if es!=0:
        end_logits   = (end_raw.astype(np.float32)-ez)*es
        start_logits = (start_raw.astype(np.float32)-sz)*ss
    else:
        end_logits, start_logits = end_raw.astype(np.float32), start_raw.astype(np.float32)

    attn = feat["input_mask"].numpy()[0]
    tt_ids = feat["input_type_ids"].numpy()[0]
    s_idx, e_idx = get_best_span(start_logits, end_logits, tt_ids, attn, offsets)


    char_s = offsets[s_idx][0]
    char_e = offsets[e_idx][1]
    pred   = context_str[char_s:char_e]

    f1, em = f1_em(pred, answers_list)
    sum_f1 += f1
    sum_em += em
    total += 1

print(f"SQuAD-{SPLIT}  EM = {sum_em/total*100:.2f}%,  F1 = {sum_f1/total*100:.2f}%")

Hi, @Yike_Li
Welcome to community! I apologize for the delay in my response as per my current understanding the quantized MobileBERT EdgeTPU model uses Quantization Aware Training to fully quantize even SoftMax without accuracy loss the model can achieve ~90% F1 on SQuAD. As far I know the near-zero F1/EM scores are caused by a logic error in your get_best_span function you filter for tt_ids == 0 (question tokens) instead of tt_ids == 1 (context tokens) so the answer span is never found. I think simply changing is_ctx = (tt_ids == 0) to is_ctx = (tt_ids == 1) restores expected performance. For a more robust solution use the LiteRT Task Library’s BertQuestionAnswerer which handles pre and post processing automatically and avoids such issues. Please refer this official documentation of Integrate BERT question answerer

Please give it a try and Let us know is it working as expected or not and feel free to share any observations or findings from your end.

Thank you for your cooperation and patience.

Hi, @Rahul_Gaikwad
Thanks for your reply. The QAT technique indeed opens up the possibility of using linear SoftMax. I’m also wondering if TFLite has introduced any new PTQ approaches for non-linear functions. As far as I know, TFLite stack only supports int8xint8 and int8xint16 precision configurations, while non-linear operators are still kept in floating-point format. This may not be helpful when deploying quantized models on edge devices.

As for the code, I modified the code as you suggested and reran the test on SQuAD dataset. However, the result is still not as expected.

SQuAD-validation  EM = 0.66%,  F1 = 3.59%

I would still like to run the model inference step by step instead of using the packaged API, taking it as one way to learn the LLM inference process. I’ll keep debugging the code, and any advice on the model usage or the code will be appreciated.