Hello,
I’m comparing TensorFlow Lite reference kernels (tf.lite.Interpreter with OpResolverType.BUILTIN_REF) to CMSIS-NN int8 implementations. CMSIS-NN says it is bit-exact relative to TFL/TFLM for many operators.
I wrote a Python script (attached below) that, for each operator, builds a minimal int8 model, finds inputs where double and single rounding disagree for the same int32 accumulator and quantized multiplier/shift, then checks which value the TFLite reference interpreter actually returns.
This is the output summary when using Tensorflow 2.21.0:
| Operation | Rounding behavior |
|---|---|
| MUL | Double |
| ADD | Double |
| SUB | Double |
| MEAN | Single |
| FULLY_CONNECTED | Single |
| CONV_2D | Single |
| DEPTHWISE_CONV_2D | Double |
I had assumed that a given TensorFlow build would follow one rounding policy for all int8 requantization that goes through MultiplyByQuantizedMultiplier, controlled at compile time by something like TFLITE_SINGLE_ROUNDING. If that were globally true, I’d expect my script to classify all ops the same way (all double or all single).
Instead, different operators appear to match different rounding formulas on the reference path. That makes it hard to mirror TFL with a single CMSIS-NN switch: CMSIS-NN exposes CMSIS_NN_USE_SINGLE_ROUNDING, which is one global choice, not per-op mixed behavior.
- Am I correct there is mixed rounding behavior for BUILTIN_REF int8 kernels — i.e. some ops always use the double-rounding MultiplyByQuantizedMultiplier path and others the single-rounding path, even in the same build?
- Which TensorFlow / tflite_runtime version should I treat as the ground truth if the goal is to match CMSIS-NN?
- Is there a supported way to force one rounding strategy for all int8 ops that use MultiplyByQuantizedMultiplier, or is per-op behavior fixed by design?
Thank you!
Python script for testing TFL rounding behavior
"""
Probe TensorFlow Lite reference kernels (BUILTIN_REF) for MultiplyByQuantizedMultiplier
rounding: gemmlowp-style double rounding vs TFLITE_SINGLE_ROUNDING, per op.
Each case builds a tiny full-int8 model, searches for int8 inputs where the two
reference formulas disagree, then compares tf.lite.Interpreter output.
Ops covered here: MUL, ADD, SUB, MEAN, FULLY_CONNECTED, CONV_2D, DEPTHWISE_CONV_2D
(int8, one subgraph op each).
Calibration: MUL uses a fixed rep dataset; ADD/SUB/MEAN use a wider float grid so
quantization parameters often admit a discriminating int8 case. If none exists in
the brute-force scan, the result is INDETERMINATE (not "unknown rounding").
"""
"""TFLite BUILTIN_REF: double vs single rounding in MultiplyByQuantizedMultiplier, per op (int8 tiny models)."""
from __future__ import annotations
import itertools
import math
import os
from dataclasses import dataclass
from typing import Callable, Iterable
import numpy as np
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
import tensorflow as tf
def _rnd(x: float) -> int:
return int(math.copysign(math.floor(abs(x) + 0.5), x))
def _qmul(m: float) -> tuple[int, int]:
if m == 0.0:
return 0, 0
f, e = math.frexp(m)
q = _rnd(f * float(1 << 31))
if q == (1 << 31):
q //= 2
e += 1
if e < -31:
return 0, 0
if q > 2**31 - 1:
q = 2**31 - 1
return int(q), int(e)
def _qmul1(x: float) -> tuple[int, int]:
if not (0.0 < x < 1.0):
raise ValueError(f"qmul1 expects x in (0,1), got {x}")
q, s = _qmul(x)
if s > 0:
raise ValueError(f"unexpected positive shift {s} for {x}")
return q, s
def _dhm(a: int, b: int) -> int:
ov = a == b == -(2**31)
ab = int(np.int64(a) * np.int64(b))
ng = (1 << 30) if ab >= 0 else (1 - (1 << 30))
h = int((ab + ng) // (1 << 31))
return (2**31 - 1) if ov else h
def _rpot(x: int, e: int) -> int:
if e < 0 or e > 31:
raise ValueError(f"rpot exponent out of range: {e}")
m = (1 << e) - 1
r = x & m
t = (m >> 1) + (1 if x < 0 else 0)
return (x >> e) + (1 if r > t else 0)
def _mbq(x: int, q: int, sh: int) -> int:
ls, rs = (sh, 0) if sh > 0 else (0, -sh)
xs = int(np.int32(x) * np.int32(1 << ls))
return _rpot(_dhm(xs, int(q)), rs)
def _mbq1(x: int, q: int, sh: int) -> int:
x, q, sh = int(x), int(q), int(sh)
ts = 31 - sh
rnd = np.int64(1) << np.int64(ts - 1)
return int(np.int32((np.int64(x) * np.int64(q) + rnd) >> np.int64(ts)))
def _mean_ms(si: float, so: float, n: int) -> tuple[int, int]:
if so == 0.0:
raise ValueError("MEAN: output scale zero")
if n <= 0:
raise ValueError(f"MEAN: bad n {n}")
mult, shift = _qmul(si / so)
sa = min(63 - (64 - int(n).bit_length()), 32, 31 + int(shift))
nm = int((np.int64(int(mult)) << sa) // n)
return int(nm), int(shift) - sa
def _c8(m, rep: Callable[[], Iterable[list]]):
c = tf.lite.TFLiteConverter.from_keras_model(m)
c.optimizations = [tf.lite.Optimize.DEFAULT]
c.representative_dataset = rep
c.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
c.inference_input_type = c.inference_output_type = tf.int8
return c
def _xy():
g = np.linspace(-2.0, 2.0, 11)
return [(float(x), float(y)) for x in g for y in g]
def _rep2c():
for _ in range(100):
yield [np.array([[1.0]], np.float32), np.array([[1.0]], np.float32)]
def _rep2w():
P = _xy()
L = len(P)
for i in range(300):
x, y = P[i % L]
yield [np.array([[x]], np.float32), np.array([[y]], np.float32)]
def _rep12():
P = _xy()
L = len(P)
for i in range(400):
x, y = P[i % L]
yield [np.array([[x, y]], np.float32)]
def _rep_nhwc12():
P = _xy()
L = len(P)
for i in range(400):
x, y = P[i % L]
yield [np.array([[[[x, y]]]], np.float32)]
def _rep_nhwc21():
P = _xy()
L = len(P)
for i in range(400):
x, y = P[i % L]
yield [np.array([[[[x]], [[y]]]], np.float32)]
def _cv8(m, rep):
return _c8(m, rep).convert()
def _intr(buf: bytes) -> tf.lite.Interpreter:
return tf.lite.Interpreter(
model_content=buf,
experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF,
experimental_delegates=None,
num_threads=1,
)
def _zs(d):
q = d["quantization_parameters"]
return int(q["zero_points"][0]), float(q["scales"][0])
def _rq(x, mult, shift, dbl):
return _mbq(x, mult, shift) if dbl else _mbq1(x, mult, shift)
def _c8o(y):
return max(-128, min(127, int(y)))
def _prep_elm(z1, s1, z2, s2, zo, so):
L = 20
tm = 2.0 * max(float(s1), float(s2))
r1, r2 = float(s1) / tm, float(s2) / tm
ro = tm / ((1 << L) * float(so))
i1m, i1s = _qmul1(r1)
i2m, i2s = _qmul1(r2)
om, os_ = _qmul1(ro)
return (L, -z1, -z2, zo, i1m, i1s, i2m, i2s, om, os_)
def _relm(a, b, t, dbl, sub):
L, o1, o2, zo, i1m, i1s, i2m, i2s, om, os_ = t
v1 = _rq((int(a) + o1) << L, i1m, i1s, dbl)
v2 = _rq((int(b) + o2) << L, i2m, i2s, dbl)
s = v1 - v2 if sub else v1 + v2
return _c8o(_rq(s, om, os_, dbl) + zo)
def _rmul(a, b, z0, z1, zo, s0, s1, so, dbl):
q, sh = _qmul(float(s0) * float(s1) / float(so))
return _c8o(_rq((int(a) - z0) * (int(b) - z1), q, sh, dbl) + zo)
def _rmean(vals, zpi, zpo, mult, shift, dbl):
n = len(vals)
acc = sum(map(int, vals)) - zpi * n
acc = max(np.iinfo(np.int32).min, min(np.iinfo(np.int32).max, acc))
return _c8o(_rq(acc, mult, shift, dbl) + zpo)
def _conv_ins(intr, name):
for op in intr._get_ops_details():
if op["op_name"] != name:
continue
ix = [int(x) for x in np.asarray(op["inputs"]).ravel() if int(x) >= 0]
if len(ix) < 2:
raise RuntimeError(f"{name}: need >=2 inputs")
tm = {t["index"]: t for t in intr.get_tensor_details()}
wi = bi = None
for i in ix[1:]:
info = tm[i]
sh = tuple(np.asarray(info["shape"]).ravel())
dt = info["dtype"]
if dt == np.int32 and len(sh) == 1:
bi = i
elif dt == np.int8 and len(sh) == 4:
wi = i
if wi is None:
raise RuntimeError(f"{name}: no int8 4D weight in {ix}")
return ix[0], wi, bi
raise RuntimeError(f"no {name}")
def _fc_ins(intr):
for op in intr._get_ops_details():
if op["op_name"] != "FULLY_CONNECTED":
continue
ix = [int(x) for x in np.asarray(op["inputs"]).ravel()]
if len(ix) < 2:
raise RuntimeError("FC: bad inputs")
bi = ix[2] if len(ix) > 2 and ix[2] >= 0 else None
return ix[0], ix[1], bi
raise RuntimeError("no FULLY_CONNECTED")
def _rfc(vals, wrow, *, io, fo, pch, be, zpo, om, os_, dbl):
acc = int(be)
for d, x in enumerate(vals):
fv, iv = int(wrow[d]), int(x)
acc += fv * (iv + io) if pch else (fv + fo) * (iv + io)
return _c8o(_rq(int(acc), om, os_, dbl) + zpo)
def _sc2(fn):
for a in range(-128, 128):
for b in range(-128, 128):
yd, ys = fn(a, b)
if yd != ys:
return a, b, yd, ys
return None
def _scn(n, fn):
for tup in itertools.product(range(-128, 128), repeat=n):
yd, ys = fn(*tup)
if yd != ys:
return tup, yd, ys
return None
def _onames(intr):
return [o["op_name"] for o in intr._get_ops_details()]
def _mac2(wrow, *, io, fo, pch, be, zpo, om, os_):
def f(v0, v1):
t = (v0, v1)
k = dict(
io=io,
fo=fo,
pch=pch,
be=be,
zpo=zpo,
om=om,
os_=os_,
)
return _rfc(t, wrow, **k, dbl=True), _rfc(t, wrow, **k, dbl=False)
return f
@dataclass
class RoundingProbeResult:
op_label: str
tflite_op_names: list[str]
rounding: str
detail: str
def _cls(t, ed, es):
if t == ed != es:
return "double"
if t == es != ed:
return "single"
if ed == es:
return "indeterminate"
return "unexpected"
def probe_mul():
i1 = tf.keras.Input(shape=(1,), dtype=tf.float32)
i2 = tf.keras.Input(shape=(1,), dtype=tf.float32)
m = tf.keras.Model([i1, i2], tf.keras.layers.Multiply()([i1, i2]))
intr = _intr(_c8(m, _rep2c).convert())
intr.allocate_tensors()
ins, outs = intr.get_input_details(), intr.get_output_details()
z0, s0 = _zs(ins[0])
z1, s1 = _zs(ins[1])
zo, so = _zs(outs[0])
def g(a, b):
return _rmul(a, b, z0, z1, zo, s0, s1, so, True), _rmul(
a, b, z0, z1, zo, s0, s1, so, False
)
fk = _sc2(g)
on = _onames(intr)
if fk is None:
return RoundingProbeResult("MUL", on, "indeterminate", "no int8 fork")
a, b, ed, es = fk
intr.set_tensor(ins[0]["index"], np.array([[a]], np.int8))
intr.set_tensor(ins[1]["index"], np.array([[b]], np.int8))
intr.invoke()
t = int(intr.get_tensor(outs[0]["index"])[0, 0])
return RoundingProbeResult("MUL", on, _cls(t, ed, es), f"a={a} b={b} tfl={t} d={ed} s={es}")
def _probe_addsub(sub: bool):
name = "SUB" if sub else "ADD"
i1 = tf.keras.Input(shape=(1,), dtype=tf.float32)
i2 = tf.keras.Input(shape=(1,), dtype=tf.float32)
layer = tf.keras.layers.Subtract if sub else tf.keras.layers.Add
m = tf.keras.Model([i1, i2], layer()([i1, i2]))
intr = _intr(_cv8(m, _rep2w))
intr.allocate_tensors()
ins, outs = intr.get_input_details(), intr.get_output_details()
t = _prep_elm(*_zs(ins[0]), *_zs(ins[1]), *_zs(outs[0]))
def g(a, b):
return _relm(a, b, t, True, sub), _relm(a, b, t, False, sub)
fk = _sc2(g)
on = _onames(intr)
if fk is None:
return RoundingProbeResult(name, on, "indeterminate", "no int8 fork")
a, b, ed, es = fk
intr.set_tensor(ins[0]["index"], np.array([[a]], np.int8))
intr.set_tensor(ins[1]["index"], np.array([[b]], np.int8))
intr.invoke()
out = int(intr.get_tensor(outs[0]["index"])[0, 0])
return RoundingProbeResult(name, on, _cls(out, ed, es), f"a={a} b={b} tfl={out} d={ed} s={es}")
def probe_add():
return _probe_addsub(False)
def probe_sub():
return _probe_addsub(True)
def probe_mean():
inp = tf.keras.Input(shape=(2,), dtype=tf.float32)
m = tf.keras.Model(
inp,
tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(inp),
)
intr = _intr(_cv8(m, _rep12))
intr.allocate_tensors()
ins, outs = intr.get_input_details(), intr.get_output_details()
zi, si = _zs(ins[0])
zo, so = _zs(outs[0])
mm, sh = _mean_ms(si, so, 2)
def g(v0, v1):
tp = (v0, v1)
return _rmean(tp, zi, zo, mm, sh, True), _rmean(tp, zi, zo, mm, sh, False)
fk = _scn(2, g)
on = _onames(intr)
if fk is None:
return RoundingProbeResult("MEAN (axis=-1, n=2)", on, "indeterminate", "no fork")
tup, ed, es = fk
intr.set_tensor(ins[0]["index"], np.array([list(tup)], np.int8))
intr.invoke()
t = int(intr.get_tensor(outs[0]["index"])[0, 0])
return RoundingProbeResult(
"MEAN (axis=-1, n=2)", on, _cls(t, ed, es), f"vals={tup} tfl={t} d={ed} s={es}"
)
def _wquant(w_t, w, out_c, odim):
wq = w_t["quantization_parameters"]
sc = np.asarray(wq["scales"], np.float64).ravel()
zp = np.asarray(wq["zero_points"], np.int32).ravel()
pch = sc.size > 1
if pch and sc.size != odim:
raise ValueError(f"scales {sc.size} != od {odim}")
if pch:
return 0, float(sc[out_c]), True
return -int(zp[0]), float(sc[0]), False
def probe_fully_connected():
inp = tf.keras.Input(shape=(2,), dtype=tf.float32)
m = tf.keras.Model(inp, tf.keras.layers.Dense(1, use_bias=True)(inp))
intr = _intr(_cv8(m, _rep12))
intr.allocate_tensors()
ai, wi, bi = _fc_ins(intr)
ins, outs = intr.get_input_details()[0], intr.get_output_details()[0]
on = _onames(intr)
tm = {t["index"]: t for t in intr.get_tensor_details()}
w_t, w = tm[wi], np.asarray(intr.get_tensor(wi), np.int8)
if w.ndim != 2:
return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", f"W rank {w.ndim}")
od, ad = int(w.shape[0]), int(w.shape[1])
if ins["index"] != ai:
return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", "in idx")
try:
fo, swe, pch = _wquant(w_t, w, 0, od)
except ValueError as e:
return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", str(e))
if ad != 2:
return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", f"adim {ad}")
zi = int(ins["quantization_parameters"]["zero_points"][0])
si = float(ins["quantization_parameters"]["scales"][0])
zo = int(outs["quantization_parameters"]["zero_points"][0])
so = float(outs["quantization_parameters"]["scales"][0])
om, osh = _qmul(si * swe / so)
be = int(np.asarray(intr.get_tensor(bi)).ravel()[0]) if bi is not None else 0
g = _mac2(w[0, :], io=-zi, fo=fo, pch=pch, be=be, zpo=zo, om=om, os_=osh)
fk = _scn(2, g)
lb = "FULLY_CONNECTED (2→1)"
if fk is None:
return RoundingProbeResult(lb, on, "indeterminate", "no fork")
tup, ed, es = fk
intr.set_tensor(ins["index"], np.array([list(tup)], np.int8))
intr.invoke()
t = int(intr.get_tensor(outs["index"])[0, 0])
return RoundingProbeResult(lb, on, _cls(t, ed, es), f"vals={tup} tfl={t} d={ed} s={es}")
def _probe_conv_dw(name, op, rep, wexp, wrow_fn, pack_in):
intr = _intr(_cv8(op, rep))
intr.allocate_tensors()
lb = name
on = _onames(intr)
try:
ai, wi, bi = _conv_ins(intr, name.split()[0])
except RuntimeError as e:
return RoundingProbeResult(lb, on, "unexpected", str(e))
ins, outs = intr.get_input_details()[0], intr.get_output_details()[0]
if ins["index"] != ai:
return RoundingProbeResult(lb, on, "unexpected", "in idx")
tm = {t["index"]: t for t in intr.get_tensor_details()}
w_t, w = tm[wi], np.asarray(intr.get_tensor(wi), np.int8)
if tuple(w.shape) != wexp:
return RoundingProbeResult(lb, on, "unexpected", f"W {w.shape}")
wrow = wrow_fn(w)
zi = int(ins["quantization_parameters"]["zero_points"][0])
si = float(ins["quantization_parameters"]["scales"][0])
zo = int(outs["quantization_parameters"]["zero_points"][0])
so = float(outs["quantization_parameters"]["scales"][0])
try:
fo, swe, pch = _wquant(w_t, w, 0, int(w.shape[0] if name.startswith("CONV") else w.shape[3]))
except ValueError as e:
return RoundingProbeResult(lb, on, "unexpected", str(e))
om, osh = _qmul(si * swe / so)
be = int(np.asarray(intr.get_tensor(bi)).ravel()[0]) if bi is not None else 0
g = _mac2(wrow, io=-zi, fo=fo, pch=pch, be=be, zpo=zo, om=om, os_=osh)
fk = _scn(2, g)
if fk is None:
return RoundingProbeResult(lb, on, "indeterminate", "no fork")
tup, ed, es = fk
intr.set_tensor(ins["index"], np.asarray(pack_in(tup), np.int8))
intr.invoke()
t = int(intr.get_tensor(outs["index"])[0, 0, 0, 0])
return RoundingProbeResult(lb, on, _cls(t, ed, es), f"vals={tup} tfl={t} d={ed} s={es}")
def probe_conv2d():
inp = tf.keras.Input(shape=(1, 1, 2), dtype=tf.float32)
op = tf.keras.Model(inp, tf.keras.layers.Conv2D(1, (1, 1), padding="valid", use_bias=True)(inp))
return _probe_conv_dw(
"CONV_2D (1×1, 2→1)",
op,
_rep_nhwc12,
(1, 1, 1, 2),
lambda w: w[0, 0, 0, :].copy(),
lambda t: np.array([[[[t[0], t[1]]]]]),
)
def probe_depthwise_conv2d():
inp = tf.keras.Input(shape=(2, 1, 1), dtype=tf.float32)
op = tf.keras.Model(
inp,
tf.keras.layers.DepthwiseConv2D((2, 1), padding="valid", depth_multiplier=1, use_bias=True)(
inp
),
)
return _probe_conv_dw(
"DEPTHWISE_CONV_2D (2×1, 1ch)",
op,
_rep_nhwc21,
(1, 2, 1, 1),
lambda w: w[0, :, 0, 0].copy(),
lambda t: np.array([[[[t[0]]], [[t[1]]]]]),
)
def main():
probes = (
probe_mul,
probe_add,
probe_sub,
probe_mean,
probe_fully_connected,
probe_conv2d,
probe_depthwise_conv2d,
)
rows = []
for fn in probes:
r = fn()
print(
f"--- {r.op_label} ---\nGraph ops: {r.tflite_op_names}\n"
f"Result: {r.rounding.upper()}\n {r.detail}\n"
)
rows.append((r.op_label.split()[0], r.rounding))
print("Summary:")
for name, rnd in rows:
print(f" {name:22} {rnd}")
if __name__ == "__main__":
main()