当前位置: 首页 > news >正文

Transformer模型、整体结构,编码器与解码器内部组成

一、Transformer

此前的Seq2Seq模型通过Attention机制取得了一定提升,但由于整体结构仍依赖RNN,依然存在

计算效率低、难以建模长距离依赖等结构性限制。Transformer完全摒弃了RNN结构,转而使用注

意力机制直接建模序列中各位置之间的关系。

1、整体结构

与基于RNN的Seq2Seq模型一样,Transformer的解码器采用自回归方式生成目标序列。不同之处

在于,每一步的输入是此前已生成的全部词,模型会输出一个与输入长度相同的序列,但我们只取

最后一个位置的结果作为当前预测。这个过程不断重复,直到生成结束标记<eos>。

此外,Transformer的编码器和解码器模块分别由多个结构相同的层堆叠而成。通过层层堆叠,模

型能够逐步提取更深层次的语义特征,从而增强对复杂语言现象的建模能力。标准的Transformer

模型通常包含6个编码器层和6个解码器层。

2、编码器

每个Encoder Layer都包含两个子层(sublayer)自注意力子层(Self-Attention Sublayer)和前馈

神经网络子层(Feed-Forward Sublayer)。

2.1、自注意力子层

在序列内部建立各位置之间的依赖关系,使模型能够为每个位置生成融合全局信息的表示。

(1)生成Q,K,V向量

总结:Q发起匹配,K与Q匹配,V加权求和

(2)计算位置相关性

评分函数采用向量点积形式。由于在高维空间中,点积的数值可能过大,会影响softmax的稳定

性,因此在实际计算中对结果进行了缩放。最终的评分函数为:

其中𝑑𝑘是key向量的维度,用于缩放点积的幅度。这个分数越大,表示第i个位置越应该关注第j个

位置的信息。

(3)计算注意力权重

(4)加权汇总生成输出

(5)总结

整个自注意力机制的完整计算公式如下:

(6)多头注意力计算过程

要准确理解语义复杂的句子,Transformer引入了多头注意力机制(Multi-Head Attention)。其核

心思想是通过多组独立的Query、Key、Value投影,让不同注意力头分别专注于不同的语义关系,

最后将各头的输出拼接融合。

分别计算多头注意力输出

合并多头注意力

2.2、前馈神经网络子层

前馈神经网络(Feed-Forward Network,简称FFN),一个标准的FFN子层包含两个线性变换和一

个非线性激活函数,中间通常使用ReLU激活。

2.3、残差连接与层归一化

在Transformer的每个编码器层中,每个子层,包括自注意力子层和前馈神经网络子层,其输出都

要经过残差连接(Residual Connection)和层归一化(Layer Normalization)处理。这两者是深层

神经网络中常用的结构,用于缓解模型训练中的梯度消失、收敛困难等问题。

(1)残差连接(Residual Connection)

将子层的输入直接与其输出相加,形成一条跨越子层的“捷径”,其数学形式为:

(2)层归一化(Layer Normalization)

主要作用是规范输入序列中每个token的特征分布(某个token的表示可能在不同维度上有较大数值

差异),提升模型训练的稳定性。该操作会将每个token的向量调整为均值为0、方差为1的规范分

布。

2.4、位置编码

为了解决 Transformer 无法捕捉语序的问题,该模型引入了位置编码(Positional Encoding)机

制,通过为每个词添加位置向量,使其在获取词义的同时也能感知位置信息,从而理解语序。

为解决绝对位置编码数值倾斜及归一化导致的位置不一致问题,Transformer采用基于正弦和余弦

函数的固定位置编码,为每个位置生成唯一且与句子长度无关的向量,从而保证模型能稳定捕捉语

序信息。

3、解码器

每个Decoder Layer都包含三个子层,分别是Masked自注意力子层、编码器-解码器注意力子

层(Encoder-Decoder Attention)和前馈神经网络子层(Feed-Forward Network),每个子层后

也都配有残差连接与层归一化(Layer Normalization),结构设计与编码器保持一致,确保训练的

稳定性和效率。

此外,解码器在输入端同样需要加入位置编码(Positional Encoding),用于提供序列中的位

置信息,其计算方式与编码器中相同。

3.1、Masked自注意力子层

用于建模当前位置与前文词之间的依赖关系。为了在训练时模拟逐词生成的过程,引入遮盖机制

(Mask),限制每个位置只能关注它前面的词。

Mask是一个下三角矩阵,上三角设置为负无穷的原因是让上三角部分经过softmax之后,权重几乎

为0。

3.2、编码器--解码器注意力子层

在解码器的交叉注意力中,Query来自解码器当前输入,Key和Value来自编码器输出。通过计算

Query与所有Key的相似度,得到源序列各位置的权重,再对Value加权求和,从而为当前生成词提

取相关的上下文信息。

3.3、前馈神经网络子层

与编码器中结构完全一致,对每个位置的表示进行非线性变换,增强模型的表达能力。

4、模型训练和推理机制

4.1、模型训练

训练时,Transformer将目标序列整体输入解码器,并在每个位置同时进行预测。为防止模型“看到”

后面的词,破坏因果顺序,解码器在自注意力机子层中引入了遮盖机制(Mask),限制每个位置

只能关注它前面的词。

4.2、模型推理

推理时,每一步都要重新输入整个已生成序列,模型需要基于全量前文重新计算注意力分布,决定

下一个词的输出。整个过程必须顺序执行,无法并行。推理阶段,模型每一步都要重新输入当前已

生成的全部词,通过自注意力机制建模上下文关系,预测下一个词。

5、API使用

PyTorch提供了完整的Transformer官方实现,封装了编码器-解码器结构,适用于机器翻译、文本生成等序列建模任务。核心模块包括:

  • nn.Transformer:顶层接口,封装完整编码器-解码器架构,支持自定义层数、注意力头数、隐藏维度等参数。

  • nn.TransformerEncoder/Decoder:分别由多个编码器/解码器层堆叠而成,用于序列编码和目标序列生成。

  • nn.TransformerEncoderLayer/DecoderLayer:实现单层结构,编码器层包含多头自注意力和前馈子层;解码器层额外增加编码器-解码器注意力。各子层均配有残差连接和层归一化。

1、Transformer构造参数

2、Transformer.forward

nn.Transformer封装了完整的前向传播逻辑,其forward()方法接收源语言序列(编码器输入)和目

标语言序列(解码器输入),返回解码器的预测结果。

3、Transformer.encoder

nn.Transformer通过encoder属性(nn.TransformerEncoder实例)对源序列进行编码,提取上下文

相关的语义表示。

4、Test

import torch.nn as nn import torch model = nn.Transformer(d_model = 64,nhead = 8,num_encoder_layers = 3,num_decoder_layers = 3, dim_feedforward = 256,batch_first=True) print(model) src = torch.randn(32,10,64) tgt = torch.randn(32,24,64) output1 = model(src,tgt) print(output1.shape ) memory = model.encoder(src) print(memory.shape) output2 = model.decoder(tgt,memory) print(output2.shape) Transformer( (encoder): TransformerEncoder( (layers): ModuleList( (0-2): 3 x TransformerEncoderLayer( (self_attn): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True) ) (linear1): Linear(in_features=64, out_features=256, bias=True) (dropout): Dropout(p=0.1, inplace=False) (linear2): Linear(in_features=256, out_features=64, bias=True) (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True) (dropout1): Dropout(p=0.1, inplace=False) (dropout2): Dropout(p=0.1, inplace=False) ) ) (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True) ) (decoder): TransformerDecoder( (layers): ModuleList( (0-2): 3 x TransformerDecoderLayer( (self_attn): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True) ) (multihead_attn): MultiheadAttention( (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True) ) (linear1): Linear(in_features=64, out_features=256, bias=True) (dropout): Dropout(p=0.1, inplace=False) (linear2): Linear(in_features=256, out_features=64, bias=True) (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True) (norm3): LayerNorm((64,), eps=1e-05, elementwise_affine=True) (dropout1): Dropout(p=0.1, inplace=False) (dropout2): Dropout(p=0.1, inplace=False) (dropout3): Dropout(p=0.1, inplace=False) ) ) (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True) ) ) torch.Size([32, 24, 64]) torch.Size([32, 10, 64]) torch.Size([32, 24, 64])

6、Transformer中英翻译案例

PyTorch无内置位置编码,而Transformer不具备位置感知能力,故需手动实现位置编码,与嵌入层输出相加后输入模型。还需实现以下模块:

  • 源语言和目标语言的词嵌入层(nn.Embedding)

  • 输出层(nn.Linear),将模型输出映射至目标词表大小

day9 16

http://www.jsqmd.com/news/474914/

相关文章:

  • 手把手教你用MedGemma-X:AI影像诊断助手5分钟快速部署
  • OpenCode场景应用:程序员通勤路上用手机写代码,回家无缝衔接
  • 内联函数,函数的缺省值,函数重载,右值引用
  • 谷歌Gemini Pro API vs ChatGPT API:免费、配置难度与性能对比
  • AI 辅助开发实战:高效完成基于 Spring Boot 的 JavaWeb 毕设项目
  • PROJECT MOGFACE企业级部署:基于Docker与内网穿透的高可用架构
  • 手把手教你解决Vulhub环境搭建中的docker-compose up -d报错(含CentOS联网技巧)
  • C语言快速入门9-指针
  • 补天漏洞响应平台:白帽子与企业安全合作的桥梁
  • Windows下MissionPlanner地面站编译避坑指南:从Git克隆到VS2022完整流程
  • 从linux内核理解Java怎样实现Socket通信
  • CLAP模型在农业领域的创新应用:病虫害声音早期预警
  • 从STM32到语音交互:CosyVoice在嵌入式设备语音提示系统中的应用构想
  • 手机省电技巧|告别电量焦虑,一天一充不是梦
  • STM32 RTC数字校准、时间戳与低功耗机制全栈解析
  • PLSQL连接Oracle报ORA-12541?5个常见原因及快速排查方法
  • UiPath离线激活全流程:从生成Token到成功激活的保姆级教程
  • HttpCanary实战指南:从零开始掌握Android HTTPS抓包技巧
  • STM32 SPI/I2S状态机与安全停机机制深度解析
  • 《QGIS快速入门与应用基础》215:批量应用标注样式
  • 【项目实战】如何将接口传过来的html文件通过WPF控件展示在桌面应用程序?
  • 用Unity物理引擎还原真实赛车手感:齿轮变速+悬挂系统调试指南
  • 高德地图JSAPI实战:如何给北京市各区自定义颜色标记(附完整代码)
  • 基于Docker与macvlan:在Linux服务器上构建高性能OpenWrt软路由
  • MedGemma X-Ray开发者案例:gradio_app.py与Orthanc PACS双向DICOM通信
  • ESP32-C2技术文档体系与工程落地全链路指南
  • 多线程并发处理样例
  • 设计模式的六大原则:原理与实践
  • ESP32-C61总线与内存访问监控系统深度解析
  • ComicAI vs 传统漫画制作:实测AI生成30页漫画要花多少法力值?