当前位置: 首页 > news >正文

Transformer 核心模块详解:多头注意力、前馈网络与词嵌入

【学习记录】Transformer 核心模块详解:多头注意力、前馈网络与词嵌入

Transformer 是现代大语言模型的基石,而多头注意力(MultiHeadAttention)前馈网络(FFN)词嵌入(Embedding)是其最核心的三个组件。本文从原理到代码,逐层拆解这三个模块,并提供 Python(PyTorch)和 C++(LibTorch)实现,附带完整的复杂度分析。


📌 目录

  1. MultiHeadAttention(多头注意力)
  2. FFN(前馈网络)
  3. Embedding(词嵌入)
  4. 三个模块的组合使用
  5. 复杂度总结

一、多头注意力(MultiHeadAttention)

1.1 作用

多头注意力机制允许模型同时关注输入序列中不同位置的不同表示子空间。它通过将查询(Q)、键(K)、值(V)线性映射到多个头,分别计算注意力,最后拼接并映射回原维度。

1.2 数学公式

标准缩放点积注意力:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V

多头注意力:

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^OMultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)headi=Attention(QWiQ,KWiK,VWiV)

1.3 代码实现(Python/PyTorch)

importtorchimporttorch.nnasnnimportmathclassMultiHeadAttention(nn.Module):def__init__(self,d_model,num_heads):super().__init__()assertd_model%num_heads==0self.d_model=d_model self.num_heads=num_heads self.d_k=d_model//num_heads self.Wq=nn.Linear(d_model,d_model)self.Wk=nn.Linear(d_model,d_model)self.Wv=nn.Linear(d_model,d_model)self.Wo=nn.Linear(d_model,d_model)defforward(self,Q,K,V,mask=None):batch_size=Q.size(0)# 1. 线性映射并拆分为多头Q=self.Wq(Q).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)K=self.Wk(K).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)V=self.Wv(V).view(batch_size,-1,self.num_heads,self.d_k).transpose(1,2)# 2. 缩放点积注意力scores=torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.d_k)# 3. 应用掩码(可选)ifmaskisnotNone:scores=scores.masked_fill(mask==0,-1e9)attn_weights=torch.softmax(scores,dim=-1)output=torch.matmul(attn_weights,V)# 4. 合并多头并输出output=output.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)returnself.Wo(output)

1.4 图解(文本示意)

输入: (B, T, D) │ ├─→ 线性映射 Wq, Wk, Wv → (B, T, D) │ ├─→ view + transpose → (B, n_head, T, d_k) │ ├─→ scores = Q @ K^T / sqrt(d_k) → (B, n_head, T, T) │ │ │ └─→ mask (可选) 填充 -1e9 │ ├─→ softmax → (B, n_head, T, T) │ ├─→ output = attn @ V → (B, n_head, T, d_k) │ ├─→ transpose + view → (B, T, D) │ └─→ Wo 线性映射 → (B, T, D)

1.5 C++ 代码(LibTorch)

#include<torch/torch.h>classMultiHeadAttentionImpl:publictorch::nn::Module{public:intd_model,num_heads,d_k;torch::nn::Linear Wq,Wk,Wv,Wo;MultiHeadAttentionImpl(intd_model_,intnum_heads_):d_model(d_model_),num_heads(num_heads_),d_k(d_model_/num_heads_),Wq(torch::nn::Linear(d_model,d_model)),Wk(torch::nn::Linear(d_model,d_model)),Wv(torch::nn::Linear(d_model,d_model)),Wo(torch::nn::Linear(d_model,d_model)){register_module("Wq",Wq);register_module("Wk",Wk);register_module("Wv",Wv);register_module("Wo",Wo);}torch::Tensorforward(torch::Tensor Q,torch::Tensor K,torch::Tensor V,torch::Tensor mask={}){intbatch_size=Q.size(0);// 线性映射Q=Wq->forward(Q).view({batch_size,-1,num_heads,d_k}).transpose(1,2);K=Wk->forward(K).view({batch_size,-1,num_heads,d_k}).transpose(1,2);V=Wv->forward(V).view({batch_size,-1,num_heads,d_k}).transpose(1,2);// 注意力分数autoscores=torch::matmul(Q,K.transpose(-2,-1))/std::sqrt(d_k);if(mask.defined()){scores=scores.masked_fill(mask==0,-1e9);}autoattn=torch::softmax(scores,-1);autooutput=torch::matmul(attn,V);output=output.transpose(1,2).contiguous().view({batch_size,-1,d_model});returnWo->forward(output);}};TORCH_MODULE(MultiHeadAttention);

1.6 复杂度分析

操作时间复杂度空间复杂度
线性映射 (Q,K,V)O(B×T×D²)O(B×T×D)
拆分多头O(B×T×D)O(B×n_head×T×d_k)
分数矩阵乘法O(B×n_head×T²×d_k)O(B×n_head×T²)
SoftmaxO(B×n_head×T²)O(B×n_head×T²)
加权求和O(B×n_head×T²×d_k)O(B×n_head×T×d_k)
合并与输出映射O(B×T×D²)O(B×T×D)
总计O(B × T² × D)O(B × n_head × T²)

其中D = d_model,d_k = D / n_head


二、前馈网络(FFN)

2.1 作用

FFN 对每个位置独立进行非线性变换,增加模型表达能力。标准结构:线性 → ReLU → 线性,通常中间维度d_ffd_model的 4 倍左右。

2.2 数学公式

FFN ( x ) = ReLU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2FFN(x)=ReLU(xW1+b1)W2+b2

2.3 代码实现(Python/PyTorch)

classFFN(nn.Module):def__init__(self,d_model,d_ff):super().__init__()self.linear1=nn.Linear(d_model,d_ff)self.linear2=nn.Linear(d_ff,d_model)self.activation=nn.ReLU()defforward(self,x):returnself.linear2(self.activation(self.linear1(x)))

2.4 图解

输入 (B, T, D) │ ├─→ linear1 (D → d_ff) → (B, T, d_ff) │ ├─→ ReLU → (B, T, d_ff) │ └─→ linear2 (d_ff → D) → (B, T, D)

2.5 C++ 代码(LibTorch)

classFFNImpl:publictorch::nn::Module{public:torch::nn::Linear linear1,linear2;FFNImpl(intd_model,intd_ff):linear1(d_model,d_ff),linear2(d_ff,d_model){register_module("linear1",linear1);register_module("linear2",linear2);}torch::Tensorforward(torch::Tensor x){returnlinear2->forward(torch::relu(linear1->forward(x)));}};TORCH_MODULE(FFN);

2.6 复杂度分析

操作时间复杂度空间复杂度
linear1O(B × T × D × d_ff)O(B × T × d_ff)
ReLUO(B × T × d_ff)O(B × T × d_ff)
linear2O(B × T × d_ff × D)O(B × T × D)
总计O(B × T × D × d_ff)O(B × T × max(D, d_ff))

d_ff = 4 × D时,复杂度约为O(4 × B × T × D²)


三、词嵌入(Embedding)

3.1 作用

将离散的 token ID 序列映射为稠密的连续向量,并乘以√d_model进行缩放,以便与位置编码相加时尺度匹配。

3.2 代码实现(Python/PyTorch)

classEmbedding(nn.Module):def__init__(self,vocab_size,d_model):super().__init__()self.embedding=nn.Embedding(vocab_size,d_model)self.d_model=d_modeldefforward(self,x):returnself.embedding(x)*math.sqrt(self.d_model)

3.3 图解

输入: (B, T) token IDs [ [1, 3, 0, ...] ] │ └─→ nn.Embedding 查表 (vocab_size × D) │ └─→ 输出 (B, T, D) │ └─→ 乘以 √D → (B, T, D)

3.4 C++ 代码(LibTorch)

classEmbeddingImpl:publictorch::nn::Module{public:torch::nn::Embedding embedding;intd_model;EmbeddingImpl(intvocab_size,intd_model_):embedding(vocab_size,d_model_),d_model(d_model_){register_module("embedding",embedding);}torch::Tensorforward(torch::Tensor x){returnembedding->forward(x)*std::sqrt(d_model);}};TORCH_MODULE(Embedding);

3.5 复杂度分析

操作时间复杂度空间复杂度
查表O(B × T)O(B × T × D)
乘法O(B × T × D)O(B × T × D)
总计O(B × T × D)O(B × T × D)

四、三个模块的组合使用

一个完整的 Transformer 编码器层通常由多头注意力 + 残差连接 + 层归一化FFN + 残差连接 + 层归一化构成。

classTransformerEncoderLayer(nn.Module):def__init__(self,d_model,num_heads,d_ff):super().__init__()self.self_attn=MultiHeadAttention(d_model,num_heads)self.ffn=FFN(d_model,d_ff)self.norm1=nn.LayerNorm(d_model)self.norm2=nn.LayerNorm(d_model)defforward(self,x,mask=None):# 自注意力 + 残差 + 层归一化attn_out=self.self_attn(x,x,x,mask)x=self.norm1(x+attn_out)# FFN + 残差 + 层归一化ffn_out=self.ffn(x)x=self.norm2(x+ffn_out)returnx

完整流程示例

vocab_size=10000d_model=512num_heads=8d_ff=2048batch_size=2seq_len=10# 输入 token IDsinput_ids=torch.randint(0,vocab_size,(batch_size,seq_len))# 嵌入层embed=Embedding(vocab_size,d_model)x=embed(input_ids)# (2,10,512)# 位置编码(此处略,可加上)# pos_enc = PositionalEncoding(d_model)# x = pos_enc(x)# Transformer 编码器层encoder_layer=TransformerEncoderLayer(d_model,num_heads,d_ff)output=encoder_layer(x)# (2,10,512)print(output.shape)# torch.Size([2, 10, 512])

五、复杂度总结

模块时间复杂度空间复杂度说明
MultiHeadAttentionO(B × T² × D)O(B × n_head × T²)核心瓶颈在 T²,长序列需优化
FFNO(B × T × D × d_ff)O(B × T × max(D, d_ff))通常 d_ff = 4D,复杂度约为 4×
EmbeddingO(B × T × D)O(B × T × D)查表操作,轻量级

优化建议

  • 对于长序列(T 很大),可使用稀疏注意力(如 FlashAttention)降低 T² 复杂度。
  • FFN 的中间维度 d_ff 越大模型容量越大,但计算量线性增加。
  • 嵌入层占参数量主要为vocab_size × D,大词表时需考虑参数共享或压缩。

🎯 总结

本文详细拆解了 Transformer 的三个核心模块:

  1. 多头注意力:让模型关注不同位置的多种关系,是 Transformer 成功的核心。
  2. 前馈网络:提供非线性变换,增强模型表达能力。
  3. 词嵌入:将离散符号映射到连续空间,是深度学习处理文本的起点。

通过理解这些模块的输入输出、形状变化和复杂度,能轻松搭建并优化自己的 Transformer 模型。

http://www.jsqmd.com/news/855569/

相关文章:

  • cp520靶场学习笔记
  • 【FPAI开发】超详细!YOLO26适配FPAI芯片部署过程详解!
  • 高级音频解密技术实现:ncmdump模块化架构解析与自动化工作流
  • 【附源码】在线骑行网站(源码+数据库+论文+答辩ppt一整套齐全)java开发springboot+vue框架javaweb,可做计算机毕业设计或课程设计
  • 【算法题攻略】模拟
  • 2026年知名的镇江防腐网格桥架优质厂家推荐榜 - 行业平台推荐
  • 鸿蒙动态信息流与健康档案模块:声明式列表与网格的深度融合
  • 电脑投屏工具,将电脑屏幕共享到手机、平板、电脑、智能电视、投影仪等其它设备上!既可以共享整个屏幕,也能单独共享某个应用窗口,可作为提词器使用,或者更多运用场景!
  • Taotoken多模型聚合在批量内容生成任务中的稳定性观察
  • OpenAI Embeddings API 申请及使用
  • AutoGLM 手机自动化测试滑动性能优化
  • O2OA(翱途)开发平台V10 财务管理|中小企业费用业务一体化
  • TK跨境直播网络链路实测分析
  • 告别MPU6050例程!ATK-IMU901与Arduino串口通信的3个关键避坑点
  • YimMenu:GTA5终极防护与增强完整指南
  • 软件测试笔记【黑盒测试篇】:基于需求、面向功能
  • 无人机算法之第四章 ArduPilot 主要配置参数及效果
  • 数据库一体机简史:谁为数据仓库正名?
  • Perplexity到底是什么:从信息熵到模型评估,一文讲透3个核心公式与4种误用场景
  • 基于PSoC 6与BMI160构建嵌入式IMU测试系统:从驱动到上位机全流程
  • COMSOL电磁超声仿真避坑指南:从‘域不适用’报错到结果收敛的完整调试流程
  • DeepSeek大模型推理显存爆满?揭秘vLLM+FlashAttention下GPU显存占用突增217%的真实根因
  • HC32F4A0实战:用SPI驱动国产BL25CMIA EEPROM,从引脚配置到可靠性存储的完整流程
  • 项目——基于C/S架构的文件传输系统平台 (2)——重构
  • 保姆级教程:在S32G274ARDB2上,用IPCF点亮RGB LED(附源码解析)
  • AI 写代码总跑偏?mirrorai 让 Claude Code、Cursor、Copilot 严格遵守你项目的真实规范
  • 2026年自助建站平台哪个好?推荐这4个知名建站平台!
  • Git 进阶(二):分支管理、暂存栈、远程仓库与多人协作
  • 【正式版上线】Open Claw 2.7.5 桌面端一键安装部署教程
  • 三步告别键盘连击:KeyboardChatterBlocker高效使用全攻略