generate_int8 function
The generate_int8
function is designed for quantizing weights of neural network models from floating-point precision (e.g., FP32 or FP16) to INT8 precision and computing various scaling factors essential for the quantization process and subsequent inference.
This function is particularly tailored for General Matrix Multiply (GEMM) operations, which are core to deep learning computations, especially in transformers-based models. Here’s a detailed explanation of its components and functionalities:
Purpose and Process
The function serves two main purposes:
Quantizing Weights: It converts model weights to INT8 format, which reduces memory footprint and can speed up inference on compatible hardware. The function supports either per-tensor or per-column (per-channel) quantization.
Computing Scaling Factors: It calculates several scaling factors needed to adjust the quantized weights and activations during inference, ensuring that the quantization process minimizes loss of accuracy.
Parameters
weights
: The original floating-point weights of the model or a layer that need to be quantized.act_range
: A dictionary containing the ranges (maximum absolute values) of activations ("x"
), weights ("w"
), and possibly outputs ("y"
), used to determine scaling factors.is_qkv
: A boolean indicating if the weights belong to a QKV (Query, Key, Value) projection layer, common in transformer models. QKV layers have specific quantization considerations.multi_query_mode
: A boolean that, when combined withis_qkv
, indicates a special mode of handling multiple queries simultaneously, affecting how quantization is applied.
Quantization and Scaling Logic
The function initially detaches and moves the weights to CPU as NumPy arrays for processing.
It differentiates between standard layers and QKV projection layers, applying a specific quantization strategy for each. For QKV layers, it treats the combined QKV matrix as three separate matrices, each with potentially different scaling factors.
For non-QKV or standard QKV layers, it computes a global (per-tensor) or local (per-column) scaling factor based on the provided activation range. For multi-query QKV layers, it computes separate scaling factors for Q, K, and V projections based on their respective activation ranges.
Scaling factors are computed for both directions: floating-point to INT8 (
scale_w_orig_quant
) and INT8 back to floating-point (scale_w_quant_orig
), along with specific scaling factors needed for GEMM operations using either CUTLASS or CUBLAS APIs. CUTLASS requires separate scaling for activations and weights, while CUBLAS uses a combined scaling factor since it does not support per-row scaling.The function accounts for different scaling requirements when using tensor or pipeline parallelism and adjusts scaling factors accordingly to ensure consistent model behavior across different numbers of GPUs.
Outputs
The function returns a dictionary containing:
Quantized weights in INT8 format (
"weight.int8"
) for both global and column-specific quantization.The computed scaling factors necessary for adjusting inputs, weights, and outputs during inference with quantized models. These factors are crucial for maintaining the accuracy of the model after quantization.
Importance of Quantization
Quantizing model weights and activations to INT8
Reduces Model Size: Lower precision weights significantly reduce the memory footprint, making deployment on edge devices more feasible.
Increases Inference Speed: Many modern CPUs and GPUs have specialized instructions for INT8 arithmetic, leading to faster computations compared to higher precision formats.
Requires Careful Scaling: To preserve model accuracy, precise scaling factors must be applied during inference. This function automates the calculation of these factors based on the ranges of weights and activations.
Last updated