当前位置: 首页 > news >正文

从零实现Transformer:第 3 部分 - 掩码多头注意力的掩码广播(Broadcasting of Masks in Masked Multi-Head Attention)

从零实现Transformer:第 3 部分 - 掩码多头注意力的掩码广播(Broadcasting of Masks in Masked Multi-Head Attention)

flyfish

以生成填充掩码 + 前瞻掩码的组合掩码 为例

1. 生成 Padding Mask(填充掩码)

屏蔽序列中的填充占位符(pad_id=0)填充的0是无效字符,模型不应该关注、学习这些无意义的占位符

2. 生成 Look-ahead Mask(前瞻掩码)

屏蔽当前位置之后的所有未来 token,解码器是自回归生成(一步步生成文本),绝对不能提前看到未来的词

3. 合并掩码

|运算把两个掩码合二为一:
只要是「填充位」或「未来位」,统一屏蔽(True)

用处

输出形状:[batch, 1, seq_len, seq_len]
这个掩码直接传入解码器的多头自注意力层
掩码为True→ 注意力分数置为负无穷,模型完全忽略该位置
掩码为False→ 正常计算注意力,模型可以关注该位置

importtorchdefcreate_tgt_mask(tgt_ids,pad_id):"""创建目标序列掩码(padding mask + look-ahead mask)"""#1.2维padding掩码[batch,seq_len]padding_mask_2d=(tgt_ids==pad_id)#2.升维适配注意力维度->[batch,1,1,seq_len]tgt_padding_mask=padding_mask_2d.unsqueeze(1).unsqueeze(1)#3.生成序列长度 tgt_seq_len=tgt_ids.shape[1]#4.构造上三角前瞻掩码[seq_len,seq_len]#diagonal=1:主对角线上方为1,遮挡未来位置look_ahead_mask=torch.triu(torch.ones(tgt_seq_len,tgt_seq_len,device=tgt_ids.device),diagonal=1).bool()#5.升维支持批量广播->[1,1,seq_len,seq_len]look_ahead_mask=look_ahead_mask.unsqueeze(0).unsqueeze(0)#6.合并掩码:任意一个为True就遮挡returntgt_padding_mask|look_ahead_mask # 测试if__name__=="__main__":pad_id=0#2个batch,序列长度50为padding tgt_ids=torch.tensor([[1,2,3,0,0],[4,5,0,0,0]])mask=create_tgt_mask(tgt_ids,pad_id)print("最终掩码形状:",mask.shape)# torch.Size([2,1,5,5])print("掩码内容:\n",mask)

输出

最终掩码形状:torch.Size([2,1,5,5])掩码内容:tensor([[[[False,True,True,True,True],[False,False,True,True,True],[False,False,False,True,True],[False,False,False,True,True],[False,False,False,True,True]]],[[[False,True,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True]]]])

广播 = PyTorch 自动把 形状不同但兼容 的张量,复制拉伸成相同形状,然后再运算

两个张量形状

returntgt_padding_mask|look_ahead_mask

两个输入形状:

  1. tgt_padding_mask[B, 1, 1, S]→ 举例[2, 1, 1, 4]
  2. look_ahead_mask[1, 1, S, S]→ 举例[1, 1, 4, 4]

广播目标:把两个张量都自动变成[2, 1, 4, 4],再做|运算

广播规则

  1. 维度为1的位置,可以自动复制扩展成任意大小
  2. 扩展后,两个张量形状完全一致,就能运算

例子1:最简单的2维广播
模拟:小张量自动拉伸

importtorch# 形状 [1,4] → 1行4列a=torch.tensor([[True,False,True,False]])# 形状 [4,4] → 4行4列b=torch.ones(4,4).bool()# 广播运算:a自动复制4行,变成[4,4],再和b运算c=a|bprint("a形状:",a.shape)print("b形状:",b.shape)print("广播后运算结果形状:",c.shape)# 输出 [4,4]

[1,4]自动扩成[4,4]

例子2:3维广播(过渡)

# [2,1,4]a=torch.rand(2,1,4).bool()# [1,4,4]b=torch.rand(1,4,4).bool()# 自动广播成 [2,4,4]c=a|bprint(c.shape)# [2,4,4]

例子3:模拟代码的4维广播

importtorch# 模拟两个掩码B,S=2,4# 1. padding掩码 [2,1,1,4]tgt_pad_mask=torch.rand(B,1,1,S).bool()# 2. 前瞻掩码 [1,1,4,4]look_ahead_mask=torch.rand(1,1,S,S).bool()# 广播运算!final_mask=tgt_pad_mask|look_ahead_mask# 打印形状print("padding掩码形状:",tgt_pad_mask.shape)# [2,1,1,4]print("前瞻掩码形状:",look_ahead_mask.shape)# [1,1,4,4]print("广播后最终形状:",final_mask.shape)# [2,1,4,4]

代码里用到的

输入参数

# 2个句子,每个句子最长5个词tgt_ids=torch.tensor([[1,2,3,0,0],# 第1个样本:有效词3个,后2个是填充0[4,5,0,0,0]# 第2个样本:有效词2个,后3个是填充0])pad_id=0# 0代表填充位

批次大小B = 2
序列长度S = 5
标准掩码维度:[batch, num_heads, seq_q, seq_k]

最终维度是[2, 1, 5, 5]

[2, 1, 5, 5] = [批次B, 头数H, 查询序列长Q, 键序列长K]
  1. 2:一次性处理2 个句子(batch=2)
  2. 1:代码里没做多头,默认1 个注意力头
  3. 5:Query 向量数量 = 目标序列长度 = 5
  4. 5:Key 向量数量 = 目标序列长度 = 5

代码里的广播

  1. tgt_padding_mask形状:[2, 1, 1, 5]
  2. look_ahead_mask形状:[1, 1, 5, 5]
  3. PyTorch自动广播把两个张量都拉伸为[2, 1, 5, 5],再做|运算

掩码内容

最终掩码 =前瞻掩码填充掩码
True= 遮挡(不让看)
False= 允许看

1. 前瞻掩码(固定不变,所有样本共用)

torch.triu(..., diagonal=1)生成固定上三角矩阵

# 5x5 前瞻掩码(对角线以上全是True,遮挡未来词) [ [F, T, T, T, T], # 第1个词:只能看自己,不能看后面4个 [F, F, T, T, T], # 第2个词:能看自己+前1个,不能看后面3个 [F, F, F, T, T], # 第3个词:能看自己+前2个,不能看后面2个 [F, F, F, F, T], # 第4个词:能看自己+前3个,不能看后面1个 [F, F, F, F, F] # 第5个词:能看所有前面的词 ]

2. 填充掩码(每个样本不一样)

样本1[1,2,3,0,0]第4、5位是填充→ 掩码[F,F,F,T,T]
样本2[4,5,0,0,0]第3、4、5位是填充→ 掩码[F,F,T,T,T]

最终合并结果

样本1 输出(第一块 5x5)

[[False, True, True, True, True], [False, False, True, True, True], [False, False, False, True, True], [False, False, False, True, True], # 第4位是填充,永久遮挡 [False, False, False, True, True]] # 第5位是填充,永久遮挡

前3行:只受前瞻掩码影响
后2行:前瞻掩码 + 填充掩码双重遮挡

样本2 输出(第二块 5x5)

[[False, True, True, True, True], [False, False, True, True, True], [False, False, True, True, True], # 第3位是填充,永久遮挡 [False, False, True, True, True], # 第4位是填充,永久遮挡 [False, False, True, True, True]] # 第5位是填充,永久遮挡

前2行:只受前瞻掩码影响
后3行:前瞻掩码 + 填充掩码双重遮挡

简单的流程就是

  1. 维度[2,1,5,5]
    [2个句子, 1个注意力头, 每个句子5个Query, 每个句子5个Key]
  2. 掩码内容
    上三角的True= 遮挡未来词(前瞻掩码)
    后半列的True= 遮挡填充0(填充掩码)
  3. 两者合并,就是看到的输出

不用广播的写法

importtorchdefcreate_tgt_mask_no_broadcast(tgt_ids,pad_id):"""创建目标序列掩码(无广播版,手动扩展维度)"""B,S=tgt_ids.shape # 直接获取批次B=2,序列长S=5#1.2维padding掩码[batch,seq_len][2,5]padding_mask_2d=(tgt_ids==pad_id)#2.升维 →[B,1,1,S][2,1,1,5]tgt_padding_mask=padding_mask_2d.unsqueeze(1).unsqueeze(1)#==============替代广播==============# 把第3维(seq_q)从1复制成 S → 形状变成[B,1,S,S]=[2,1,5,5]tgt_padding_mask=tgt_padding_mask.repeat(1,1,S,1)#3.构造上三角前瞻掩码[S,S][5,5]look_ahead_mask=torch.triu(torch.ones(S,S,device=tgt_ids.device),diagonal=1).bool()#4.升维 →[1,1,S,S][1,1,5,5]look_ahead_mask=look_ahead_mask.unsqueeze(0).unsqueeze(0)#==============替代广播==============# 把第0维(batch)从1复制成 B → 形状变成[B,1,S,S]=[2,1,5,5]look_ahead_mask=look_ahead_mask.repeat(B,1,1,1)#6.两个掩码形状完全一致,直接运算(无任何广播)returntgt_padding_mask|look_ahead_mask # 测试if__name__=="__main__":pad_id=0tgt_ids=torch.tensor([[1,2,3,0,0],[4,5,0,0,0]])# 运行无广播版本 mask=create_tgt_mask_no_broadcast(tgt_ids,pad_id)print("最终掩码形状:",mask.shape)# 依旧是 torch.Size([2,1,5,5])print("掩码内容:\n",mask)

输出

最终掩码形状:torch.Size([2,1,5,5])掩码内容:tensor([[[[False,True,True,True,True],[False,False,True,True,True],[False,False,False,True,True],[False,False,False,True,True],[False,False,False,True,True]]],[[[False,True,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True],[False,False,True,True,True]]]])
http://www.jsqmd.com/news/807521/

相关文章:

  • RimWorld模组开发新范式:Riml元语言工具提升开发效率
  • VMware Unlocker 3.0:在普通PC上运行macOS虚拟机的终极指南
  • 积分、微分、指数和对数运算放大电路基础知识及Multisim电路仿真
  • WARPED框架:基于单目RGB视频的机器人模仿学习系统
  • 感应照明技术:从工业到家用,一场技术降维的工程冒险
  • 从零到一:手把手完成Jmeter与JDK环境搭建及配置验证
  • 长沙口碑好的学区房怎么选 - mypinpai
  • 小红书内容下载终极指南:如何用XHS-Downloader轻松保存无水印作品
  • Spec-Kit中文版:AI驱动的规范驱动开发实践指南
  • 如何在Windows和Linux上快速解锁VMware的macOS支持:Unlocker 3.0终极指南
  • 2025年项目管理工具TOP10:Gitee引领技术驱动新浪潮
  • AI编程工具的内卷:Copilot、Cursor、通义灵码,谁能笑到最后?
  • 2026年AI生成内容怕AI检测?7款专业工具帮你降AI率高效过关!收藏必备 - 降AI实验室
  • Shopify上线AI Toolkit:卖家运营提效新利器,却也暗藏风险与挑战
  • Display Driver Uninstaller终极指南:5分钟彻底解决显卡驱动残留问题
  • Elektra Skills:为AI编程助手引入结构化执行与自动化治理的解决方案架构师
  • 2026年口碑好的LED显示屏品牌排名 - mypinpai
  • 关于假发的几个偏见,今天一并说清楚
  • 机器学习在资产管理中的应用:从数据到投资组合的端到端框架
  • 长沙壹南府好不好用?有什么优点? - mypinpai
  • OpenAI 兼容接口调用 Claude 的迁移实战
  • claw-gatekeeper:构建稳定智能的数据抓取守护服务
  • 如何5分钟部署AzurLaneAutoScript:面向新手的终极自动化指南
  • 3分钟学会!用Video-subtitle-extractor轻松提取视频硬字幕,告别手动转录烦恼
  • 为什么 Promise 比 setTimeout 先执行?——JavaScript 事件循环与异步顺序完全指南
  • 2026年4月,口碑好的钨钢防弹插板供应商哪家强?钨钢防弹插板/q420C高强钢板/nm500耐磨板,防弹插板公司推荐 - 品牌推荐师
  • Java安装完全指南:从零搭建Java开发环境
  • 四大32位FPGA软核处理器实战对比:LEON3、OR1200、Nios II与MicroBlaze选型指南
  • 卖token有多赚钱
  • 雨之灵动获数千万融资,AI 仿生毛绒宠物 Walulu 能否建立品牌壁垒?