Issue Replicating TF-Lite Conv2D Quantized Inference Output

Issue Replicating TF-Lite Conv2D Quantized Inference Output

I am trying to reproduce the exact layer-wise output of a quantized EfficientNet model (TFLite model, TensorFlow 2.17) by re-implementing Conv2D, DepthwiseConv2D, FullyConnected, Add, Mul, Sub and Mean operations.
All layers match except Conv2D, where I consistently see a mismatch of -1 for about ~1% of values compared to TFLite’s output in many layers (not all).

The layer outputs are obtained from:

  • interpreter.invoke()

  • interpreter._get_ops_details()

  • interpreter.get_tensor(interpreter._get_tensor_details(_input, 0)["index"])

Example of mismatch

Input SRM Output DRM Output TF-Lite Output
5907 12 13 13
-19651 -31 -31 -30
-12264 -46 -46 -46

Quantization parameters

Iz (input zero-point) = -127
Oz (output zero-point) = -3
Is (input scale) = 0.15857075154781342
Os (output scale) = 0.5111709833145142
Fs (filter scale) = 0.010584672912955284
bS (bias scale) = 0.0016784195322543383
M (real multiplier) = 0.00328347971662879
M0 (quantized multiplier) = 1805112064
shift = -8
bias = -885

Quantized multiplication implementation

I implemented TFLite’s core math primitives, including:

  • SaturatingRoundingDoublingHighMul

  • RoundingDivideByPOT

  • Single-rounding (SRM) and double-rounding (DRM) variants

(Full code included below.)

Conv2D implementation

Conv2D uses the standard convolution formula:

acc = sum((x - Iz) * (w - Fz)) + bias
acc = MultiplyByQuantizedMultiplier(acc, M[channel], shift[channel])
acc = acc + Oz
acc = clamp(acc, act_min, act_max)

(Full code included below.)

Problem

DepthwiseConv2D, which uses the same quantized-multiplier code, matches exactly with TF-Lite. But Conv2D does not, and I see a ~1% discrepancy in many values.

Why does Conv2D mismatch TF-Lite outputs while DepthwiseConv2D matches perfectly?

There is a difference in the exact TFLite implementation, but I haven’t found it yet. We need help figuring it out.

Codes

QuantizationFunction is implemented to emulate multiplication by quantized multiplier. The code is:

class QuantizationUtils:
    INT32_MIN = -2**31
    INT32_MAX =  2**31 - 1
    bit = 31


    @staticmethod
    def quantize_multiplier_smaller_than_one(real_multiplier):
        """
        Given a real multiplier in (0, 1) or > 1,
        compute the quantized int32 multiplier and shift
        that approximates real_multiplier ≈ multiplier * 2^shift.
        """


        if real_multiplier == 0:
            return 0, 0


        shift = 0
        significand = real_multiplier


        # Normalize to [0.5, 1)
        while significand < 0.5:
            significand *= 2.0
            shift -= 1
        while significand >= 1.0:
            significand /= 2.0
            shift += 1


        # Convert to fixed-point 32-bit representation (Q31 format)
        q = int(round(significand * (1 << QuantizationUtils.bit)))


        if q == (1 << QuantizationUtils.bit):
            q //= 2
            shift += 1


        return q, shift


    @staticmethod
    def saturating_rounding_doubling_high_mul(a: int, b: int) -> int:
        """Python equivalent of TFLite's SaturatingRoundingDoublingHighMul."""
        # Special overflow case: INT32_MIN * INT32_MIN → INT32_MAX
        if a == QuantizationUtils.INT32_MIN and b == QuantizationUtils.INT32_MIN:
            return QuantizationUtils.INT32_MAX


        # Perform 64-bit multiplication
        a_64 = int(a)
        b_64 = int(b)
        ab_64 = a_64 * b_64


        # Rounding offset (nudge) according to sign
        nudge = (1 << 30) if ab_64 >= 0 else (1 - (1 << 30))


        # Multiply by 2 and extract high 32 bits (simulate doubling_high_mul)
        result = (ab_64 + nudge) >> 31


        # Saturate to int32 range
        if result > QuantizationUtils.INT32_MAX:
            result = QuantizationUtils.INT32_MAX
        elif result < QuantizationUtils.INT32_MIN:
            result = QuantizationUtils.INT32_MIN


        return int(result)


    @staticmethod
    def rounding_divide_by_pot(x: int, exponent: int) -> int:
        """Equivalent to TFLite's RoundingDivideByPOT (round-to-nearest, ties away from zero)."""
        assert 0 <= exponent <= 31


        mask = (1 << exponent) - 1
        remainder = x & mask
        threshold = (mask >> 1)


        if x < 0:
            threshold += 1


        # Round to nearest integer, ties away from zero
        result = (x >> exponent) + (1 if (remainder > threshold) else 0)


        return result


    @staticmethod
    def multiply_by_quantized_multiplier_DRM(x: int, quantized_multiplier: int, shift: int) -> int:
        """
        Simulates the core quantized multiply in TFLite:
        result = x * quantized_multiplier * 2^shift, with rounding and saturation.
        """
        # shift < 0 means right shift (division by power of two)
        # shift > 0 means left shift (multiplication by power of two)
        if shift < 0:
            return QuantizationUtils.rounding_divide_by_pot(
                QuantizationUtils.saturating_rounding_doubling_high_mul(x, quantized_multiplier), -shift
            )
        else:
            return QuantizationUtils.saturating_rounding_doubling_high_mul(x, quantized_multiplier) * (1 << shift)
       
   
    @staticmethod
    def multiply_by_quantized_multiplier_SRM(x: int, quantized_multiplier: int, shift: int) -> int:
        """
        Simulates the core quantized multiply in TFLite:
        result = x * quantized_multiplier * 2^shift, with rounding and saturation.
        """
        # shift < 0 means right shift (division by power of two)
        # shift > 0 means left shift (multiplication by power of two)
        # Perform 64-bit multiplication
        a_64 = int(x)
        b_64 = int(quantized_multiplier)
        ab_64 = a_64 * b_64


        if shift < 0:
            return QuantizationUtils.rounding_divide_by_pot(
                ab_64, 31-shift
            )
        else:
            return QuantizationUtils.rounding_divide_by_pot(ab_64, 31+shift)


    @staticmethod
    def mul_by_quantized_multiplier_smaller_than_one_exp(x, M, rshift, impl="single"):
        """
        Convenience wrapper; set impl="single" to do single rounding,
        or impl="double" for double rounding.
        """
        if impl == "single":
            return QuantizationUtils.multiply_by_quantized_multiplier_SRM(x, M, rshift)
        elif impl == "double":
            return QuantizationUtils.multiply_by_quantized_multiplier_DRM(x, M, rshift)
        else:
            raise ValueError("impl must be 'single' or 'double'")

The code for convolution (Conv2D) is:

def call(self, input):
    """
    Performs the forward pass of the Conv2D layer


    acc = sum((x - Iz) * (w - Fz)) + bias
    acc = MultiplyByQuantizedMultiplier(acc, M[channel], shift[channel])
    acc = acc + Oz
    acc = clamp(acc, act_min, act_max)
    """
    if self.verbose:
        print("Running Conv2D Layer (TFLite compatible)\n--------------------")


    # === 1. Pad and prepare input ===
    input = self.input_padding(input)
    _, self.H_in_pad, self.W_in_pad, _ = input.shape
    self.input_shape_pad = (self.H_in_pad, self.W_in_pad)


    # Output size calculation
    self.output_shape = tuple(
        ((x - y) // z + 1)
        for x, y, z in zip(self.input_shape_pad, self.filter_shape, self.strides)
    )
    self.H_out, self.W_out = self.output_shape


    # === 2. Initialize output ===
    final_output = np.zeros((self.batch, self.H_out, self.W_out, self.C_out_filter), dtype=np.int32)


    # === 3. Perform convolution per output channel ===
    for channel in range(self.C_out_filter):
        # Extract per-channel filter and bias
        filter_c = self.filter[:, :, :, channel]
        bias_c = int(self.bias[channel])


        # Per-channel multiplier & shift
        M_c = self.M[channel]
        shift_c = self.shift[channel]


        for b in range(self.batch):
            for i in range(self.H_out):
                for j in range(self.W_out):
                    # Region slice
                    h_start = i * self.H_stride
                    w_start = j * self.W_stride
                    h_end = h_start + self.H_filter
                    w_end = w_start + self.W_filter


                    region = input[b, h_start:h_end, w_start:w_end, :]


                    # === 4. True TFLite accumulation ===
                    # Subtract input and weight zero-points before multiply
                    acc = np.sum((region - self.Iz) * (filter_c)) + bias_c


                    # === 5. Apply quantized scaling ===
                    acc = QuantizationUtils.mul_by_quantized_multiplier_smaller_than_one_exp(
                        acc, M_c, shift_c, impl="single"
                    )


                    # === 6. Add output zero-point and clamp ===
                    acc += self.Oz


                    # ReLU6 or ReLU (clamp to activation range)
                    if self.relu_flag:
                        acc = np.maximum(acc, self.Oz)
                    acc = np.clip(acc, self.act_min, self.act_max)


                    final_output[b, i, j, channel] = acc


    return final_output.astype(np.int8)
1 Like