Transformer 位置编码深入解析:从正弦编码到 RoPE、ALiBi
Transformer 位置编码深入解析:从正弦编码到 RoPE、ALiBi
1. 引言
Transformer 的自注意力机制是置换不变的——打乱输入 token 的顺序,输出不变。位置编码(Positional Encoding)是让模型知道"谁在前、谁在后"的关键。本文将从最初的正弦编码到现代的 RoPE、ALiBi,系统讲解位置编码的演进。
位置编码分类:
| 类型 | 代表 | 特点 |
|---|---|---|
| 绝对位置编码 | 正弦编码、可学习编码 | 加到输入嵌入上 |
| 相对位置编码 | ALiBi、Relative PE | 编码 token 间距离 |
| 旋转位置编码 | RoPE | 注意力分数中注入位置 |
| 无显式编码 | CNN 隐式位置 | 依赖结构归纳偏置 |
2. 正弦位置编码(Sinusoidal)
2.1 原理
PE(pos, 2i) = sin(pos / 10000^(2i/d)) PE(pos, 2i+1) = cos(pos / 10000^(2i/d)) 其中: pos = token 在序列中的位置 i = 嵌入维度的索引 d = 嵌入维度2.2 实现
importtorchimportmathclassSinusoidalPositionalEncoding(torch.nn.Module):def__init__(self,d_model,max_len=5000):super().__init__()pe=torch.zeros(max_len,d_model)position=torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))pe[:,0::2]=torch.sin(position*div_term)pe[:,1::2]=torch.cos(position*div_term)pe=pe.unsqueeze(0)# (1, max_len, d_model)self.register_buffer('pe',pe)defforward(self,x):# x: (batch, seq_len, d_model)returnx+self.pe[:,:x.size(1)]2.3 特点
- 优点:无需训练,可泛化到任意长度
- 缺点:绝对位置编码,与注意力计算分离
- 关键性质:PE(pos+k) 可以表示为 PE(pos) 的线性变换
3. 可学习位置编码
classLearnedPositionalEncoding(torch.nn.Module):def__init__(self,d_model,max_len=512):super().__init__()self.embedding=torch.nn.Embedding(max_len,d_model)defforward(self,x):positions=torch.arange(x.size(1),device=x.device)returnx+self.embedding(positions)ViT、GPT-2 使用此方案。简单有效,但无法泛化到训练时未见过的长度。
4. RoPE(Rotary Position Embedding)
4.1 核心思想
RoPE 不直接修改输入嵌入,而是在注意力计算时注入位置信息:
q_m = R(m) · q // 查询向量旋转 m 个位置 k_n = R(n) · k // 键向量旋转 n 个位置 q_m · k_n = (R(m)·q) · (R(n)·k) = q · R(m-n)·k 结果只依赖相对位置 (m-n)!4.2 二维旋转
对于二维向量 [x, y],旋转角度 θ: R(θ) = [cos θ -sin θ] [sin θ cos θ] R(θ)·[x, y] = [x·cos θ - y·sin θ, x·sin θ + y·cos θ]4.3 高维 RoPE 实现
classRotaryPositionalEmbedding:"""RoPE 实现"""def__init__(self,dim,base=10000):# 计算频率inv_freq=1.0/(base**(torch.arange(0,dim,2).float()/dim))self.register_buffer('inv_freq',inv_freq)def_build_cache(self,seq_len,device):t=torch.arange(seq_len,device=device).float()freqs=torch.outer(t,self.inv_freq.to(device))# cos 和 sinself._cos_cached=freqs.cos()self._sin_cached=freqs.sin()def_apply_rotary(self,x,cos,sin):"""应用旋转"""# x: (batch, heads, seq_len, head_dim)d=x.size(-1)//2x1,x2=x[...,:d],x[...,d:]# 扩展 cos/sin 维度cos=cos[:x.size(2),:].unsqueeze(0).unsqueeze(0)sin=sin[:x.size(2),:].unsqueeze(0).unsqueeze(0)# 旋转out1=x1*cos-x2*sin out2=x1*sin+x2*cosreturntorch.cat([out1,out2],dim=-1)defforward(self,q,k,seq_len):self._build_cache(seq_len,q.device)q_rot=self._apply_rotary(q,self._cos_cached,self._sin_cached)k_rot=self._apply_rotary(k,self._cos_cached,self._sin_cached)returnq_rot,k_rot# 简化版 RoPE(实际使用)defapply_rope(q,k,freqs_cis):"""freqs_cis: 预计算的复数频率"""# 转为复数q_complex=torch.view_as_complex(q.float().reshape(*q.shape[:-1],-1,2))k_complex=torch.view_as_complex(k.float().reshape(*k.shape[:-1],-1,2))# 乘以旋转因子q_out=torch.view_as_real(q_complex*freqs_cis).flatten(-2)k_out=torch.view_as_real(k_complex*freqs_cis).flatten(-2)returnq_out.type_as(q),k_out.type_as(k)4.4 RoPE 长度外推
# NTK-aware 缩放defapply_ntk_scaling(base,dim,factor):"""NTK-aware RoPE 缩放,支持更长上下文"""new_base=base*factor**(dim/(dim-2))inv_freq=1.0/(new_base**(torch.arange(0,dim,2).float()/dim))returninv_freq# YaRN(Yet another RoPE extensioN)defyarn_rope(dim,base,factor,beta_fast=32,beta_slow=1):"""YaRN 动态缩放"""# 计算每个频率维度的缩放因子inv_freq=1.0/(base**(torch.arange(0,dim,2).float()/dim))freqs=1.0/inv_freq# 找到需要缩放的维度low=dim*math.log(beta_slow)/math.log(base)high=dim*math.log(beta_fast)/math.log(base)scale=torch.ones(dim//2)foriinrange(dim//2):iffreqs[i]<low:scale[i]=1.0/factoreliffreqs[i]>high:scale[i]=1.0else:# 线性插值t=(freqs[i]-low)/(high-low)scale[i]=(1-t)/factor+treturninv_freq*scale5. ALiBi(Attention with Linear Biases)
5.1 原理
ALiBi 完全不用位置编码,而是在注意力分数上加一个线性偏置:
Attention(Q, K) = softmax(Q·K^T / √d + m · bias) bias[i,j] = -|i - j| // 距离越远,惩罚越大 m = 斜率,每个头不同(几何级数)5.2 实现
importtorchimportmathdefalibi_bias(num_heads,max_len):"""生成 ALiBi 偏置矩阵"""# 斜率:几何级数 2^(-8/n), 2^(-16/n), ...slopes=2**(-8*torch.arange(1,num_heads+1)/num_heads)# 距离矩阵positions=torch.arange(max_len)distance=positions.unsqueeze(0)-positions.unsqueeze(1)# (max_len, max_len)distance=distance.abs().float()# 偏置 = 斜率 × 距离bias=-slopes.unsqueeze(1).unsqueeze(2)*distance.unsqueeze(0)returnbias# (num_heads, max_len, max_len)classALiBiAttention(torch.nn.Module):def__init__(self,d_model,num_heads,max_len=4096):super().__init__()self.num_heads=num_heads self.head_dim=d_model//num_heads self.q_proj=torch.nn.Linear(d_model,d_model)self.k_proj=torch.nn.Linear(d_model,d_model)self.v_proj=torch.nn.Linear(d_model,d_model)self.o_proj=torch.nn.Linear(d_model,d_model)# 预计算 ALiBi 偏置self.register_buffer('alibi',alibi_bias(num_heads,max_len))defforward(self,x):B,L,_=x.shape q=self.q_proj(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)k=self.k_proj(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)v=self.v_proj(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)# 注意力 + ALiBiattn=(q @ k.transpose(-2,-1))/math.sqrt(self.head_dim)attn=attn+self.alibi[:,:L,:L].unsqueeze(0)attn=attn.softmax(dim=-1)out=(attn @ v).transpose(1,2).reshape(B,L,-1)returnself.o_proj(out)6. 各方案对比
| 方案 | 外推能力 | 训练开销 | 推理开销 | 长序列效果 |
|---|---|---|---|---|
| 正弦编码 | 有限 | 无 | 低 | 一般 |
| 可学习编码 | 无 | 低 | 低 | 差 |
| RoPE | 需缩放 | 中 | 中 | 好 |
| ALiBi | 天然外推 | 无 | 低 | 好 |
| YaRN RoPE | 强外推 | 中 | 中 | 很好 |
主流选择:
- Llama/Qwen/DeepSeek→ RoPE(配 YaRN 长度外推)
- BLOOM/MPT→ ALiBi
- GPT 系列→ 可学习绝对位置编码
7. 总结
位置编码的核心演进:
- 正弦编码:开创性工作,但与注意力分离
- RoPE:在注意力分数中注入相对位置,成为主流
- ALiBi:最简单方案,线性偏置即可,天然支持外推
- YaRN:RoPE 的长度外推方案,支持 128K+ 上下文
