Why Seperate pruning and QAT APIs in model optimization toolkit?

Happily using the tfmot in various projects for a while now… but one aspect of the current design puzzles the heck out of me. Why are there disjoint APIs from QAT and pruning?

After all:

  • You can model pruning as “just another kind of Quantizer” (one that maps some values to 0.0 and leaves the rest unchanged).
  • Though less vital than for pruning , supporting scheduling of the degree of quantization applied by Quantizers can be useful in QAT (esp. when quantizing down to sub 8=bit bit-widths).
  • Pruning-as-a-kind-of-QAT also avoids the need for a special-cases two-step training processs if QAT and pruning are to be combined.

A quick PoC implementation (Quantizer composition operator + Incremental Pruning “Quantizer”) created to simplifiy porting models using a legacy internal library seems to work just fine.

On a similar note: the pruning API seems go to some trouble to prune by over-writing pruned weight variables rather than simply “injecting” a masking operation (with straight-through estimator for gradient). Surely, due to constant folding applied when models are optimized for inference (e.g. tflite converter) the “end-result” of masking would be the same for less coding effort?

What am I missing? Anyone from the tfmot Team(s) care to shed any light?

Hi @andrew_stevens1,

Even though Pruning and Quantization techniques are used to reduce the model size and computational cost for deploying on resource constrained devices but they operate in different manner.

Pruning main intention is that it removes the weights that have minimal impact on the model’s performance and it makes the neural network sparse. Typically, pruning is performed after training.
whereas the quantization process reduces precision of weights and activations by converting floating-point to lower-precision data types (e.g., 8-bit, 4-bit). Quantization can be applied during training(QAT) or after training(PTQ).

Overwriting pruned weights makes the network permanent sparser rather than masking operation.

Thank You

Thanks for responding but the query was at rather more detailed technical/Sw-architectural level than this. I’m well aware of what pruning/quantization are (and most mainstream and SOTA methods for that matter).

  • Masking/overwriting are equivalent once constant-folding / optimization in the deployment toolchain is taken into account.
  • Pruning can be peformed after training but this wholly suboptimal. Effective pruning requires at least fine-tuning/re-training to allow the network to adapt to pruned weights(plenty of literature on this). The implementation machinery needed ‘under the hood’ isvery similar to that for QAT.