量化
Tensorrt的量化
tensorrt 的训练后量化过程主要包含两步:
- 准备量化数据,500份左右。这部分我们将进入后端的数据按照Pytorch的格式保存下来
- 以准备好的数据启动量化过程,并生成模型。
详细可参见示例.
训练时量化(基于pytorch_quantization)
在官方notebook中,NVIDIA总结了如何在Pytorch中通过训练时量化提升量化精度。
本文档从以下两个方面对如何处理backbone进行补充汇总:
- int8下改善层融合提升性能的方式
- 更多的模型示例
- torchpipe
如果需要直接可用的方案,请转至resnet50示例.
对于检测模型,可以考虑直接使用针对yolov7的官方完整教程.
量化参数的预训练
训练时量化除了需要正常训练的模型提供预训练参数,也需要训练后量化(ptq)提供量化的预训练参数。
我们集成了calib_tools,可做参考.
- 定义calibrater:
calib=calib_tools.Calibrator("mse") # 可选max(量化过程最快,不推荐) mse(一般准确率较好) entropy percentile
此步骤后,所有对torch.nn.Conv2D
等层的调用都将被劫持到对应的量化版本如quant_nn.Conv2D
.
- 修改backbone定义
为了更好的量化,对模型的定义可能需要针对性的修改。我们将逐渐预置部分修改好的backbone.
q_model = ResidualBlock(64)
量化卷积和残差相加的融合
ResidualBlock: 对于Add层,输入均是int8时,能比输入分别为fp32和int8更高效。
class ResidualBlock(nn.Module):
def __init__(self, num_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, bias=False)
self.quant = quant
if torch.nn.Conv2d is quant_nn.Conv2d:
self.residual_quantizer = quant_nn.TensorQuantizer(quant_nn.QuantConv2d.default_quant_desc_input)
self.relu1 = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv1(x)
if hasattr(self, "residual_quantizer"):
out = self.relu1(self.residual_quantizer(x)+out)
else:
out = self.relu1(x+out)
- 准备训练集,进行量化:
calib.calibrate(q_model, train_dataloader, num_batches=16)
训练时量化(QAT)流程
如果经过ptq流程的模型精度未达到要求,可以在ptq结果上进行finetune. 此步骤在ptq的基础上正常finetune大约1/10 总的epoch即可。然后保存为onnx:
calib_tools.save_onnx(q_model, f"model_name_qat.onnx")
Resnet
备注
官方的训练格式非常简单,仅仅是为了用作示例。
QAT 结果
不改变backbone,直接量化
仿照官方示例,我们对resnet进行了分步骤实验:
- 下载训练数据:代码
- 训练10个epoch获得resnet50模型:代码, 精度98.44%
- (optinal)pytorch ptq:代码, 精度96.64%(max)
- (optinal)pytorch qat:代码, 精度98.26%.
mse + 残差融合
以上resnet的训练,采用max方式量化,并且没有对Add进行融合,导致tensorrt运行速度未达预期。以下将Add在int8下进行融合并换用mse模式后的结果:
pytorch下结果汇总
Model | Accuracy | Performance | 备注 |
---|---|---|---|
Baseline resnet50 | 98.44% | 2.11674ms (0.982373ms for int8) | 固定学习率训练40epoch(未充分训练) |
ptq resnet50 | 96.64%(max) 98.41%(mse) | 1.33484ms | 充分训练且模型能力足够的情况下ptq可能精度降低非常小 |
qat resnet50 | 98.26% | 1.38074ms | fine-tune了2epoch |
qat resnet50 + 残差融合 | 98.62% | 1.03164ms | fine-tune了2epoch, 残差融合 |
torchpipe下测试结果汇总
以下使用torchpipe加载生成的onnx进行测试:
Model | Accuracy | Performance | 备注 |
---|---|---|---|
tensorrt's fp32 | 98.44% | - | |
tensorrt's native int8 | 98.26% | - | |
qat | 98.67% | - | onnxruntime下精度为98.69%。 |