Skip to main content

Quantization

TensorRT Quantization

The post-training quantization process of TensorRT mainly includes two steps:

  • Prepare quantization data, about 500 pieces. We will save the data in PyTorch format in the backend.
  • Launch the quantization process with the prepared data and generate the model.

For more details, please refer to the example.

Training-based Quantization (based on pytorch_quantization)

In the official notebook, NVIDIA summarizes how to improve quantization accuracy through training-based quantization in PyTorch.

This document supplements and summarizes how to handle the backbone in the following two aspects:

  1. Ways to improve layer fusion and performance under int8.
  2. More model examples.
  3. TorchPipe.

If you need a directly usable solution, please refer to the resnet50 example.

For detection models, you can consider using the official complete tutorial for yolov7.

Pre-training of Quantization Parameters

In addition to the pre-training parameters provided by the model for normal training, training-based quantization also requires quantization pre-training parameters provided by post-training quantization (ptq).

We have integrated calib_tools for reference.

  • Define calibrator:
calib = calib_tools.Calibrator("mse") # Optional: max (fastest quantization process, not recommended), mse (generally better accuracy), entropy, percentile

After this step, all calls to layers such as torch.nn.Conv2D will be hijacked to their corresponding quantized versions, such as quant_nn.Conv2D.

  • Modify backbone definition

To achieve better quantization, the model definition may need to be modified accordingly. We will gradually provide some modified backbones as presets.

q_model = ResidualBlock(64)
Merging Quantized Convolution and Residual Addition

ResidualBlock: For the Add layer, when the inputs are both int8, it is more efficient than when the inputs are fp32 and int8 separately.

class ResidualBlock(nn.Module):
def __init__(self, num_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
self.quant = quant
if torch.nn.Conv2d is quant_nn.Conv2d:
self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
self.relu1 = nn.ReLU(inplace=True)

def forward(self, x):
out = self.conv1(x)
if hasattr(self, "residual_quantizer"):
out = self.relu1(self.residual_quantizer(x)+out)
else:
out = self.relu1(x+out)
  • Prepare the training set and perform quantization:
calib.calibrate(q_model, train_dataloader, num_batches=16)

Training-based Quantization (QAT) Process

If the model accuracy after the ptq process does not meet the requirements, you can fine-tune it based on the ptq results. This step requires about 1/10 of the total epochs for normal fine-tuning on top of ptq. Then save it as onnx:

calib_tools.save_onnx(q_model, f"model_name_qat.onnx")

Resnet

note

The official training format is very simple and is only used as an example.

QAT Results

Direct Quantization without Modifying Backbone

Following the official example, we conducted step-by-step experiments on resnet:

  • Download training data: code
  • Train for 10 epochs to obtain the resnet50 model: code, accuracy 98.44%
  • (optional) PyTorch ptq: code, accuracy 96.64% (max)
  • (optional) PyTorch qat: code, accuracy 98.26%.

MSE + Residual Fusion

The above resnet training uses the max quantization method and does not fuse the Add layer, resulting in TensorRT running speed not meeting expectations. The following are the results after fusing Add under int8 and switching to the mse mode:

  • ptq: code, accuracy 94.34% (mse)
  • qat: code, accuracy 95.82%.

Summary of Results in PyTorch

ModelAccuracyPerformanceNote
Baseline resnet5098.44%2.11674ms (0.982373ms for int8)Trained with fixed learning rate for 40 epochs (insufficient training)
ptq resnet5096.64%(max) 98.41%(mse)1.33484msWhen the model is sufficiently trained and has sufficient capacity, ptq may have very little decrease in accuracy
qat resnet5098.26%1.38074msFine-tuned for 2 epochs
qat resnet50 + Residual Fusion98.62%1.03164msFine-tuned for 2 epochs, with residual fusion

Summary of Test Results in TorchPipe

The following tests were performed using the onnx generated by TorchPipe:

  • Export onnx: code
  • Load fp32-onnx with TorchPipe and perform ptq: code
  • Test with qat-onnx loaded with TorchPipe: code
ModelAccuracyPerformanceNote
tensorrt's fp3298.44%-
tensorrt's native int898.26%-
qat98.67%-Acc. under onnxruntime is 98.69%.