当前位置: 首页 > news >正文

别再死记硬背Attention公式了!用Python+PyTorch手撕一个Hierarchical Attention Network(HAN)

从零实现层次注意力网络:用PyTorch构建可解释的文本分析模型

在自然语言处理领域,理解长文档的层次结构一直是个挑战。传统的注意力机制虽然强大,但面对嵌套的文本层级(如词→句→段落)时往往力不从心。这就是Hierarchical Attention Network(HAN)的用武之地——它像人类阅读一样,先理解词语,再把握句子,最后整合段落含义。本文将带您用PyTorch从零搭建这个精妙的架构,过程中您会发现:

  1. 注意力机制不再是黑箱,通过可视化权重能看到模型"关注"了什么
  2. GRU单元如何在不同层级间传递和提炼信息
  3. 为什么说HAN特别适合代码变更分析、医疗报告解析等结构化文本任务

1. 环境准备与数据预处理

1.1 安装必要依赖

建议使用Python 3.8+和最新版PyTorch。创建一个干净的虚拟环境后安装:

pip install torch==1.12.0 torchtext==0.13.0 matplotlib numpy

1.2 构建示例数据集

我们将模拟代码审查场景,构造包含三个层级的虚拟数据:

import torch from collections import defaultdict # 示例数据结构:hunk → lines → words sample_data = [ { "hunk_id": 1, "lines": [ {"line_id": 1, "text": "fix null pointer exception"}, {"line_id": 2, "text": "add input validation"} ] }, { "hunk_id": 2, "lines": [ {"line_id": 3, "text": "optimize database query"}, {"line_id": 4, "text": "remove redundant joins"} ] } ] # 构建词汇表 word_vocab = defaultdict(lambda: len(word_vocab)) word_vocab["<pad>"] = 0 # 填充标记 word_vocab["<unk>"] = 1 # 未知词 for hunk in sample_data: for line in hunk["lines"]: for word in line["text"].split(): _ = word_vocab[word.lower()] print(f"词汇表大小: {len(word_vocab)}")

2. 词级编码与注意力实现

2.1 双向GRU编码器

词级编码器需要捕获每个词的上下文信息:

import torch.nn as nn class WordLevelEncoder(nn.Module): def __init__(self, vocab_size, embed_dim=100, hidden_size=50): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.gru = nn.GRU( input_size=embed_dim, hidden_size=hidden_size, bidirectional=True, batch_first=True ) def forward(self, x): # x形状: (batch_size, seq_len) embedded = self.embedding(x) # (batch_size, seq_len, embed_dim) outputs, _ = self.gru(embedded) # (batch_size, seq_len, 2*hidden_size) return outputs

2.2 词级注意力机制

注意力层让模型学会聚焦关键词语:

class WordAttention(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(2*hidden_size, hidden_size) self.context_vector = nn.Parameter(torch.randn(hidden_size)) def forward(self, encoder_outputs): # encoder_outputs形状: (batch_size, seq_len, 2*hidden_size) u = torch.tanh(self.linear(encoder_outputs)) # (batch_size, seq_len, hidden_size) scores = torch.matmul(u, self.context_vector) # (batch_size, seq_len) alphas = torch.softmax(scores, dim=1) # 注意力权重 return torch.sum(encoder_outputs * alphas.unsqueeze(-1), dim=1), alphas

3. 行级编码与注意力实现

3.1 行编码器结构

行编码器处理词级编码器的输出序列:

class LineLevelEncoder(nn.Module): def __init__(self, input_size, hidden_size=50): super().__init__() self.gru = nn.GRU( input_size=input_size, hidden_size=hidden_size, bidirectional=True, batch_first=True ) def forward(self, x): # x形状: (batch_size, num_lines, 2*word_hidden_size) outputs, _ = self.gru(x) # (batch_size, num_lines, 2*hidden_size) return outputs

3.2 行级注意力层

行级注意力识别文档中的关键句子:

class LineAttention(nn.Module): def __init__(self, hidden_size): super().__init__() self.linear = nn.Linear(2*hidden_size, hidden_size) self.context_vector = nn.Parameter(torch.randn(hidden_size)) def forward(self, encoder_outputs): # 实现与WordAttention类似,但处理的是行级表示 u = torch.tanh(self.linear(encoder_outputs)) scores = torch.matmul(u, self.context_vector.unsqueeze(1).squeeze(-1)) alphas = torch.softmax(scores, dim=1) return torch.sum(encoder_outputs * alphas.unsqueeze(-1), dim=1), alphas

4. 块级编码与完整HAN集成

4.1 块编码器设计

块级编码器处理行级表示序列:

class HunkLevelEncoder(nn.Module): def __init__(self, input_size, hidden_size=50): super().__init__() self.gru = nn.GRU( input_size=input_size, hidden_size=hidden_size, bidirectional=True, batch_first=True ) def forward(self, x): # x形状: (batch_size, num_hunks, 2*line_hidden_size) outputs, _ = self.gru(x) return outputs

4.2 完整HAN架构

整合所有组件构建端到端模型:

class HierarchicalAttentionNetwork(nn.Module): def __init__(self, vocab_size, word_embed_dim=100, word_hidden_size=50, line_hidden_size=50, hunk_hidden_size=50): super().__init__() self.word_encoder = WordLevelEncoder(vocab_size, word_embed_dim, word_hidden_size) self.word_attention = WordAttention(word_hidden_size) self.line_encoder = LineLevelEncoder(2*word_hidden_size, line_hidden_size) self.line_attention = LineAttention(line_hidden_size) self.hunk_encoder = HunkLevelEncoder(2*line_hidden_size, hunk_hidden_size) self.hunk_attention = LineAttention(hunk_hidden_size) # 复用LineAttention结构 def forward(self, hunks): # hunks是预处理后的输入数据 batch_size = len(hunks) # 词级处理 line_representations = [] word_attentions = [] for lines in hunks: line_reps = [] word_atts = [] for words in lines: word_outputs = self.word_encoder(words) line_rep, word_att = self.word_attention(word_outputs) line_reps.append(line_rep) word_atts.append(word_att) line_representations.append(torch.stack(line_reps)) word_attentions.append(torch.stack(word_atts)) # 行级处理 hunk_representations = [] line_attentions = [] for lines in line_representations: line_outputs = self.line_encoder(lines.unsqueeze(0)) hunk_rep, line_att = self.line_attention(line_outputs) hunk_representations.append(hunk_rep) line_attentions.append(line_att) # 块级处理 hunk_outputs = torch.stack(hunk_representations) final_output, hunk_attentions = self.hunk_attention(hunk_outputs) return { "output": final_output, "word_attentions": word_attentions, "line_attentions": line_attentions, "hunk_attentions": hunk_attentions }

5. 模型训练与注意力可视化

5.1 自定义训练循环

实现带注意力监控的训练过程:

def train_model(model, data_loader, criterion, optimizer, epochs=10): model.train() for epoch in range(epochs): total_loss = 0 for batch in data_loader: optimizer.zero_grad() outputs = model(batch["input"]) loss = criterion(outputs["output"], batch["label"]) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss/len(data_loader):.4f}") # 可视化第一个样本的注意力权重 visualize_attention(batch["input"][0], outputs["word_attentions"][0], outputs["line_attentions"][0], outputs["hunk_attentions"][0])

5.2 注意力权重可视化

理解模型关注的重点:

import matplotlib.pyplot as plt def visualize_attention(sample, word_att, line_att, hunk_att): fig, axes = plt.subplots(3, 1, figsize=(10, 12)) # 词级注意力 words = ["<pad>"]*len(word_att) # 这里应替换为实际的词序列 axes[0].bar(range(len(word_att)), word_att.detach().numpy()) axes[0].set_title("Word-level Attention") # 行级注意力 lines = ["Line 1", "Line 2"] # 示例行内容 axes[1].bar(range(len(line_att)), line_att.detach().numpy()) axes[1].set_xticks(range(len(lines))) axes[1].set_xticklabels(lines, rotation=45) axes[1].set_title("Line-level Attention") # 块级注意力 hunks = ["Hunk 1", "Hunk 2"] # 示例块描述 axes[2].bar(range(len(hunk_att)), hunk_att.detach().numpy()) axes[2].set_xticks(range(len(hunks))) axes[2].set_xticklabels(hunks) axes[2].set_title("Hunk-level Attention") plt.tight_layout() plt.show()

6. 实际应用中的优化技巧

在真实项目中部署HAN时,有几个关键优化点值得注意:

  1. 批处理优化:原始实现逐样本处理效率低,可改用pack_padded_sequence处理变长序列
  2. 注意力计算加速:当序列较长时,使用缩放点积注意力(scaled dot-product)计算更快
  3. 多任务学习:在输出层同时预测多个相关标签(如代码审查中同时预测缺陷类型和严重程度)
  4. 层次Dropout:在不同层级应用不同dropout率,通常词级>行级>块级

一个优化后的注意力计算示例:

class EfficientAttention(nn.Module): def __init__(self, hidden_size): super().__init__() self.query = nn.Linear(hidden_size, hidden_size) self.key = nn.Linear(hidden_size, hidden_size) def forward(self, encoder_outputs): # encoder_outputs: (batch_size, seq_len, hidden_size) q = self.query(encoder_outputs) # (batch_size, seq_len, hidden_size) k = self.key(encoder_outputs) # (batch_size, seq_len, hidden_size) scores = torch.bmm(q, k.transpose(1,2)) / (encoder_outputs.size(-1)**0.5) alphas = torch.softmax(scores, dim=-1) return torch.bmm(alphas, encoder_outputs), alphas
http://www.jsqmd.com/news/673322/

相关文章:

  • 【侯俊霞全网最全收集--PLC1200/200SMART(88课时) 中级课程 第1章】
  • 软件测试计划模板
  • 5200000 个文件,rm -rf 报错,如何快速清理?
  • 车载问答系统开发不再踩坑:Dify v0.12.3适配Autosar AP平台完整技术白皮书(含ASAM MCD-2 MC接口映射表)
  • 【Dify插件开发黄金法则】:20年AI平台架构师亲授,从零构建可商用插件的5大核心步骤
  • 别再死磕理论了!用PCL+KinectFusion手把手教你从照片到3D模型(保姆级避坑指南)
  • 软件标准管理中的规范执行监督
  • 从源码演变看PyTorch forward设计:从v0.1.12到2.x的钩子(Hook)机制进化史
  • 【2026年最新600套毕设项目分享】微信小程序的新闻资讯系统(30117)
  • Path of Building:3大核心功能彻底改变流放之路角色构筑
  • 单细胞分析入门:用Python的AnnData管理你的第一个单细胞数据集(附代码)
  • 文档解析准确率从81.6%→99.2%:Dify v0.8.5+自定义Chunker调优全流程,仅限内部技术团队验证的7个关键参数
  • 哔哩下载姬完整教程:5分钟掌握B站视频下载与处理终极方案
  • 移动后端开发API设计与推送服务
  • SAP S/4HANA Cloud 公有云实施:广州企业服务商选型与落地实践
  • PTP协议精讲(2.11):纳秒从何而来——硬件时间戳的奥秘
  • Spring Boot 入门:Java 生态最流行的应用开发框架介绍
  • 打卡信奥刷题(3134)用C++实现信奥题 P7552 [COCI 2020/2021 #6] Anagramistica
  • 从‘硬’到‘软’:柔性阵列与稳健波束形成入门避坑指南
  • GEO深水区:AI信息分发革命下,行业乱象的底层逻辑与价值终局 - 速递信息
  • 2026年4月液液萃取设备厂家推荐,金属/连续/锂/沉锂母液/发酵液萃取设备,专业萃取解决方案供应商 - 品牌推荐用户报道者
  • Honor of Kings 2026.04.19
  • PTP协议精讲(2.12):PTP的十种语言——报文格式全解析
  • Python实战:用京东云SDK三行代码搞定短信发送(附状态回调查询完整Demo)
  • 从‘复合管’(达林顿管)到现代功放芯片:一场关于‘放大能力’的技术演进简史
  • 深入S2A-Net的‘对齐卷积’:如何让卷积网络‘看懂’旋转的物体?
  • 从仿真波形看懂Xilinx FIFO:手把手教你用Vivado分析复位与empty信号的变化
  • 终极《环世界》性能优化指南:如何通过Performance-Fish实现400%帧率提升
  • 从创建到关闭:手把手带你走完一个Bug在Bugzilla中的完整生命周期
  • 微服务架构中的分布式事务处理方案与数据一致性保障