CANN-昇腾NPU-量化训练-QAT和PTQ怎么选
模型量化有两种时机:训练时做(QAT,Quantization-Aware Training)和训练后做(PTQ,Post-Training Quantization)。在昇腾NPU上,QAT 用 torch_npu 的量化感知训练,PTQ 用 CANN 的 AMCT 工具。这篇讲清楚两者的适用场景和操作步骤。
PTQ:训练后量化
PTQ 不需要重新训练,直接把 fp16 模型量化成 int8/w8a8。适合快速上线、不想重新训练的场景。
fromamct_npuimportcreate_quant_config,quantize_model# 1. 准备校准数据集(100-1000 条代表性数据)calib_dataloader=get_calib_dataloader(num_samples=500)# 2. 创建量化配置config=create_quant_config(model_file="model.onnx",config_file="./quant_config.json",dst_json_path="./quant_ready.json",)# 3. 校准(跑一遍校准数据,统计激活分布)quant_model=quantize_model(model_file="model.onnx",quant_config_file="./quant_config.json",calib_dataloader=calib_dataloader,)# 4. 导出量化模型quant_model.export_quant_onnx("model_quant.onnx")PTQ 的关键:校准数据集要跟真实推理数据分布一致。用训练集做校准,推理时分布不同,精度损失会放大。
QAT:量化感知训练
QAT 在训练时模拟量化误差,让模型"适应"量化。精度损失比 PTQ 小 30-50%,但需要重新训练。
importtorchfromtorch_npu.contribimportQATWrapper model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",torch_dtype=torch.bfloat16,device_map="npu:0",)# 包装成 QAT 模型qat_model=QATWrapper(model,qconfig={"weight":"int8","activation":"int8","quantize_per_tensor":True,})# 正常训练(QAT 在 forward 时插入伪量化节点)optimizer=torch.optim.AdamW(qat_model.parameters(),lr=1e-5)fordataindataloader:loss=qat_model(data)loss.backward()optimizer.step()# 训练完成后转成真正量化模型quant_model=torch.ao.quantization.convert(qat_model)torch.save(quant_model.state_dict(),"model_qat.pt")精度损失对比
Llama2-7B,CANN 8.5,Atlas 800I A2:
| 量化方案 | WNLI (准确率) | GSM8K (准确率) | 推理速度 |
|---|---|---|---|
| fp16 (基准) | 78.5% | 56.2% | 1.0× |
| PTQ int8 | 76.1% (-2.4%) | 53.8% (-2.4%) | 1.8× |
| QAT int8 | 77.9% (-0.6%) | 55.6% (-0.6%) | 1.8× |
| PTQ int4 | 68.2% (-10.3%) | 44.1% (-12.1%) | 2.5× |
| QAT int4 | 74.8% (-3.7%) | 51.3% (-4.9%) | 2.5× |
QAT 的精度损失只有 PTQ 的 1/4。如果精度敏感(评测集、生产环境),优先 QAT。
选择建议
| 场景 | 推荐方案 | 理由 |
|---|---|---|
| 快速原型验证 | PTQ | 不需要训练,10 分钟完成 |
| 生产环境,精度敏感 | QAT | 精度损失小,训练成本可接受 |
| 显存严重不足 | PTQ int4 | 权重 4bit,显存减半 |
| 已有训练流水线 | QAT | 插入 QAT wrapper 即可,改动小 |
跟 ATB 的配合
ATB 的 LLM 接口直接支持量化模型:
fromatbimportLLM# PTQ 量化模型model_ptq=LLM("model_quant.onnx",device="npu:0",quantize="w8a8",# 对应 PTQ 的配置)# QAT 量化模型model_qat=LLM("model_qat.pt",device="npu:0",quantize="w8a8_qat",)ATB 内部会自动调用对应的量化 GEMM kernel。w8a8 的 GEMM 吞吐是 fp16 的 1.8-2.0×。
PTQ 快但精度损失大,QAT 慢但精度高。如果你的模型要上生产,多花 1-2 天做 QAT 是值得的。PTQ 适合快速验证和显存极度受限的场景。仓库在这里:
https://atomgit.com/cann/AMCT
