Quantization
TensorRT Quantization
The post-training quantization process of TensorRT mainly includes two steps:
- Prepare quantization data, about 500 pieces. We will save the data in PyTorch format in the backend.
- Launch the quantization process with the prepared data and generate the model.
For more details, please refer to the example.
Training-based Quantization (based on pytorch_quantization)
In the official notebook, NVIDIA summarizes how to improve quantization accuracy through training-based quantization in PyTorch.
This document supplements and summarizes how to handle the backbone in the following two aspects:
- Ways to improve layer fusion and performance under int8.
- More model examples.
- TorchPipe.
If you need a directly usable solution, please refer to the resnet50 example.
For detection models, you can consider using the official complete tutorial for yolov7.
Pre-training of Quantization Parameters
In addition to the pre-training parameters provided by the model for normal training, training-based quantization also requires quantization pre-training parameters provided by post-training quantization (ptq).
We have integrated calib_tools for reference.
- Define calibrator:
calib = calib_tools.Calibrator("mse") # Optional: max (fastest quantization process, not recommended), mse (generally better accuracy), entropy, percentile
After this step, all calls to layers such as torch.nn.Conv2D
will be hijacked to their corresponding quantized versions, such as quant_nn.Conv2D
.
- Modify backbone definition
To achieve better quantization, the model definition may need to be modified accordingly. We will gradually provide some modified backbones as presets.
q_model = ResidualBlock(64)
Merging Quantized Convolution and Residual Addition
ResidualBlock: For the Add layer, when the inputs are both int8, it is more efficient than when the inputs are fp32 and int8 separately.
class ResidualBlock(nn.Module):
def __init__(self, num_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
self.quant = quant
if torch.nn.Conv2d is quant_nn.Conv2d:
self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
self.relu1 = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv1(x)
if hasattr(self, "residual_quantizer"):
out = self.relu1(self.residual_quantizer(x)+out)
else:
out = self.relu1(x+out)
- Prepare the training set and perform quantization:
calib.calibrate(q_model, train_dataloader, num_batches=16)
Training-based Quantization (QAT) Process
If the model accuracy after the ptq process does not meet the requirements, you can fine-tune it based on the ptq results. This step requires about 1/10 of the total epochs for normal fine-tuning on top of ptq. Then save it as onnx:
calib_tools.save_onnx(q_model, f"model_name_qat.onnx")
Resnet
The official training format is very simple and is only used as an example.
QAT Results
Direct Quantization without Modifying Backbone
Following the official example, we conducted step-by-step experiments on resnet:
- Download training data: code
- Train for 10 epochs to obtain the resnet50 model: code, accuracy 98.44%
- (optional) PyTorch ptq: code, accuracy 96.64% (max)
- (optional) PyTorch qat: code, accuracy 98.26%.
MSE + Residual Fusion
The above resnet training uses the max quantization method and does not fuse the Add layer, resulting in TensorRT running speed not meeting expectations. The following are the results after fusing Add under int8 and switching to the mse mode:
Summary of Results in PyTorch
Model | Accuracy | Performance | Note |
---|---|---|---|
Baseline resnet50 | 98.44% | 2.11674ms (0.982373ms for int8) | Trained with fixed learning rate for 40 epochs (insufficient training) |
ptq resnet50 | 96.64%(max) 98.41%(mse) | 1.33484ms | When the model is sufficiently trained and has sufficient capacity, ptq may have very little decrease in accuracy |
qat resnet50 | 98.26% | 1.38074ms | Fine-tuned for 2 epochs |
qat resnet50 + Residual Fusion | 98.62% | 1.03164ms | Fine-tuned for 2 epochs, with residual fusion |
Summary of Test Results in TorchPipe
The following tests were performed using the onnx generated by TorchPipe:
- Export onnx: code
- Load fp32-onnx with TorchPipe and perform ptq: code
- Test with qat-onnx loaded with TorchPipe: code
Model | Accuracy | Performance | Note |
---|---|---|---|
tensorrt's fp32 | 98.44% | - | |
tensorrt's native int8 | 98.26% | - | |
qat | 98.67% | - | Acc. under onnxruntime is 98.69%. |