Mixed rounding behavior in quantized model using TFL Interpreter

Hello,

I’m comparing TensorFlow Lite reference kernels (tf.lite.Interpreter with OpResolverType.BUILTIN_REF) to CMSIS-NN int8 implementations. CMSIS-NN says it is bit-exact relative to TFL/TFLM for many operators.

I wrote a Python script (attached below) that, for each operator, builds a minimal int8 model, finds inputs where double and single rounding disagree for the same int32 accumulator and quantized multiplier/shift, then checks which value the TFLite reference interpreter actually returns.

This is the output summary when using Tensorflow 2.21.0:

Operation Rounding behavior
MUL Double
ADD Double
SUB Double
MEAN Single
FULLY_CONNECTED Single
CONV_2D Single
DEPTHWISE_CONV_2D Double

I had assumed that a given TensorFlow build would follow one rounding policy for all int8 requantization that goes through MultiplyByQuantizedMultiplier, controlled at compile time by something like TFLITE_SINGLE_ROUNDING. If that were globally true, I’d expect my script to classify all ops the same way (all double or all single).

Instead, different operators appear to match different rounding formulas on the reference path. That makes it hard to mirror TFL with a single CMSIS-NN switch: CMSIS-NN exposes CMSIS_NN_USE_SINGLE_ROUNDING, which is one global choice, not per-op mixed behavior.

  1. Am I correct there is mixed rounding behavior for BUILTIN_REF int8 kernels — i.e. some ops always use the double-rounding MultiplyByQuantizedMultiplier path and others the single-rounding path, even in the same build?
  2. Which TensorFlow / tflite_runtime version should I treat as the ground truth if the goal is to match CMSIS-NN?
  3. Is there a supported way to force one rounding strategy for all int8 ops that use MultiplyByQuantizedMultiplier, or is per-op behavior fixed by design?

Thank you!

Python script for testing TFL rounding behavior
"""
Probe TensorFlow Lite reference kernels (BUILTIN_REF) for MultiplyByQuantizedMultiplier
rounding: gemmlowp-style double rounding vs TFLITE_SINGLE_ROUNDING, per op.

Each case builds a tiny full-int8 model, searches for int8 inputs where the two
reference formulas disagree, then compares tf.lite.Interpreter output.

Ops covered here: MUL, ADD, SUB, MEAN, FULLY_CONNECTED, CONV_2D, DEPTHWISE_CONV_2D
(int8, one subgraph op each).

Calibration: MUL uses a fixed rep dataset; ADD/SUB/MEAN use a wider float grid so
quantization parameters often admit a discriminating int8 case. If none exists in
the brute-force scan, the result is INDETERMINATE (not "unknown rounding").
"""
"""TFLite BUILTIN_REF: double vs single rounding in MultiplyByQuantizedMultiplier, per op (int8 tiny models)."""

from __future__ import annotations

import itertools
import math
import os
from dataclasses import dataclass
from typing import Callable, Iterable

import numpy as np

os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
import tensorflow as tf


def _rnd(x: float) -> int:
    return int(math.copysign(math.floor(abs(x) + 0.5), x))


def _qmul(m: float) -> tuple[int, int]:
    if m == 0.0:
        return 0, 0
    f, e = math.frexp(m)
    q = _rnd(f * float(1 << 31))
    if q == (1 << 31):
        q //= 2
        e += 1
    if e < -31:
        return 0, 0
    if q > 2**31 - 1:
        q = 2**31 - 1
    return int(q), int(e)


def _qmul1(x: float) -> tuple[int, int]:
    if not (0.0 < x < 1.0):
        raise ValueError(f"qmul1 expects x in (0,1), got {x}")
    q, s = _qmul(x)
    if s > 0:
        raise ValueError(f"unexpected positive shift {s} for {x}")
    return q, s


def _dhm(a: int, b: int) -> int:
    ov = a == b == -(2**31)
    ab = int(np.int64(a) * np.int64(b))
    ng = (1 << 30) if ab >= 0 else (1 - (1 << 30))
    h = int((ab + ng) // (1 << 31))
    return (2**31 - 1) if ov else h


def _rpot(x: int, e: int) -> int:
    if e < 0 or e > 31:
        raise ValueError(f"rpot exponent out of range: {e}")
    m = (1 << e) - 1
    r = x & m
    t = (m >> 1) + (1 if x < 0 else 0)
    return (x >> e) + (1 if r > t else 0)


def _mbq(x: int, q: int, sh: int) -> int:
    ls, rs = (sh, 0) if sh > 0 else (0, -sh)
    xs = int(np.int32(x) * np.int32(1 << ls))
    return _rpot(_dhm(xs, int(q)), rs)


def _mbq1(x: int, q: int, sh: int) -> int:
    x, q, sh = int(x), int(q), int(sh)
    ts = 31 - sh
    rnd = np.int64(1) << np.int64(ts - 1)
    return int(np.int32((np.int64(x) * np.int64(q) + rnd) >> np.int64(ts)))


def _mean_ms(si: float, so: float, n: int) -> tuple[int, int]:
    if so == 0.0:
        raise ValueError("MEAN: output scale zero")
    if n <= 0:
        raise ValueError(f"MEAN: bad n {n}")
    mult, shift = _qmul(si / so)
    sa = min(63 - (64 - int(n).bit_length()), 32, 31 + int(shift))
    nm = int((np.int64(int(mult)) << sa) // n)
    return int(nm), int(shift) - sa


def _c8(m, rep: Callable[[], Iterable[list]]):
    c = tf.lite.TFLiteConverter.from_keras_model(m)
    c.optimizations = [tf.lite.Optimize.DEFAULT]
    c.representative_dataset = rep
    c.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    c.inference_input_type = c.inference_output_type = tf.int8
    return c


def _xy():
    g = np.linspace(-2.0, 2.0, 11)
    return [(float(x), float(y)) for x in g for y in g]


def _rep2c():
    for _ in range(100):
        yield [np.array([[1.0]], np.float32), np.array([[1.0]], np.float32)]


def _rep2w():
    P = _xy()
    L = len(P)
    for i in range(300):
        x, y = P[i % L]
        yield [np.array([[x]], np.float32), np.array([[y]], np.float32)]


def _rep12():
    P = _xy()
    L = len(P)
    for i in range(400):
        x, y = P[i % L]
        yield [np.array([[x, y]], np.float32)]


def _rep_nhwc12():
    P = _xy()
    L = len(P)
    for i in range(400):
        x, y = P[i % L]
        yield [np.array([[[[x, y]]]], np.float32)]


def _rep_nhwc21():
    P = _xy()
    L = len(P)
    for i in range(400):
        x, y = P[i % L]
        yield [np.array([[[[x]], [[y]]]], np.float32)]


def _cv8(m, rep):
    return _c8(m, rep).convert()


def _intr(buf: bytes) -> tf.lite.Interpreter:
    return tf.lite.Interpreter(
        model_content=buf,
        experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF,
        experimental_delegates=None,
        num_threads=1,
    )


def _zs(d):
    q = d["quantization_parameters"]
    return int(q["zero_points"][0]), float(q["scales"][0])


def _rq(x, mult, shift, dbl):
    return _mbq(x, mult, shift) if dbl else _mbq1(x, mult, shift)


def _c8o(y):
    return max(-128, min(127, int(y)))


def _prep_elm(z1, s1, z2, s2, zo, so):
    L = 20
    tm = 2.0 * max(float(s1), float(s2))
    r1, r2 = float(s1) / tm, float(s2) / tm
    ro = tm / ((1 << L) * float(so))
    i1m, i1s = _qmul1(r1)
    i2m, i2s = _qmul1(r2)
    om, os_ = _qmul1(ro)
    return (L, -z1, -z2, zo, i1m, i1s, i2m, i2s, om, os_)


def _relm(a, b, t, dbl, sub):
    L, o1, o2, zo, i1m, i1s, i2m, i2s, om, os_ = t
    v1 = _rq((int(a) + o1) << L, i1m, i1s, dbl)
    v2 = _rq((int(b) + o2) << L, i2m, i2s, dbl)
    s = v1 - v2 if sub else v1 + v2
    return _c8o(_rq(s, om, os_, dbl) + zo)


def _rmul(a, b, z0, z1, zo, s0, s1, so, dbl):
    q, sh = _qmul(float(s0) * float(s1) / float(so))
    return _c8o(_rq((int(a) - z0) * (int(b) - z1), q, sh, dbl) + zo)


def _rmean(vals, zpi, zpo, mult, shift, dbl):
    n = len(vals)
    acc = sum(map(int, vals)) - zpi * n
    acc = max(np.iinfo(np.int32).min, min(np.iinfo(np.int32).max, acc))
    return _c8o(_rq(acc, mult, shift, dbl) + zpo)


def _conv_ins(intr, name):
    for op in intr._get_ops_details():
        if op["op_name"] != name:
            continue
        ix = [int(x) for x in np.asarray(op["inputs"]).ravel() if int(x) >= 0]
        if len(ix) < 2:
            raise RuntimeError(f"{name}: need >=2 inputs")
        tm = {t["index"]: t for t in intr.get_tensor_details()}
        wi = bi = None
        for i in ix[1:]:
            info = tm[i]
            sh = tuple(np.asarray(info["shape"]).ravel())
            dt = info["dtype"]
            if dt == np.int32 and len(sh) == 1:
                bi = i
            elif dt == np.int8 and len(sh) == 4:
                wi = i
        if wi is None:
            raise RuntimeError(f"{name}: no int8 4D weight in {ix}")
        return ix[0], wi, bi
    raise RuntimeError(f"no {name}")


def _fc_ins(intr):
    for op in intr._get_ops_details():
        if op["op_name"] != "FULLY_CONNECTED":
            continue
        ix = [int(x) for x in np.asarray(op["inputs"]).ravel()]
        if len(ix) < 2:
            raise RuntimeError("FC: bad inputs")
        bi = ix[2] if len(ix) > 2 and ix[2] >= 0 else None
        return ix[0], ix[1], bi
    raise RuntimeError("no FULLY_CONNECTED")


def _rfc(vals, wrow, *, io, fo, pch, be, zpo, om, os_, dbl):
    acc = int(be)
    for d, x in enumerate(vals):
        fv, iv = int(wrow[d]), int(x)
        acc += fv * (iv + io) if pch else (fv + fo) * (iv + io)
    return _c8o(_rq(int(acc), om, os_, dbl) + zpo)


def _sc2(fn):
    for a in range(-128, 128):
        for b in range(-128, 128):
            yd, ys = fn(a, b)
            if yd != ys:
                return a, b, yd, ys
    return None


def _scn(n, fn):
    for tup in itertools.product(range(-128, 128), repeat=n):
        yd, ys = fn(*tup)
        if yd != ys:
            return tup, yd, ys
    return None


def _onames(intr):
    return [o["op_name"] for o in intr._get_ops_details()]


def _mac2(wrow, *, io, fo, pch, be, zpo, om, os_):
    def f(v0, v1):
        t = (v0, v1)
        k = dict(
            io=io,
            fo=fo,
            pch=pch,
            be=be,
            zpo=zpo,
            om=om,
            os_=os_,
        )
        return _rfc(t, wrow, **k, dbl=True), _rfc(t, wrow, **k, dbl=False)

    return f


@dataclass
class RoundingProbeResult:
    op_label: str
    tflite_op_names: list[str]
    rounding: str
    detail: str


def _cls(t, ed, es):
    if t == ed != es:
        return "double"
    if t == es != ed:
        return "single"
    if ed == es:
        return "indeterminate"
    return "unexpected"


def probe_mul():
    i1 = tf.keras.Input(shape=(1,), dtype=tf.float32)
    i2 = tf.keras.Input(shape=(1,), dtype=tf.float32)
    m = tf.keras.Model([i1, i2], tf.keras.layers.Multiply()([i1, i2]))
    intr = _intr(_c8(m, _rep2c).convert())
    intr.allocate_tensors()
    ins, outs = intr.get_input_details(), intr.get_output_details()
    z0, s0 = _zs(ins[0])
    z1, s1 = _zs(ins[1])
    zo, so = _zs(outs[0])

    def g(a, b):
        return _rmul(a, b, z0, z1, zo, s0, s1, so, True), _rmul(
            a, b, z0, z1, zo, s0, s1, so, False
        )

    fk = _sc2(g)
    on = _onames(intr)
    if fk is None:
        return RoundingProbeResult("MUL", on, "indeterminate", "no int8 fork")
    a, b, ed, es = fk
    intr.set_tensor(ins[0]["index"], np.array([[a]], np.int8))
    intr.set_tensor(ins[1]["index"], np.array([[b]], np.int8))
    intr.invoke()
    t = int(intr.get_tensor(outs[0]["index"])[0, 0])
    return RoundingProbeResult("MUL", on, _cls(t, ed, es), f"a={a} b={b} tfl={t} d={ed} s={es}")


def _probe_addsub(sub: bool):
    name = "SUB" if sub else "ADD"
    i1 = tf.keras.Input(shape=(1,), dtype=tf.float32)
    i2 = tf.keras.Input(shape=(1,), dtype=tf.float32)
    layer = tf.keras.layers.Subtract if sub else tf.keras.layers.Add
    m = tf.keras.Model([i1, i2], layer()([i1, i2]))
    intr = _intr(_cv8(m, _rep2w))
    intr.allocate_tensors()
    ins, outs = intr.get_input_details(), intr.get_output_details()
    t = _prep_elm(*_zs(ins[0]), *_zs(ins[1]), *_zs(outs[0]))

    def g(a, b):
        return _relm(a, b, t, True, sub), _relm(a, b, t, False, sub)

    fk = _sc2(g)
    on = _onames(intr)
    if fk is None:
        return RoundingProbeResult(name, on, "indeterminate", "no int8 fork")
    a, b, ed, es = fk
    intr.set_tensor(ins[0]["index"], np.array([[a]], np.int8))
    intr.set_tensor(ins[1]["index"], np.array([[b]], np.int8))
    intr.invoke()
    out = int(intr.get_tensor(outs[0]["index"])[0, 0])
    return RoundingProbeResult(name, on, _cls(out, ed, es), f"a={a} b={b} tfl={out} d={ed} s={es}")


def probe_add():
    return _probe_addsub(False)


def probe_sub():
    return _probe_addsub(True)


def probe_mean():
    inp = tf.keras.Input(shape=(2,), dtype=tf.float32)
    m = tf.keras.Model(
        inp,
        tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(inp),
    )
    intr = _intr(_cv8(m, _rep12))
    intr.allocate_tensors()
    ins, outs = intr.get_input_details(), intr.get_output_details()
    zi, si = _zs(ins[0])
    zo, so = _zs(outs[0])
    mm, sh = _mean_ms(si, so, 2)

    def g(v0, v1):
        tp = (v0, v1)
        return _rmean(tp, zi, zo, mm, sh, True), _rmean(tp, zi, zo, mm, sh, False)

    fk = _scn(2, g)
    on = _onames(intr)
    if fk is None:
        return RoundingProbeResult("MEAN (axis=-1, n=2)", on, "indeterminate", "no fork")
    tup, ed, es = fk
    intr.set_tensor(ins[0]["index"], np.array([list(tup)], np.int8))
    intr.invoke()
    t = int(intr.get_tensor(outs[0]["index"])[0, 0])
    return RoundingProbeResult(
        "MEAN (axis=-1, n=2)", on, _cls(t, ed, es), f"vals={tup} tfl={t} d={ed} s={es}"
    )


def _wquant(w_t, w, out_c, odim):
    wq = w_t["quantization_parameters"]
    sc = np.asarray(wq["scales"], np.float64).ravel()
    zp = np.asarray(wq["zero_points"], np.int32).ravel()
    pch = sc.size > 1
    if pch and sc.size != odim:
        raise ValueError(f"scales {sc.size} != od {odim}")
    if pch:
        return 0, float(sc[out_c]), True
    return -int(zp[0]), float(sc[0]), False


def probe_fully_connected():
    inp = tf.keras.Input(shape=(2,), dtype=tf.float32)
    m = tf.keras.Model(inp, tf.keras.layers.Dense(1, use_bias=True)(inp))
    intr = _intr(_cv8(m, _rep12))
    intr.allocate_tensors()
    ai, wi, bi = _fc_ins(intr)
    ins, outs = intr.get_input_details()[0], intr.get_output_details()[0]
    on = _onames(intr)
    tm = {t["index"]: t for t in intr.get_tensor_details()}
    w_t, w = tm[wi], np.asarray(intr.get_tensor(wi), np.int8)
    if w.ndim != 2:
        return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", f"W rank {w.ndim}")
    od, ad = int(w.shape[0]), int(w.shape[1])
    if ins["index"] != ai:
        return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", "in idx")
    try:
        fo, swe, pch = _wquant(w_t, w, 0, od)
    except ValueError as e:
        return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", str(e))
    if ad != 2:
        return RoundingProbeResult("FULLY_CONNECTED (2→1)", on, "unexpected", f"adim {ad}")
    zi = int(ins["quantization_parameters"]["zero_points"][0])
    si = float(ins["quantization_parameters"]["scales"][0])
    zo = int(outs["quantization_parameters"]["zero_points"][0])
    so = float(outs["quantization_parameters"]["scales"][0])
    om, osh = _qmul(si * swe / so)
    be = int(np.asarray(intr.get_tensor(bi)).ravel()[0]) if bi is not None else 0
    g = _mac2(w[0, :], io=-zi, fo=fo, pch=pch, be=be, zpo=zo, om=om, os_=osh)
    fk = _scn(2, g)
    lb = "FULLY_CONNECTED (2→1)"
    if fk is None:
        return RoundingProbeResult(lb, on, "indeterminate", "no fork")
    tup, ed, es = fk
    intr.set_tensor(ins["index"], np.array([list(tup)], np.int8))
    intr.invoke()
    t = int(intr.get_tensor(outs["index"])[0, 0])
    return RoundingProbeResult(lb, on, _cls(t, ed, es), f"vals={tup} tfl={t} d={ed} s={es}")


def _probe_conv_dw(name, op, rep, wexp, wrow_fn, pack_in):
    intr = _intr(_cv8(op, rep))
    intr.allocate_tensors()
    lb = name
    on = _onames(intr)
    try:
        ai, wi, bi = _conv_ins(intr, name.split()[0])
    except RuntimeError as e:
        return RoundingProbeResult(lb, on, "unexpected", str(e))
    ins, outs = intr.get_input_details()[0], intr.get_output_details()[0]
    if ins["index"] != ai:
        return RoundingProbeResult(lb, on, "unexpected", "in idx")
    tm = {t["index"]: t for t in intr.get_tensor_details()}
    w_t, w = tm[wi], np.asarray(intr.get_tensor(wi), np.int8)
    if tuple(w.shape) != wexp:
        return RoundingProbeResult(lb, on, "unexpected", f"W {w.shape}")
    wrow = wrow_fn(w)
    zi = int(ins["quantization_parameters"]["zero_points"][0])
    si = float(ins["quantization_parameters"]["scales"][0])
    zo = int(outs["quantization_parameters"]["zero_points"][0])
    so = float(outs["quantization_parameters"]["scales"][0])
    try:
        fo, swe, pch = _wquant(w_t, w, 0, int(w.shape[0] if name.startswith("CONV") else w.shape[3]))
    except ValueError as e:
        return RoundingProbeResult(lb, on, "unexpected", str(e))
    om, osh = _qmul(si * swe / so)
    be = int(np.asarray(intr.get_tensor(bi)).ravel()[0]) if bi is not None else 0
    g = _mac2(wrow, io=-zi, fo=fo, pch=pch, be=be, zpo=zo, om=om, os_=osh)
    fk = _scn(2, g)
    if fk is None:
        return RoundingProbeResult(lb, on, "indeterminate", "no fork")
    tup, ed, es = fk
    intr.set_tensor(ins["index"], np.asarray(pack_in(tup), np.int8))
    intr.invoke()
    t = int(intr.get_tensor(outs["index"])[0, 0, 0, 0])
    return RoundingProbeResult(lb, on, _cls(t, ed, es), f"vals={tup} tfl={t} d={ed} s={es}")


def probe_conv2d():
    inp = tf.keras.Input(shape=(1, 1, 2), dtype=tf.float32)
    op = tf.keras.Model(inp, tf.keras.layers.Conv2D(1, (1, 1), padding="valid", use_bias=True)(inp))
    return _probe_conv_dw(
        "CONV_2D (1×1, 2→1)",
        op,
        _rep_nhwc12,
        (1, 1, 1, 2),
        lambda w: w[0, 0, 0, :].copy(),
        lambda t: np.array([[[[t[0], t[1]]]]]),
    )


def probe_depthwise_conv2d():
    inp = tf.keras.Input(shape=(2, 1, 1), dtype=tf.float32)
    op = tf.keras.Model(
        inp,
        tf.keras.layers.DepthwiseConv2D((2, 1), padding="valid", depth_multiplier=1, use_bias=True)(
            inp
        ),
    )
    return _probe_conv_dw(
        "DEPTHWISE_CONV_2D (2×1, 1ch)",
        op,
        _rep_nhwc21,
        (1, 2, 1, 1),
        lambda w: w[0, :, 0, 0].copy(),
        lambda t: np.array([[[[t[0]]], [[t[1]]]]]),
    )


def main():
    probes = (
        probe_mul,
        probe_add,
        probe_sub,
        probe_mean,
        probe_fully_connected,
        probe_conv2d,
        probe_depthwise_conv2d,
    )
    rows = []
    for fn in probes:
        r = fn()
        print(
            f"--- {r.op_label} ---\nGraph ops: {r.tflite_op_names}\n"
            f"Result: {r.rounding.upper()}\n  {r.detail}\n"
        )
        rows.append((r.op_label.split()[0], r.rounding))
    print("Summary:")
    for name, rnd in rows:
        print(f"  {name:22}  {rnd}")


if __name__ == "__main__":
    main()