CANN-昇腾NPU-模型压缩-剪枝和蒸馏怎么用
模型压缩有三板斧:量化(已讲)、剪枝(Pruning)、蒸馏(Distillation)。剪枝去掉不重要的权重,蒸馏用小模型学大模型的输出。在昇腾NPU上,剪枝和蒸馏可以联合使用,把 7B 模型压到 3B 精度损失 ❤️%。
剪枝(Pruning)
去掉权重矩阵中绝对值最小的连接(结构化剪枝去掉整个神经元或 attention head)。
fromtorch_npu.contribimportprune model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",torch_dtype=torch.bfloat16,device_map="npu:0",)# 结构化剪枝:去掉 25% 的 attention headpruner=prune.StructuredPruner(model,pruning_ratio=0.25,# 剪掉 25% 的 headcriterion="l1_norm",# 按 L1 范数排序,小的去掉blocks=["self_attn"],# 只剪 attention)# 剪枝pruned_model=pruner.prune()# 微调恢复精度train(pruned_model,dataloader,epochs=3)# 保存剪枝后的模型(结构变了,需要特殊格式)pruner.save("llama2-7b-pruned-25pct.pt")剪枝后的模型结构变了(attention head 数减少),不能直接用标准 LLM 加载。需要 ATB 支持动态 head 数——目前 CANN 8.5 还不支持,需要等后续版本。
蒸馏(Distillation)
大模型(Teacher)指导小模型(Student)训练:
importtorchimporttorch_npufromtransformersimportAutoModelForCausalLM# Teacher: Llama2-13Bteacher=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf",torch_dtype=torch.bfloat16,device_map="npu:0,1",# 13B 需要 2 卡)# Student: Llama2-7Bstudent=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",torch_dtype=torch.bfloat16,device_map="npu:0",)# 蒸馏训练optimizer=torch.optim.AdamW(student.parameters(),lr=1e-5)forbatchindataloader:# Teacher 前向(不计算梯度)withtorch.no_grad():teacher_logits=teacher(batch["input_ids"]).logits# Student 前向student_logits=student(batch["input_ids"]).logits# 蒸馏损失:KL 散度T=2.0# 温度loss=torch.nn.functional.kl_div(torch.nn.functional.log_softmax(student_logits/T,dim=-1),torch.nn.functional.softmax(teacher_logits/T,dim=-1),reduction="batchmean",)*(T*T)loss.backward()optimizer.step()optimizer.zero_grad()蒸馏 + 剪枝联合
先蒸馏训练一个较大的 Student(如 7B),再剪枝到更小(如 4B),最后微调恢复精度:
步骤 1:蒸馏(13B → 7B),训练 3 轮 步骤 2:结构化剪枝(7B → 4B),去掉 40% 的 FFN 神经元 步骤 3:用蒸馏损失继续微调 1 轮,恢复精度Llama2-13B → 7B → 4B 的精度损失:
| 阶段 | WNLI (准确率) | GSM8K (准确率) | 模型大小 |
|---|---|---|---|
| Teacher (13B) | 78.5% | 56.2% | 26GB |
| Student (7B,蒸馏) | 76.8% (-1.7%) | 54.1% (-2.1%) | 14GB |
| Pruned (4B) | 74.2% (-4.3%) | 51.3% (-4.9%) | 8GB |
| Pruned + 微调 | 75.9% (-2.6%) | 53.7% (-2.5%) | 8GB |
最终 4B 模型精度损失 2.6%,大小减到原来的 30%。
昇腾NPU上的加速
蒸馏训练需要同时跑 Teacher 和 Student(前向 + 反向)。显存需求:
Teacher (13B, fp16): 26GB(冻结,不存梯度) Student (7B, fp16): 14GB + 28GB(优化器)= 42GB 激活: 10GB 总计: 26 + 42 + 10 = 78GB → 需要 2 卡(128GB)13B Teacher 放在 2 卡 TP,7B Student 放在 1 卡,剩余 1 卡做梯度累积。
跟量化的配合
蒸馏/剪枝后的模型再做量化,进一步压缩:
原始 7B: 14GB (fp16) → 蒸馏到 4B: 8GB (fp16) → W4A16 量化: 2.5GB (int4)2.5GB 的模型可以放到 Jetson Orin(32GB 显存)上跑推理,延迟 < 50ms/token。
剪枝和蒸馏是模型压缩的高阶技巧。剪枝在 CANN 8.5 上支持有限(动态 head 数待支持),蒸馏已经可以完整跑通。两者联合使用,7B → 4B 精度损失 ❤️%。仓库在这里:
https://atomgit.com/cann/torch_npu
