import torch# 1. 模拟批次输入,0=PAD
input_ids = torch.tensor([[1,2,3,0,0], [4,5,0,0,0]])
batch, seq_len = input_ids.shape# ----------------------
# 第一步:生成 Padding Mask
# ----------------------
pad_mask = (input_ids == 0) # [B, L]
# 扩维到 [B, L, L],适配注意力分数矩阵 [B, L, L]
pad_mask = pad_mask.unsqueeze(1).repeat(1, seq_len, 1) # [2,5,5]# ----------------------
# 第二步:生成 Causal 前瞻掩码(上三角)
# ----------------------
# [L, L] 上三角,diagonal=1 表示主对角线右侧全部遮挡
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool() # [5,5]# ----------------------
# 第三步:合并掩码 逻辑或
# ----------------------
# 利用广播:[B,L,L] | [L,L] → [B,L,L]
full_mask = pad_mask | causal_mask# ----------------------
# 第四步:作用到注意力分数
# ----------------------
attn_score = torch.randn(batch, seq_len, seq_len) # 模拟注意力分数 [B,L,L]
attn_score = attn_score.masked_fill(full_mask, -1e9)
attn_weight = torch.softmax(attn_score, dim=-1)print(attn_weight)