Skip to main content

量化

Tensorrt的量化

tensorrt 的训练后量化过程主要包含两步:

  • 准备量化数据,500份左右。这部分我们将进入后端的数据按照Pytorch的格式保存下来
  • 以准备好的数据启动量化过程,并生成模型。

详细可参见示例.

训练时量化(基于pytorch_quantization)

官方notebook中,NVIDIA总结了如何在Pytorch中通过训练时量化提升量化精度。

本文档从以下两个方面对如何处理backbone进行补充汇总:

  1. int8下改善层融合提升性能的方式
  2. 更多的模型示例
  3. torchpipe

如果需要直接可用的方案,请转至resnet50示例.

对于检测模型,可以考虑直接使用针对yolov7的官方完整教程.

量化参数的预训练

训练时量化除了需要正常训练的模型提供预训练参数,也需要训练后量化(ptq)提供量化的预训练参数。

我们集成了calib_tools,可做参考.

  • 定义calibrater:
calib=calib_tools.Calibrator("mse") # 可选max(量化过程最快,不推荐) mse(一般准确率较好) entropy percentile

此步骤后,所有对torch.nn.Conv2D等层的调用都将被劫持到对应的量化版本如quant_nn.Conv2D.

  • 修改backbone定义

为了更好的量化,对模型的定义可能需要针对性的修改。我们将逐渐预置部分修改好的backbone.

q_model = ResidualBlock(64)
量化卷积和残差相加的融合

ResidualBlock: 对于Add层,输入均是int8时,能比输入分别为fp32和int8更高效。

class ResidualBlock(nn.Module):
def __init__(self, num_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
self.quant = quant
if torch.nn.Conv2d is quant_nn.Conv2d:
self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
self.relu1 = nn.ReLU(inplace=True)

def forward(self, x):
out = self.conv1(x)
if hasattr(self, "residual_quantizer"):
out = self.relu1(self.residual_quantizer(x)+out)
else:
out = self.relu1(x+out)
  • 准备训练集,进行量化:
calib.calibrate(q_model, train_dataloader, num_batches=16)

训练时量化(QAT)流程

如果经过ptq流程的模型精度未达到要求,可以在ptq结果上进行finetune. 此步骤在ptq的基础上正常finetune大约1/10 总的epoch即可。然后保存为onnx:

calib_tools.save_onnx(q_model, f"model_name_qat.onnx")

Resnet

备注

官方的训练格式非常简单,仅仅是为了用作示例。

QAT 结果

不改变backbone,直接量化

仿照官方示例,我们对resnet进行了分步骤实验:

  • 下载训练数据:代码
  • 训练10个epoch获得resnet50模型:代码, 精度98.44%
  • (optinal)pytorch ptq:代码, 精度96.64%(max)
  • (optinal)pytorch qat:代码, 精度98.26%.

mse + 残差融合

以上resnet的训练,采用max方式量化,并且没有对Add进行融合,导致tensorrt运行速度未达预期。以下将Add在int8下进行融合并换用mse模式后的结果:

  • ptq:代码, 精度94.34%(mse)
  • qat:代码, 精度95.82%

pytorch下结果汇总

ModelAccuracyPerformance备注
Baseline resnet5098.44%2.11674ms (0.982373ms for int8)固定学习率训练40epoch(未充分训练)
ptq resnet5096.64%(max) 98.41%(mse)1.33484ms充分训练且模型能力足够的情况下ptq可能精度降低非常小
qat resnet5098.26%1.38074msfine-tune了2epoch
qat resnet50 + 残差融合98.62%1.03164msfine-tune了2epoch, 残差融合

torchpipe下测试结果汇总

以下使用torchpipe加载生成的onnx进行测试:

  • 导出onnx:代码
  • 使用torchpipe加载fp32-onnx并进行ptq: 代码
  • 使用torchpipe加载qat-onnx进行测试: 代码
ModelAccuracyPerformance备注
tensorrt's fp3298.44%-
tensorrt's native int898.26%-
qat98.67%-onnxruntime下精度为98.69%。