Hello, I’m using the quantized MobileBert provided by the Google Team. I’ve got a few questions about its usage.
- Using [Netron] to check the .tflite architecture, I found that the SoftMax are linear quantized. Usually this kind of non-linear functions will remain fp32 or be quantized with some polynomial approximation, as such linear quantization will affect the accuracy a lot. I found no description info or any related documents about this part in the model website…
- Then comes to the F1/EM test on SQUAD dataset. I only got 1.03% EM and 8.02% F1 score, far from expected. I’m a new beginner of AI field, so I’m not quite sure if this is because of my code bug or the model itself. I’ll attach my code later.
I’m wondering if anyone has used this model before. Any advice on the model usage or the code will be appreciated.
Here’s my code:
import re
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import AutoTokenizer
import string
from collections import Counter
TFLITE_MODEL = "mobilebert_edgetpu_quant.tflite"
SPLIT = "validation"
BATCH_SIZE = 1
MAX_ANSWER_LEN = 30
N_BEST = 20
interpreter = tf.lite.Interpreter(model_path=TFLITE_MODEL)
interpreter.allocate_tensors()
in_det = interpreter.get_input_details()
out_det = interpreter.get_output_details()
print("Inputs:")
for d in in_det: print(d["name"], d["shape"], d["dtype"])
print("Outputs:")
for d in out_det: print(d["name"], d["shape"], d["dtype"])
_, MAX_LEN = in_det[0]["shape"] # e.g. [1,384] -> MAX_LEN=384
def normalize_answer(s):
s = s.lower()
s = ''.join(ch for ch in s if ch not in set(string.punctuation))
s = re.sub(r'\b(a|an|the)\b', ' ', s)
return ' '.join(s.split())
def f1_em(pred, truths):
pred_tokens = normalize_answer(pred).split()
best_f1, best_em = 0.0, 0
for t in truths:
gold_tokens = normalize_answer(t).split()
common = Counter(pred_tokens) & Counter(gold_tokens)
num_same = sum(common.values())
if len(pred_tokens)==0 and len(gold_tokens)==0:
f1 = 1.0
elif num_same==0:
f1 = 0.0
else:
prec = num_same / max(1, len(pred_tokens))
rec = num_same / max(1, len(gold_tokens))
f1 = 2*prec*rec/(prec+rec)
em = int(normalize_answer(pred) == normalize_answer(t))
best_f1 = max(best_f1, f1)
best_em = max(best_em, em)
return best_f1, best_em
def get_best_span(start_logits, end_logits, tt_ids, attn_mask, offsets,
max_answer_len=MAX_ANSWER_LEN, n_best=N_BEST):
neg_inf = -1e9
is_ctx = (tt_ids == 0)
has_text = (offsets[:,1] > offsets[:,0])
is_valid = is_ctx & (attn_mask == 1) & has_text
if not np.any(is_valid):
return 0, 0
valid_idx = np.where(is_valid)[0]
s_logits = np.full_like(start_logits, neg_inf, dtype=np.float32)
e_logits = np.full_like(end_logits, neg_inf, dtype=np.float32)
s_logits[valid_idx] = start_logits[valid_idx]
e_logits[valid_idx] = end_logits[valid_idx]
topk = int(min(n_best, len(valid_idx)))
start_top = valid_idx[np.argpartition(s_logits[valid_idx], -topk)[-topk:]]
end_top = valid_idx[np.argpartition(e_logits[valid_idx], -topk)[-topk:]]
best_score, best_s, best_e = -np.inf, valid_idx[0], valid_idx[0]
for s in start_top:
e_max = s + max_answer_len - 1
for e in end_top:
if e < s or e > e_max:
continue
score = s_logits[s] + e_logits[e]
if score > best_score:
best_score, best_s, best_e = score, int(s), int(e)
return best_s, best_e
tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased", use_fast=True)
ds_raw = tfds.load("squad", split=SPLIT, as_supervised=False)
SEP = b"\x1f"
def encode_fn(question_bytes, context_bytes, answers_bytes):
question = question_bytes.numpy().decode("utf-8")
context = context_bytes.numpy().decode("utf-8")
enc = tokenizer(
question, context,
truncation="only_second",
padding="max_length",
max_length=MAX_LEN,
return_offsets_mapping=True,
return_tensors="np"
)
ans_np = answers_bytes.numpy()
answers_bytes_list = [bytes(x) for x in ans_np.tolist()] if ans_np.size > 0 else []
all_ans_joined = SEP.join(answers_bytes_list) if answers_bytes_list else b""
return (
enc["input_ids"][0],
enc["attention_mask"][0],
enc["token_type_ids"][0],
enc["offset_mapping"][0].astype(np.int64),
all_ans_joined,
context.encode("utf-8")
)
def tf_encode(ex):
inp_ids, attn, tt, offsets, truth_bytes, context_bytes = tf.py_function(
func=encode_fn,
inp=[ex["question"], ex["context"], ex["answers"]["text"]],
Tout=(tf.int32, tf.int32, tf.int32, tf.int64, tf.string, tf.string)
)
inp_ids.set_shape([MAX_LEN])
attn.set_shape([MAX_LEN])
tt.set_shape([MAX_LEN])
offsets.set_shape([MAX_LEN,2])
truth_bytes.set_shape([])
context_bytes.set_shape([])
return {
"input_ids": inp_ids,
"input_mask": attn,
"input_type_ids": tt,
"offsets": offsets,
"context_b": context_bytes
}, truth_bytes
val_ds = ds_raw.map(tf_encode).batch(BATCH_SIZE).prefetch(1)
sum_f1, sum_em, total = 0.0, 0, 0
for feat, truth_b in val_ds:
raw = truth_b[0].numpy()
answers_list = [x.decode("utf-8") for x in raw.split(SEP) if len(x) > 0]
if not answers_list:
answers_list = [""]
context_str = feat["context_b"][0].numpy().decode("utf-8")
offsets = feat["offsets"][0].numpy()
for d in in_det:
name, idx, dtype = d["name"], d["index"], d["dtype"]
if "word_ids" in name or "input_ids" in name:
data = feat["input_ids"].numpy()
elif "mask" in name:
data = feat["input_mask"].numpy()
elif "type_ids" in name:
data = feat["input_type_ids"].numpy()
else:
continue
interpreter.set_tensor(idx, data.astype(dtype))
interpreter.invoke()
end_raw = interpreter.get_tensor(out_det[0]["index"])[0]
start_raw = interpreter.get_tensor(out_det[1]["index"])[0]
es, ez = out_det[0]["quantization"]
ss, sz = out_det[1]["quantization"]
if es!=0:
end_logits = (end_raw.astype(np.float32)-ez)*es
start_logits = (start_raw.astype(np.float32)-sz)*ss
else:
end_logits, start_logits = end_raw.astype(np.float32), start_raw.astype(np.float32)
attn = feat["input_mask"].numpy()[0]
tt_ids = feat["input_type_ids"].numpy()[0]
s_idx, e_idx = get_best_span(start_logits, end_logits, tt_ids, attn, offsets)
char_s = offsets[s_idx][0]
char_e = offsets[e_idx][1]
pred = context_str[char_s:char_e]
f1, em = f1_em(pred, answers_list)
sum_f1 += f1
sum_em += em
total += 1
print(f"SQuAD-{SPLIT} EM = {sum_em/total*100:.2f}%, F1 = {sum_f1/total*100:.2f}%")
