CANN-昇腾NPU-LoRA微调-显存只占5%怎么做到的
全量微调 Llama2-7B 需要更新 7B 参数,显存开销约 80GB。LoRA 只训练 0.5% 的参数(约 35M),显存开销约 4GB。在昇腾NPU上 LoRA 微调是性价比最高的方案。
LoRA 原理
在原始权重 W 旁边加一个低秩矩阵 ΔW = A × B:
原始:y = Wx LoRA:y = (W + AB)x = Wx + ABx A: [hidden, r], B: [r, hidden], r << hidden Llama2-7B: r=16, A=[4096,16], B=[16,4096] 每层 LoRA 参数: 2 × 4096 × 16 = 128K 全量参数: 4096 × 4096 = 16M 参数比: 0.8%训练时 W 冻结,只更新 A 和 B。推理时把 ΔW 合并回 W,零额外延迟。
昇腾NPU上的实现
frompeftimportLoraConfig,get_peft_modelimporttorchimporttorch_npu model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",torch_dtype=torch.bfloat16,device_map="npu:0")lora_config=LoraConfig(r=16,lora_alpha=32,target_modules=["q_proj","v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj"],lora_dropout=0.05,)model=get_peft_model(model,lora_config)model.print_trainable_parameters()# 输出: trainable params: 35M || all params: 6.7B || trainable%: 0.52%显存对比
Llama2-7B 全量微调 vs LoRA 微调(单卡 Atlas 800I A2,64GB 显存):
| 项目 | 全量微调 | LoRA 微调 |
|---|---|---|
| 模型权重 | 14GB(可冻结) | 14GB(冻结) |
| LoRA 权重 | - | 0.2GB |
| 梯度 | 14GB(全量) | 0.2GB(仅 LoRA) |
| 优化器状态 | 28GB(全量 fp32) | 0.4GB(仅 LoRA) |
| 激活 | 20GB | 20GB |
| 总计 | 76GB | 34.8GB |
LoRA 微调的显存不到全量微调的一半。单卡 64GB 就能跑,不需要 8 卡 TP。
训练速度
| 配置 | 吞吐 (tokens/s) | 训练时间(1M tokens) |
|---|---|---|
| 全量微调(8 卡 TP) | 1,800 | 9 分钟 |
| LoRA(单卡) | 1,200 | 14 分钟 |
LoRA 单卡比全量 8 卡慢 55%。但如果你只有一张卡,LoRA 是唯一选择。
推理时合并 LoRA
# 方法 1:动态合并(每次推理时计算 Wx + ABx)# 优点:可以热切换不同 LoRA# 缺点:多一次矩阵乘法,约 3-5ms 额外延迟# 方法 2:预合并(训练完成后 W = W + AB,删除 A 和 B)# 优点:推理零开销# 缺点:不能切换 LoRAmodel=model.merge_and_unload()# 预合并预合并后的模型跟原始模型完全一样——ATB 的推理加速、FlashAttention、量化全部适用。
多 LoRA 推理
同一个基座模型挂多个 LoRA,不同请求用不同 LoRA:
fromatbimportLLM,MultiLoraConfig model=LLM("meta-llama/Llama-2-7b-hf",device="npu:0",multi_lora=MultiLoraConfig(lora_dirs=["lora_chat","lora_code","lora_translate"],max_loras=3,))# 不同请求用不同 LoRAresults=model.generate([("Hello","lora_chat"),("def fib(n):","lora_code"),("Translate:","lora_translate"),])ATB 的 Multi-LoRA 把多个 LoRA 的 ΔW 打包成 Batch GEMM,一次计算所有 LoRA 的增量。这比逐个 LoRA 计算高效得多。
精度影响
LoRA 的精度损失取决于 r 值(秩):
| r 值 | 可训练参数 | 微调精度损失 | 适用场景 |
|---|---|---|---|
| 8 | 17M | 0.5-1% | 简单任务(分类) |
| 16 | 35M | 0.1-0.3% | 通用(对话、代码) |
| 32 | 70M | <0.1% | 复杂任务(翻译、推理) |
| 64 | 140M | 几乎无损失 | 全量微调替代 |
r=16 是性价比最高的选择。除非任务特别复杂,否则不需要 r>32。
LoRA 微调是昇腾NPU上最实用的微调方案——显存少、速度快、合并后零推理开销。如果你的场景是"基座模型 + 领域适配",LoRA 几乎总是比全量微调更好的选择。仓库在这里:
https://atomgit.com/cann/torch_npu
