当前位置: 首页 > news >正文

Transformer 位置编码深入解析:从正弦编码到 RoPE、ALiBi

Transformer 位置编码深入解析:从正弦编码到 RoPE、ALiBi

1. 引言

Transformer 的自注意力机制是置换不变的——打乱输入 token 的顺序,输出不变。位置编码(Positional Encoding)是让模型知道"谁在前、谁在后"的关键。本文将从最初的正弦编码到现代的 RoPE、ALiBi,系统讲解位置编码的演进。

位置编码分类:

类型代表特点
绝对位置编码正弦编码、可学习编码加到输入嵌入上
相对位置编码ALiBi、Relative PE编码 token 间距离
旋转位置编码RoPE注意力分数中注入位置
无显式编码CNN 隐式位置依赖结构归纳偏置

2. 正弦位置编码(Sinusoidal)

2.1 原理

PE(pos, 2i) = sin(pos / 10000^(2i/d)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d)) 其中: pos = token 在序列中的位置 i = 嵌入维度的索引 d = 嵌入维度

2.2 实现

importtorchimportmathclassSinusoidalPositionalEncoding(torch.nn.Module):def__init__(self,d_model,max_len=5000):super().__init__()pe=torch.zeros(max_len,d_model)position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))pe[:,0::2]=torch.sin(position*div_term)pe[:,1::2]=torch.cos(position*div_term)pe=pe.unsqueeze(0)# (1, max_len, d_model)self.register_buffer('pe',pe)defforward(self,x):# x: (batch, seq_len, d_model)returnx+self.pe[:,:x.size(1)]

2.3 特点

  • 优点:无需训练,可泛化到任意长度
  • 缺点:绝对位置编码,与注意力计算分离
  • 关键性质:PE(pos+k) 可以表示为 PE(pos) 的线性变换

3. 可学习位置编码

classLearnedPositionalEncoding(torch.nn.Module):def__init__(self,d_model,max_len=512):super().__init__()self.embedding=torch.nn.Embedding(max_len,d_model)defforward(self,x):positions=torch.arange(x.size(1),device=x.device)returnx+self.embedding(positions)

ViT、GPT-2 使用此方案。简单有效,但无法泛化到训练时未见过的长度。

4. RoPE(Rotary Position Embedding)

4.1 核心思想

RoPE 不直接修改输入嵌入,而是在注意力计算时注入位置信息

q_m = R(m) · q // 查询向量旋转 m 个位置 k_n = R(n) · k // 键向量旋转 n 个位置 q_m · k_n = (R(m)·q) · (R(n)·k) = q · R(m-n)·k 结果只依赖相对位置 (m-n)!

4.2 二维旋转

对于二维向量 [x, y],旋转角度 θ: R(θ) = [cos θ -sin θ] [sin θ cos θ] R(θ)·[x, y] = [x·cos θ - y·sin θ, x·sin θ + y·cos θ]

4.3 高维 RoPE 实现

classRotaryPositionalEmbedding:"""RoPE 实现"""def__init__(self,dim,base=10000):# 计算频率inv_freq=1.0/(base**(torch.arange(0,dim,2).float()/dim))self.register_buffer('inv_freq',inv_freq)def_build_cache(self,seq_len,device):t=torch.arange(seq_len,device=device).float()freqs=torch.outer(t,self.inv_freq.to(device))# cos 和 sinself._cos_cached=freqs.cos()self._sin_cached=freqs.sin()def_apply_rotary(self,x,cos,sin):"""应用旋转"""# x: (batch, heads, seq_len, head_dim)d=x.size(-1)//2x1,x2=x[...,:d],x[...,d:]# 扩展 cos/sin 维度cos=cos[:x.size(2),:].unsqueeze(0).unsqueeze(0)sin=sin[:x.size(2),:].unsqueeze(0).unsqueeze(0)# 旋转out1=x1*cos-x2*sin out2=x1*sin+x2*cosreturntorch.cat([out1,out2],dim=-1)defforward(self,q,k,seq_len):self._build_cache(seq_len,q.device)q_rot=self._apply_rotary(q,self._cos_cached,self._sin_cached)k_rot=self._apply_rotary(k,self._cos_cached,self._sin_cached)returnq_rot,k_rot# 简化版 RoPE(实际使用)defapply_rope(q,k,freqs_cis):"""freqs_cis: 预计算的复数频率"""# 转为复数q_complex=torch.view_as_complex(q.float().reshape(*q.shape[:-1],-1,2))k_complex=torch.view_as_complex(k.float().reshape(*k.shape[:-1],-1,2))# 乘以旋转因子q_out=torch.view_as_real(q_complex*freqs_cis).flatten(-2)k_out=torch.view_as_real(k_complex*freqs_cis).flatten(-2)returnq_out.type_as(q),k_out.type_as(k)

4.4 RoPE 长度外推

# NTK-aware 缩放defapply_ntk_scaling(base,dim,factor):"""NTK-aware RoPE 缩放,支持更长上下文"""new_base=base*factor**(dim/(dim-2))inv_freq=1.0/(new_base**(torch.arange(0,dim,2).float()/dim))returninv_freq# YaRN(Yet another RoPE extensioN)defyarn_rope(dim,base,factor,beta_fast=32,beta_slow=1):"""YaRN 动态缩放"""# 计算每个频率维度的缩放因子inv_freq=1.0/(base**(torch.arange(0,dim,2).float()/dim))freqs=1.0/inv_freq# 找到需要缩放的维度low=dim*math.log(beta_slow)/math.log(base)high=dim*math.log(beta_fast)/math.log(base)scale=torch.ones(dim//2)foriinrange(dim//2):iffreqs[i]<low:scale[i]=1.0/factoreliffreqs[i]>high:scale[i]=1.0else:# 线性插值t=(freqs[i]-low)/(high-low)scale[i]=(1-t)/factor+treturninv_freq*scale

5. ALiBi(Attention with Linear Biases)

5.1 原理

ALiBi 完全不用位置编码,而是在注意力分数上加一个线性偏置

Attention(Q, K) = softmax(Q·K^T / √d + m · bias) bias[i,j] = -|i - j| // 距离越远,惩罚越大 m = 斜率,每个头不同(几何级数)

5.2 实现

importtorchimportmathdefalibi_bias(num_heads,max_len):"""生成 ALiBi 偏置矩阵"""# 斜率:几何级数 2^(-8/n), 2^(-16/n), ...slopes=2**(-8*torch.arange(1,num_heads+1)/num_heads)# 距离矩阵positions=torch.arange(max_len)distance=positions.unsqueeze(0)-positions.unsqueeze(1)# (max_len, max_len)distance=distance.abs().float()# 偏置 = 斜率 × 距离bias=-slopes.unsqueeze(1).unsqueeze(2)*distance.unsqueeze(0)returnbias# (num_heads, max_len, max_len)classALiBiAttention(torch.nn.Module):def__init__(self,d_model,num_heads,max_len=4096):super().__init__()self.num_heads=num_heads self.head_dim=d_model//num_heads self.q_proj=torch.nn.Linear(d_model,d_model)self.k_proj=torch.nn.Linear(d_model,d_model)self.v_proj=torch.nn.Linear(d_model,d_model)self.o_proj=torch.nn.Linear(d_model,d_model)# 预计算 ALiBi 偏置self.register_buffer('alibi',alibi_bias(num_heads,max_len))defforward(self,x):B,L,_=x.shape q=self.q_proj(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)k=self.k_proj(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)v=self.v_proj(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)# 注意力 + ALiBiattn=(q @ k.transpose(-2,-1))/math.sqrt(self.head_dim)attn=attn+self.alibi[:,:L,:L].unsqueeze(0)attn=attn.softmax(dim=-1)out=(attn @ v).transpose(1,2).reshape(B,L,-1)returnself.o_proj(out)

6. 各方案对比

方案外推能力训练开销推理开销长序列效果
正弦编码有限一般
可学习编码
RoPE需缩放
ALiBi天然外推
YaRN RoPE强外推很好

主流选择:

  • Llama/Qwen/DeepSeek→ RoPE(配 YaRN 长度外推)
  • BLOOM/MPT→ ALiBi
  • GPT 系列→ 可学习绝对位置编码

7. 总结

位置编码的核心演进:

  1. 正弦编码:开创性工作,但与注意力分离
  2. RoPE:在注意力分数中注入相对位置,成为主流
  3. ALiBi:最简单方案,线性偏置即可,天然支持外推
  4. YaRN:RoPE 的长度外推方案,支持 128K+ 上下文
http://www.jsqmd.com/news/1059823/

相关文章:

  • League Akari:英雄联盟智能助手如何提升你的游戏体验5倍?
  • 基于Playwright与AI的闲鱼智能监控机器人:自动化抓取与语义分析实战
  • 解密pyautocad架构:Python驱动AutoCAD自动化的工程化策略
  • DLSS Swapper完全指南:一站式管理游戏DLSS文件,让NVIDIA显卡性能最大化
  • Seedance 2.0:多模态视频生成协议层解析
  • 终极指南:如何用OmenSuperHub彻底掌控惠普游戏本性能与散热
  • 5大SillyTavern关键技术故障深度解析与实战修复
  • DeepSeek R1技术报告深度解析:大模型数据配方与训练工艺
  • 0622晨间日记
  • 居家办公曲面屏选购指南:人体工学与视觉舒适度实战解析
  • import/export不是语法糖:JavaScript模块系统底层原理
  • OpenClaw:本地AI工作流编排工具与中文封装实践
  • 国产GPU实现大模型Day-0推理:摩尔线程SGLang-MUSA深度解析
  • 基于Z-Score的TinyML异常检测系统设计与实现
  • MobX + React Native 实战避坑指南:SafeAreaProvider 与 observer 渲染优化
  • Seedance 2.0:多模态AI视频创作的即梦工作流
  • 如何用开源工具永久保存你的数字记忆:从聊天记录到年度报告
  • Apollo配置加密实战:从Jasypt集成到KMS密钥管理
  • ERNIE-Image 8B:中文文生图模型的精准文字渲染实践
  • Vue 3可复用分页组件设计:契约驱动、服务端/客户端双模式与BEM样式解耦
  • 跨架构兼容技术突破:Box64实现ARM设备高效运行x86_64程序的完整解决方案
  • 【飞机】自主无人机飞行稳定和轨迹跟踪Matlab实现
  • 网盘下载速度慢怎么办?从PanDownload解析到kdown实测
  • DeepSeek V4国产化实测:MXFP4与TileLang技术解析
  • Kimi K2.6 Agent调度原理:从胶水代码到生产级资源纳管
  • Nginx平滑升级实战:零中断热替换二进制原理与落地
  • Trae Spec/Plan模式:结构化AI编程新范式
  • AI编程最后一公里:从写代码到懂工程上下文
  • 电力系统稳定性分析新范式:数据驱动与分布式认证技术详解
  • 【船舶】基于mrDMD和Koopman理论的数据驱动船舶运动分析附Matlab代码