当前位置: 首页 > news >正文

大模型微调--MoELora

文章目录

      • MOELoRA 的核心组件
      • MOE 在多任务学习中的作用
      • LoRA 在参数高效微调中的贡献
      • MOELoRA 的协同工作机制

https://arxiv.org/pdf/2310.18339
When MOE Meets LLMs: Parameter Efficient Fine-tuning for Multi-task Medical Applications


MOELoRA 的核心组件

MOELoRA 的核心思想建立在两个关键技术上:混合专家系统(MOE)和低秩自适应(LoRA)。MOE 负责处理多任务学习中的任务分配和专家协作,LoRA 则专注于参数高效的模型微调。

MOE 在多任务学习中的作用

MOE 结构通过动态路由机制将输入数据分配给不同的专家模块,每个专家专注于特定任务或数据子集。这种设计允许模型在不显著增加参数量的情况下,灵活处理多任务场景。MOE 的优势在于其能够根据任务复杂度自动调整专家资源的分配,提升模型在有限数据和计算资源下的表现。

LoRA 在参数高效微调中的贡献

LoRA 通过低秩矩阵分解技术,在预训练模型的基础上引入少量可训练参数,大幅降低微调阶段的资源消耗。具体实现中,LoRA 将权重更新 ΔW 分解为两个低秩矩阵的乘积(例如 ΔW = BA,其中 B 和 A 的秩远小于原权重矩阵)。这种方法既保留了预训练模型的知识,又实现了高效的任务适配。

MOELoRA 的协同工作机制

MOELoRA 将 MOE 的任务分配能力与 LoRA 的参数效率结合,形成分层优化结构。MOE 层负责识别任务类型并激活对应的专家模块,每个专家内部采用 LoRA 进行微调。这种设计既避免了多任务间的干扰,又通过共享基础模型参数减少了冗余。


https://github.com/liuqidong07/MOELoRA-peft/blob/master/src/MLoRA/peft/tuners/mmoelora.py

classMMOELoraLayer(LoraLayer):def__init__(self,in_features:int,out_features:int,expert_num:int):super().__init__(in_features,out_features)self.expert_num=expert_numdefupdate_layer(self,adapter_name,r,lora_alpha,lora_dropout,init_lora_weights):self.r[adapter_name]=r self.lora_alpha[adapter_name]=lora_alphaiflora_dropout>0.0:lora_dropout_layer=nn.Dropout(p=lora_dropout)else:lora_dropout_layer=nn.Identity()self.lora_dropout.update(nn.ModuleDict({adapter_name:lora_dropout_layer}))# Actual trainable parametersifr>0:self.lora_A.update(nn.ModuleDict({adapter_name:MMOELinearA(self.in_features,r,self.expert_num)}))self.lora_B.update(nn.ModuleDict({adapter_name:MMOELinearB(r,self.out_features,self.expert_num)}))self.scaling[adapter_name]=lora_alpha/rifinit_lora_weights:self.reset_lora_parameters(adapter_name)self.to(self.weight.device)defreset_lora_parameters(self,adapter_name):ifadapter_nameinself.lora_A.keys():# initialize A the same way as the default for nn.Linear and B to zeroforiinrange(self.expert_num):nn.init.normal_(self.lora_A[adapter_name].loraA[i].mlp.weight,mean=0.0,std=0.01)nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)classMMOELoraLinear(nn.Linear,MMOELoraLayer):# Lora implemented in a dense layer# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Loradef__init__(self,adapter_name:str,in_features:int,out_features:int,r:int=0,lora_alpha:int=1,lora_dropout:float=0.0,fan_in_fan_out:bool=False,# Set this to True if the layer to replace stores weight like (fan_in, fan_out)**kwargs,):init_lora_weights=kwargs.pop("init_lora_weights",True)self.expert_num=kwargs.pop("expert_num",True)self.task_num=kwargs.pop("task_num",True)self.te_dim=kwargs.pop("task_embedding_dim",True)nn.Linear.__init__(self,in_features,out_features,**kwargs)MMOELoraLayer.__init__(self,in_features=in_features,out_features=out_features,expert_num=self.expert_num)# init the Gate networkself.lora_task_embedding=nn.ModuleDict({})self.lora_gate=nn.ModuleDict({})self.lora_task_embedding.update(nn.ModuleDict({adapter_name:nn.Embedding(self.task_num+1,self.te_dim)}))self.lora_gate.update(nn.ModuleDict({adapter_name:Gate(self.te_dim,self.expert_num)}))# Freezing the pre-trained weight matrixself.weight.requires_grad=Falseself.fan_in_fan_out=fan_in_fan_outiffan_in_fan_out:self.weight.data=self.weight.data.T nn.Linear.reset_parameters(self)self.update_layer(adapter_name,r,lora_alpha,lora_dropout,init_lora_weights)self.active_adapter=adapter_namedefmerge(self,task_id):ifself.active_adapternotinself.lora_A.keys():returnifself.merged:warnings.warn("Already merged. Nothing to do.")returnifself.r[self.active_adapter]>0:expert_weight=self.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):lora_A_weights=self.lora_A[self.active_adapter].loraA[i].mlp.weight lora_B_weights=self.lora_B[self.active_adapter].loraB[i].mlp.weight self.weight.data+=(transpose(lora_B_weights @ lora_A_weights,self.fan_in_fan_out,)*self.scaling[self.active_adapter]*expert_weight[...,i])self.merged=Truedefunmerge(self,task_id):ifself.active_adapternotinself.lora_A.keys():returnifnotself.merged:warnings.warn("Already unmerged. Nothing to do.")returnifself.r[self.active_adapter]>0:expert_weight=self.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):lora_A_weights=self.lora_A[self.active_adapter].loraA[i].mlp.weight lora_B_weights=self.lora_B[self.active_adapter].loraB[i].mlp.weight self.weight.data-=(transpose(lora_B_weights @ lora_A_weights,self.fan_in_fan_out,)*self.scaling[self.active_adapter]*expert_weight[...,i])self.merged=Falsedefforward(self,x:torch.Tensor,**kwargs):task_id=kwargs["task_id"]previous_dtype=x.dtypeifself.active_adapternotinself.lora_A.keys():# No adapter, directly use linearreturnF.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)ifself.disable_adapters:# No adapterifself.r[self.active_adapter]>0andself.merged:# merge the adapter to linearself.unmerge(task_id)result=F.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)elifself.r[self.active_adapter]>0andnotself.merged:# general lora processresult=F.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)x=x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)expert_weight=self.lora_gate[self.active_adapter](self.lora_task_embedding[self.active_adapter](task_id))foriinrange(self.expert_num):result+=(# lora processself.lora_B[self.active_adapter].loraB[i](self.lora_A[self.active_adapter].loraA[i](self.lora_dropout[self.active_adapter](x)),)*self.scaling[self.active_adapter]*expert_weight[...,i].unsqueeze(-1).unsqueeze(0))else:result=F.linear(x,transpose(self.weight,self.fan_in_fan_out),bias=self.bias)result=result.to(previous_dtype)returnresultclassMMOELinearA(nn.Module):'''MMOE based LoRA block'''def__init__(self,in_features,out_features,expert_num)->None:super().__init__()self.expert_num=expert_num self.in_features,self.out_features=in_features,out_features self.loraA=nn.ModuleList([])assertself.out_features%self.expert_num==0# lora rank should be divided by expert numberself.r=self.out_features//self.expert_numfor_inrange(self.expert_num):self.loraA.append(Expert(self.in_features,self.r))defforward(self,x):'''input x is a vector, return output is a list'''outputs=[]foriinrange(self.expert_num):outputs.append(self.loraA[i](x))returnoutputsclassMMOELinearB(nn.Module):'''MMOE based LoRA block'''def__init__(self,in_features,out_features,expert_num)->None:super().__init__()self.expert_num=expert_num self.in_features,self.out_features=in_features,out_features self.loraB=nn.ModuleList([])assertself.in_features%self.expert_num==0self.r=self.in_features//self.expert_numfor_inrange(self.expert_num):self.loraB.append(Expert(self.r,self.out_features))defforward(self,x):'''input x is a list, return output is also a list'''outputs=[]foriinrange(self.expert_num):outputs.append(self.loraB[i](x[i]))returnoutputsclassExpert(nn.Module):def__init__(self,in_features,out_features):super().__init__()self.in_features,self.out_features=in_features,out_features self.mlp=nn.Linear(self.in_features,self.out_features,bias=False)self.weight=self.mlp.weightdefforward(self,x):# LoRA A or B blocky=self.mlp(x)returnyclassGate(nn.Module):def__init__(self,input_size,expert_num):super().__init__()# 使用embedding来代替线性层self.GateL=nn.Linear(input_size,expert_num,bias=False)self.act=nn.Softmax(dim=1)# 第0维为batch sizedefforward(self,x):y=self.GateL(x)y=self.act(y)returny
http://www.jsqmd.com/news/122248/

相关文章:

  • LangFlow内置模板库盘点:有哪些可以直接复用的场景?
  • LangFlow中的敏感信息保护机制:API密钥加密存储
  • 【大模型开发者必看】Open-AutoGLM重复生成难题:4个核心参数调优策略
  • graphrag简介
  • Open-AutoGLM触控失效怎么破?资深架构师教你4招精准排障
  • 掌握这5个调试技巧,轻松解决Open-AutoGLM字符输入异常问题
  • Open-AutoGLM输入法无法响应?5分钟快速诊断与恢复流程曝光
  • Open-AutoGLM去重机制深度剖析:如何用Top-k与Temperature控制输出稳定性
  • 2025年年终深圳家电搬运公司推荐:专业排行解析与多维度服务对比指南 - 十大品牌推荐
  • MyBatis-Plus与Druid企业级整合实战
  • 基于Hive的双十一淘宝美妆数据分析与可视化开题报告
  • LangFlow能否支持增量更新?部分节点重新执行机制
  • 别再被重复文本困扰!Open-AutoGLM输入清洗的7个关键步骤(独家实战经验)
  • LangFlow是否提供权限管理系统?多用户访问控制现状
  • 基于Hive的淘宝彩妆销售数据的设计与实现开题报告
  • LangFlow社区活跃度观察:文档、案例与问题响应速度
  • Open-AutoGLM触控响应中断如何解决:4个核心配置项必须检查
  • 为什么你的Open-AutoGLM处理不了@#$%?一文看懂字符转义机制
  • LangFlow与TypeScript项目集成时的类型兼容问题解决
  • PHP的用户态和内核态的庖丁解牛
  • Open-AutoGLM字符编码崩溃怎么办?资深架构师教你快速定位并修复
  • LangFlow工作流导出为API接口的操作步骤详解
  • 仅限内部流传的Open-AutoGLM调试秘技:触控无响应的7个隐藏原因(首次公开)
  • 2025年资深行业分析师推荐:当前最具实力的5家全球市场证明公司全方位对比 - 十大品牌推荐
  • LangFlow自定义组件开发教程:打造专属AI工作流模块
  • PHP网络/磁盘 I/O 远慢于 CPU的庖丁解牛
  • 为什么你的Open-AutoGLM手势不生效?资深架构师亲授排查清单
  • $urls = array_chunk($urls, ceil(count($urls)/$workers));的庖丁解牛
  • 杰理之使用数字mic做USB mic声音不够大【篇】
  • 2025年年终深圳家电搬运公司推荐:实力榜单TOP5与全方位服务对比评测 - 十大品牌推荐