当前位置: 首页 > news >正文

BN / LN / RMSNorm

BN / LN / RMSNorm 归一化方法总结

一、背景与动机

深度网络训练中常见问题:

  • 梯度消失 / 梯度爆炸
  • 不同层输入分布变化(Internal Covariate Shift)
  • 收敛慢、训练不稳定

👉 归一化(Normalization)的核心目标:

将特征标准化到稳定分布,加速训练并提升模型稳定性。BN、LN 和 RMSNorm 的本质区别在于归一化维度与是否中心化,其中 BN 依赖 batch,LN 按特征归一化,而 RMSNorm 仅做尺度归一化以提升效率与稳定性。

统一形式:

x ^ = x − μ σ 2 + ϵ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}x^=σ2+ϵxμ

再进行仿射变换:

y = γ x ^ + β y = \gamma \hat{x} + \betay=γx^+β


二、Batch Normalization(BN)

1. 核心思想

batch 维度 + 空间维度做归一化(常用于 CNN):

μ c = E B , H , W [ x ] \mu_c = \mathbb{E}_{B,H,W}[x]μc=EB,H,W[x]

σ c 2 = Var B , H , W [ x ] \sigma_c^2 = \text{Var}_{B,H,W}[x]σc2=VarB,H,W[x]

x ^ = x − μ c σ c 2 + ϵ \hat{x} = \frac{x - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}x^=σc2+ϵxμc


2. 推理阶段(重要)

使用滑动平均:

μ r u n n i n g = ( 1 − m ) μ + m μ b a t c h \mu_{running} = (1-m)\mu + m\mu_{batch}μrunning=(1m)μ+mμbatch

σ r u n n i n g 2 = ( 1 − m ) σ 2 + m σ b a t c h 2 \sigma^2_{running} = (1-m)\sigma^2 + m\sigma^2_{batch}σrunning2=(1m)σ2+mσbatch2


3. 特点

  • 依赖 batch size
  • 训练 / 推理行为不同
  • 适合 CNN

三、Layer Normalization(LN)

1. 核心思想

单个样本的特征维度做归一化(Transformer 常用):

μ = E C [ x ] \mu = \mathbb{E}_{C}[x]μ=EC[x]

σ 2 = Var C [ x ] \sigma^2 = \text{Var}_{C}[x]σ2=VarC[x]

x ^ = x − μ σ 2 + ϵ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}x^=σ2+ϵxμ


2. 特点

  • 不依赖 batch size
  • 训练 / 推理一致
  • 适合 NLP / Transformer

四、RMSNorm(Root Mean Square Norm)

1. 核心思想

只做缩放,不减均值:

R M S ( x ) = E [ x 2 ] RMS(x) = \sqrt{\mathbb{E}[x^2]}RMS(x)=E[x2]

x ^ = x E [ x 2 ] + ϵ \hat{x} = \frac{x}{\sqrt{\mathbb{E}[x^2] + \epsilon}}x^=E[x2]+ϵx

y = γ x ^ y = \gamma \hat{x}y=γx^


2. 特点

  • 去掉均值中心化(更简单)
  • 计算更快
  • 在大模型中表现良好(如 LLaMA)

五、三者对比(面试重点)

方法归一化维度是否减均值是否依赖Batch训练/推理差异典型应用
BNB, H, WCNN
LNCTransformer
RMSNormC大模型

六、本质区别总结

1. 归一化维度不同

  • BN:跨样本
  • LN / RMSNorm:单样本

2. 是否中心化(减均值)

  • BN / LN:有
  • RMSNorm:无

3. 数学表达差异

  • BN / LN:

x − μ σ \frac{x - \mu}{\sigma}σxμ

  • RMSNorm:

x E [ x 2 ] \frac{x}{\sqrt{\mathbb{E}[x^2]}}E[x2]x


七、代码实现

# NOTE BN/LN/RMSNormimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFclassBatchNorm(nn.Module):# BN is usually used for CNN, and the input dimensions are B, C, H, W.def__init__(self,channels_dim,eps=1e-5,momentum=0.1):super().__init__()self.eps=eps self.momentum=momentum# NOTE: momentum is the update speed of running_mean and running_varself.register_buffer('running_mean',torch.zeros(1,channels_dim,1,1))self.register_buffer('running_var',torch.ones(1,channels_dim,1,1))self.gamma=nn.Parameter(torch.ones(1,channels_dim,1,1))self.beta=nn.Parameter(torch.zeros(1,channels_dim,1,1))defforward(self,x):ifself.training:mean=x.mean(dim=[0,2,3],keepdim=True)# B,C,H,W -> 1,C,1,1var=x.var(dim=[0,2,3],keepdim=True,unbiased=False)# B,C,H,W -> 1,C,1,1# update running statsself.running_mean=(1-self.momentum)*self.running_mean+self.momentum*mean self.running_var=(1-self.momentum)*self.running_var+self.momentum*varelse:mean=self.running_mean var=self.running_var x_normed=(x-mean)/torch.sqrt(var+self.eps)out=self.gamma*x_normed+self.betareturnoutclassLayerNorm(nn.Module):# LN is usually used for RNN/Transformer, and the input dimensions are B, L, C.def__init__(self,channels_dim,eps=1e-5):super().__init__()self.eps=eps self.gamma=nn.Parameter(torch.ones(1,1,channels_dim))self.beta=nn.Parameter(torch.zeros(1,1,channels_dim))defforward(self,x):mean=x.mean(dim=-1,keepdim=True)# B,L,C -> B,L,1var=x.var(dim=-1,keepdim=True,unbiased=False)# B,L,C -> B,L,1x_normed=(x-mean)/torch.sqrt(var+self.eps)out=self.gamma*x_normed+self.betareturnoutclassRMSNorm(nn.Module):# RMSNorm is a variant of LN, which only normalizes the variance and does not normalize the mean.# It is usually used for RNN/Transformer, and the input dimensions are B, L, C.def__init__(self,channels_dim,eps=1e-5):super().__init__()self.eps=eps self.gamma=nn.Parameter(torch.ones(1,1,channels_dim))defforward(self,x):rms=torch.mean(x**2,dim=-1,keepdim=True)# B,L,C -> B,L,1x_normed=x/torch.sqrt(rms+self.eps)out=self.gamma*x_normedreturnoutif__name__=="__main__":x=torch.rand(10,5,768)LN=LayerNorm(768)x_LN=LN(x)print(x_LN.shape)RMSN=RMSNorm(768)x_RMS=RMSN(x)print(x_RMS.shape)cnn_x=torch.rand(4,12,512,512)BN=BatchNorm(12)x_BN=BN(cnn_x)print(x_BN.shape)

http://www.jsqmd.com/news/633802/

相关文章:

  • 终极生物图像分析指南:如何用CellProfiler自动处理数千张图像
  • Rust的Pin类型:理解自引用结构体的安全固定
  • 设计企业级SKILL的7个最佳实战原则
  • 高效截图工具对比:Snipaste与FastStone Capture的实战应用
  • Finereport报表导出进阶:利用JS与URL参数实现Sheet页的精准筛选与导出
  • 软件范围管理中的需求变更控制
  • OpCore Simplify终极指南:5分钟搞定Hackintosh EFI配置,小白也能轻松上手
  • IINA播放器完整指南:macOS专业视频播放解决方案深度解析
  • Performance-Fish:让《环世界》流畅度提升400%的终极性能优化方案
  • 云容笔谈·东方红颜影像生成系统实战:为游戏角色批量生成古风立绘
  • 微波管参数全解析:高能辐射
  • BIThesis 3.7.0更新指南:北京理工大学研究生论文格式规范升级解析
  • 精通猫抓扩展:7个高级配置与流媒体解析实战技巧
  • 项目介绍 MATLAB实现基于RNN-XGBoost-CNN 递归神经网络(RNN)结合极限梯度提升(XGBoost)与卷积神经网络(CNN)进行股票价格预测的详细项目实例(含模型描述及部分示例代码)
  • 全球压缩机式家用冰淇淋机市场分析报告
  • Seaborn调色板实战:从数据特征到视觉表达的配色艺术
  • GEE实战指南:Sentinel-2多光谱植被指数批量计算与生态监测应用
  • 快速将HDRI转换为立方体贴图的终极免费工具指南
  • AIGlasses OS Pro AI编程助手实践:自动生成图像处理代码
  • 2026年4月AI爆发周:阿里连推三款模型、字节全双工语音上线,国内大模型进入“落地竞速“新阶段
  • Realtek USB网卡驱动深度解析:群晖NAS网络性能提升实战指南
  • 如何用QMCDecode快速解密QQ音乐加密音频文件:免费Mac工具完整指南
  • 关于串和代码的应用(涉及BF算法、KMP算法)
  • 遵义广和巧手名车维修电话多少?2026年官方联系方式与靠谱指南 - 精选优质企业推荐榜
  • Qwen3-Embedding 模型融合实战:Slerp 技术在跨领域任务中的优化策略
  • WarcraftHelper终极指南:5分钟让魔兽争霸3重获新生
  • GLM-4.1V-9B-Base高算力适配教程:双GPU分层加载与显存优化详解
  • 配置管理方案环境变量与配置文件
  • GLM-4.1V-9B-Base多模态内容审核效果实测:精准识别违规图片与文本
  • gte-base-zh实战:用Python代码调用API实现智能文本相似度计算