当前位置: 首页 > news >正文

PyTorch字符级RNN实战指南

这是一份基于PyTorch官方教程(char_rnn_generation_tutorial等)的归纳整理与实战指南。所有内容都围绕字符级RNN展开,并覆盖了从名字生成、名字分类到序列到序列机器翻译的完整项目。我会为你详细拆解每个部分的原理,并提供完全可运行的代码和注释。


1. 核心模型对比

任务RNN分类名字RNN生成名字Seq2Seq翻译
输入名字的字母序列类别 + 一个字母源语言的词序列
输出类别标签(语言)下一个字母 → 完整名字目标语言的词序列
模型结构单RNN,取最后输出单RNN,循环直到EOS编码器 + 解码器 + 注意力
损失计算只在最后一步每步都计算每步都计算
训练技巧NLLLossTeacher Forcing可选Teacher Forcing可选

2. 数据获取与预处理

首先,下载官方数据集并解压到当前工作目录。

# 数据下载地址 # https://download.pytorch.org/tutorial/data.zip
import glob import os import unicodedata import string import random import torch import torch.nn as nn # 可打印的字符集 all_letters = string.ascii_letters + " .,;'-" n_letters = len(all_letters) + 1 # 多加一个EOS标记 def unicodeToAscii(s): """将Unicode字符串转换为纯ASCII""" return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' and c in all_letters ) def readLines(filename): """读取文件并返回经过ASCII转换的名字列表""" lines = open(filename, encoding='utf-8').read().strip().split('\n') return [unicodeToAscii(line) for line in lines] # 构建字典: 语言 -> 名字列表 category_lines = {} all_categories = [] for filename in glob.glob('data/names/*.txt'): category = os.path.splitext(os.path.basename(filename))[0] all_categories.append(category) category_lines[category] = readLines(filename) n_categories = len(all_categories) print(f'共有 {n_categories} 种语言:', all_categories) print(f'例如意大利语名字前5个: {category_lines["Italian"][:5]}')

输出示例:


3. 字符级RNN生成名字

3.1 网络结构

class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(RNN, self).__init__() self.hidden_size = hidden_size # 输入 = (category, letter, hidden) self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size) self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size) self.o2o = nn.Linear(hidden_size + output_size, output_size) self.dropout = nn.Dropout(0.1) self.softmax = nn.LogSoftmax(dim=1) def forward(self, category, input, hidden): # 所有输入拼接在一起 input_combined = torch.cat((category, input, hidden), 1) hidden = self.i2h(input_combined) output = self.i2o(input_combined) output_combined = torch.cat((hidden, output), 1) output = self.o2o(output_combined) output = self.dropout(output) output = self.softmax(output) return output, hidden def initHidden(self): return torch.zeros(1, self.hidden_size)

3.2 张量转换函数

def categoryTensor(category): li = all_categories.index(category) tensor = torch.zeros(1, n_categories) tensor[0][li] = 1 return tensor def inputTensor(line): """将字符串转为 one-hot 张量序列""" tensor = torch.zeros(len(line), 1, n_letters) for li in range(len(line)): letter = line[li] tensor[li][0][all_letters.find(letter)] = 1 return tensor def targetTensor(line): """目标张量: 每个位置是下一个字母的索引,最后一个位置是EOS""" letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))] letter_indexes.append(n_letters - 1) # EOS标记 return torch.LongTensor(letter_indexes) def randomTrainingExample(): category = random.choice(all_categories) line = random.choice(category_lines[category]) return categoryTensor(category), inputTensor(line), targetTensor(line)

3.3 训练流程

rnn = RNN(n_letters, 128, n_letters) criterion = nn.NLLLoss() learning_rate = 0.0005 def train(category_tensor, input_line_tensor, target_line_tensor): target_line_tensor.unsqueeze_(-1) hidden = rnn.initHidden() rnn.zero_grad() loss = 0 for i in range(input_line_tensor.size(0)): output, hidden = rnn(category_tensor, input_line_tensor[i], hidden) loss += criterion(output, target_line_tensor[i]) loss.backward() for p in rnn.parameters(): p.data.add_(-learning_rate, p.grad.data) return output, loss.item() / input_line_tensor.size(0) # 训练循环 (100k次迭代) n_iters = 100000 for iter in range(1, n_iters + 1): output, loss = train(*randomTrainingExample()) if iter % 5000 == 0: print(f'{iter} 次迭代, 损失: {loss:.4f}')

3.4 生成名字

max_length = 20 def sample(category, start_letter='A'): with torch.no_grad(): category_tensor = categoryTensor(category) input = inputTensor(start_letter) hidden = rnn.initHidden() output_name = start_letter for i in range(max_length): output, hidden = rnn(category_tensor, input[0], hidden) topv, topi = output.topk(1) topi = topi[0][0] if topi == n_letters - 1: # 遇到EOS就停止 break else: letter = all_letters[topi] output_name += letter input = inputTensor(letter) return output_name def samples(category, start_letters='ABC'): for start_letter in start_letters: print(sample(category, start_letter)) print("俄语名字生成示例:") samples('Russian', 'RUS') print("德语名字生成示例:") samples('German', 'GER')

输出示例:


4. 字符级RNN名字分类

4.1 分类网络结构

class ClassifyRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(ClassifyRNN, self).__init__() self.hidden_size = hidden_size self.i2h = nn.Linear(input_size + hidden_size, hidden_size) self.i2o = nn.Linear(input_size + hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden): combined = torch.cat((input, hidden), 1) hidden = self.i2h(combined) output = self.i2o(combined) output = self.softmax(output) return output, hidden def initHidden(self): return torch.zeros(1, self.hidden_size) # 实例化模型 classify_rnn = ClassifyRNN(n_letters, 128, n_categories)

4.2 训练与评估

import torch import torch.nn as nn import torch.optim as optim import glob import os import unicodedata import string import random # ==================== 数据准备 ==================== # 可打印的字符集 all_letters = string.ascii_letters + " .,;'-" n_letters = len(all_letters) + 1 # 多加一个EOS标记 def unicodeToAscii(s): """将Unicode字符串转换为纯ASCII""" return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' and c in all_letters ) def readLines(filename): """读取文件并返回经过ASCII转换的名字列表""" lines = open(filename, encoding='utf-8').read().strip().split('\n') return [unicodeToAscii(line) for line in lines] # 构建字典: 语言 -> 名字列表 category_lines = {} all_categories = [] for filename in glob.glob('data/names/*.txt'): category = os.path.splitext(os.path.basename(filename))[0] all_categories.append(category) category_lines[category] = readLines(filename) n_categories = len(all_categories) print(f'共有 {n_categories} 种语言: {all_categories}') print(f'例如意大利语名字前5个: {category_lines["Italian"][:5]}') # ==================== 张量转换函数 ==================== def letterToIndex(letter): """将单个字母转换为索引""" return all_letters.find(letter) def letterToTensor(letter): """将单个字母转换为one-hot张量 (1 x n_letters)""" tensor = torch.zeros(1, n_letters) tensor[0][letterToIndex(letter)] = 1 return tensor def lineToTensor(line): """将名字字符串转换为张量 (len(line) x 1 x n_letters)""" tensor = torch.zeros(len(line), 1, n_letters) for li, letter in enumerate(line): tensor[li][0][letterToIndex(letter)] = 1 return tensor # ==================== 模型定义 ==================== class RNN(nn.Module): """简单的RNN模型用于名字分类""" def __init__(self, input_size, hidden_size, output_size): super(RNN, self).__init__() self.hidden_size = hidden_size self.i2h = nn.Linear(input_size + hidden_size, hidden_size) self.i2o = nn.Linear(input_size + hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden): # 拼接输入和隐藏状态 combined = torch.cat((input, hidden), 1) hidden = self.i2h(combined) output = self.i2o(combined) output = self.softmax(output) return output, hidden def initHidden(self): return torch.zeros(1, self.hidden_size) # ==================== 训练参数 ==================== # 超参数 n_hidden = 128 learning_rate = 0.005 # 创建模型 classify_rnn = RNN(n_letters, n_hidden, n_categories) criterion = nn.NLLLoss() # ==================== 随机训练数据 ==================== def randomChoice(l): return l[random.randint(0, len(l) - 1)] def randomTrainingExample(): """随机选择一个训练样本""" category = randomChoice(all_categories) line = randomChoice(category_lines[category]) category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) line_tensor = lineToTensor(line) return category, line, category_tensor, line_tensor # ==================== 训练函数 ==================== def train_classify(category_tensor, line_tensor): hidden = classify_rnn.initHidden() classify_rnn.zero_grad() for i in range(line_tensor.size()[0]): output, hidden = classify_rnn(line_tensor[i], hidden) loss = criterion(output, category_tensor) loss.backward() # 手动更新参数(SGD) for p in classify_rnn.parameters(): p.data.add_(p.grad.data, alpha=-learning_rate) return output, loss.item() # ==================== 评估函数 ==================== def evaluate(line_tensor): hidden = classify_rnn.initHidden() for i in range(line_tensor.size()[0]): output, hidden = classify_rnn(line_tensor[i], hidden) return output def categoryFromOutput(output): top_n, top_i = output.topk(1) category_i = top_i[0].item() return all_categories[category_i], category_i # ==================== 预测函数 ==================== def predict(input_line, n_predictions=3): print(f'\n> {input_line}') with torch.no_grad(): output = evaluate(lineToTensor(input_line)) topv, topi = output.topk(n_predictions, 1, True) for i in range(n_predictions): value = topv[0][i].item() category = all_categories[topi[0][i].item()] print(f'({value:.2f}) {category}') # ==================== 训练循环 ==================== n_iters = 100000 print_every = 5000 print("\n开始训练...") for iter in range(1, n_iters + 1): category, line, category_tensor, line_tensor = randomTrainingExample() output, loss = train_classify(category_tensor, line_tensor) if iter % print_every == 0: guess, guess_i = categoryFromOutput(output) correct = '✓' if guess == category else '✗ (%s)' % category print(f'{iter} {iter / n_iters * 100:.1f}% {line:<20} {guess:<10} {correct}') print("\n训练完成!") # ==================== 测试预测 ==================== print("\n" + "="*50) print("测试预测:") print("="*50) predict('Dovesky') predict('Jackson') predict('Satoshi')

输出示例:

4.3 混淆矩阵可视化

python

import matplotlib.pyplot as plt import matplotlib.ticker as ticker confusion = torch.zeros(n_categories, n_categories) def evaluate_confusion(line_tensor): hidden = classify_rnn.initHidden() for i in range(line_tensor.size()[0]): output, hidden = classify_rnn(line_tensor[i], hidden) return output for i in range(10000): category, line, category_tensor, line_tensor = randomTrainingExample() output = evaluate_confusion(line_tensor) guess, guess_i = categoryFromOutput(output) category_i = all_categories.index(category) confusion[category_i][guess_i] += 1 # 归一化并绘图 for i in range(n_categories): confusion[i] = confusion[i] / confusion[i].sum() plt.figure(figsize=(15, 15)) plt.matshow(confusion.numpy()) plt.xticks(range(n_categories), all_categories, rotation=90) plt.yticks(range(n_categories), all_categories) plt.show()


5. Seq2Seq翻译模型

5.1 编码器(Encoder)

class EncoderRNN(nn.Module): def __init__(self, input_size, hidden_size): super(EncoderRNN, self).__init__() self.hidden_size = hidden_size self.embedding = nn.Embedding(input_size, hidden_size) self.gru = nn.GRU(hidden_size, hidden_size) def forward(self, input, hidden): embedded = self.embedding(input).view(1, 1, -1) output, hidden = self.gru(embedded, hidden) return output, hidden def initHidden(self): return torch.zeros(1, 1, self.hidden_size, device=device)

5.2 带注意力的解码器(Attentional Decoder)

class AttnDecoderRNN(nn.Module): def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=10): super(AttnDecoderRNN, self).__init__() self.hidden_size = hidden_size self.output_size = output_size self.dropout_p = dropout_p self.max_length = max_length self.embedding = nn.Embedding(output_size, hidden_size) self.attn = nn.Linear(hidden_size * 2, max_length) self.attn_combine = nn.Linear(hidden_size * 2, hidden_size) self.dropout = nn.Dropout(dropout_p) self.gru = nn.GRU(hidden_size, hidden_size) self.out = nn.Linear(hidden_size, output_size) def forward(self, input, hidden, encoder_outputs): embedded = self.embedding(input).view(1, 1, -1) embedded = self.dropout(embedded) # 计算注意力权重 attn_weights = F.softmax( self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1) attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0)) output = torch.cat((embedded[0], attn_applied[0]), 1) output = self.attn_combine(output).unsqueeze(0) output = F.relu(output) output, hidden = self.gru(output, hidden) output = F.log_softmax(self.out(output[0]), dim=1) return output, hidden, attn_weights def initHidden(self): return torch.zeros(1, 1, self.hidden_size, device=device)

5.3 训练与推理

def train_encoder_decoder(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=10, teacher_forcing_ratio=0.5): encoder_hidden = encoder.initHidden() encoder_optimizer.zero_grad() decoder_optimizer.zero_grad() input_length = input_tensor.size(0) target_length = target_tensor.size(0) encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device) loss = 0 # 编码阶段 for ei in range(input_length): encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden) encoder_outputs[ei] = encoder_output[0, 0] # 解码阶段 decoder_input = torch.tensor([[SOS_token]], device=device) decoder_hidden = encoder_hidden use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False if use_teacher_forcing: for di in range(target_length): decoder_output, decoder_hidden, _ = decoder( decoder_input, decoder_hidden, encoder_outputs) loss += criterion(decoder_output, target_tensor[di]) decoder_input = target_tensor[di] else: for di in range(target_length): decoder_output, decoder_hidden, _ = decoder( decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.topk(1) decoder_input = topi.squeeze().detach() loss += criterion(decoder_output, target_tensor[di]) if decoder_input.item() == EOS_token: break loss.backward() encoder_optimizer.step() decoder_optimizer.step() return loss.item() / target_length

6. 拓展方向与实践建议

6.1 升级到更高效的RNN变体

class LSTMGenerator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(LSTMGenerator, self).__init__() self.hidden_size = hidden_size self.lstm = nn.LSTM(n_categories + input_size, hidden_size, num_layers=2, dropout=0.2, batch_first=False) self.out = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=2) def forward(self, category, input_seq, hidden): # 拼接类别和输入序列 combined = torch.cat([category.unsqueeze(0).expand(input_seq.size(0), -1, -1), input_seq], dim=2) output, hidden = self.lstm(combined, hidden) output = self.softmax(self.out(output)) return output, hidden

6.2 更多可尝试的数据集

任务输入输出
姓别预测名字性别标签
角色归属角色名作者
国家-城市国家城市名
词性标注单词词性标签
商品分类商品名类目
对话生成输入文本回复文本

6.3 优化技巧总结

  1. Teacher Forcing:以一定概率使用真实目标值作为下一步输入,加速收敛

  2. Dropout:防止过拟合,特别是在生成任务中

  3. 梯度裁剪:避免梯度爆炸

  4. 学习率调度:随着训练进行适当降低学习率

  5. 批量训练:使用DataLoader和batch处理提高效率


7. 完整项目结构

text

project/ ├── data/ │ └── names/ │ ├── English.txt │ ├── French.txt │ └── ... (共18个文件) ├── models/ │ ├── encoder.py │ ├── decoder.py │ └── rnn_generator.py ├── utils/ │ ├── data_loader.py │ └── text_processing.py ├── train.py # 训练脚本 ├── generate.py # 名字生成脚本 ├── predict.py # 分类预测脚本 └── translate.py # 翻译脚本

8. 参考资料

  1. PyTorch官方教程Char RNN Generation

  2. 名字数据集下载地址

  3. Seq2Seq论文Learning Phrase Representations using RNN Encoder-Decoder

  4. 注意力机制论文Neural Machine Translation by Jointly Learning to Align and Translate

  5. Seq2Seq with AttentionPyTorch官方教程

这份指南涵盖了从基础RNN到高级Seq2Seq的所有核心内容,所有代码都经过精心注释,可以直接复制到本地环境中运行。祝你学习顺利!

http://www.jsqmd.com/news/1072812/

相关文章:

  • 车联网蓝牙测试:经典蓝牙数据抓包.(SSP配对模式)
  • OpencvSharp 算子学习教案之 - Cv2.Circle 重载2
  • 数字化赋能传统离散制造:智能化技术在高端石材工程领域的落地与深度优化
  • 【LangChain核心组件】文档加载器
  • 2018Y408
  • Sqlserver数据库日志文件过大(收缩/裁剪处理)
  • CSDN 高质量 DHCP 实验博文
  • 花5万买串口屏,总结出的7条血泪教训做储能设备的千万别再踩坑
  • CircleCI自动化_circleci-automation
  • 程序员跨境收支必备:查外汇网实战指南
  • 《Effective Python》读书笔记14: 附录 - 90条建议完整列表
  • 鸿蒙PC中使用ohos-sdk完成Rust适配,自动签名编译安装第三方库walkdir是 Rust 递归遍历目录的专用库
  • 第34章:自动化代码评审Agent——自动审查PR并给出建议
  • AI调试助手EAP谱试,连接周期从2天到3小时
  • 一篇文章带你入门漏洞靶场:从 0 到 1 玩转 bWAPP(附完整安装教程)
  • ChatGPT 转 pdf 怎么压缩但清晰,AI 导出鸭平衡体积与清晰度,告别文档臃肿问题
  • Codex CLI-03-AGENTS.md 编写指南:让 AI 理解你的项目
  • 屏幕截图文字识别工具帮你屏幕截图取字
  • 论文分享➲ arXiv2026 | H2HMem: A Multimodal Memory Benchmark for Agents in Human-Human Interactions
  • 鸿蒙PC适配llvm-gcc-compat编译安装第三方库convert_case,打造Rust 第三方字符串命名风格互相转换
  • 5分钟搞定OpenCode Go套餐无缝接入Claude Code,性价比直接起飞!
  • 鸿蒙 PC使用ohos-pip-autosign激活自动签名工具,安装第三方库arrow实现Python人性化时间处理库
  • 嵌入式linux学习记录十四、术语
  • 第二章 基本数据类型及其操作4
  • SoK: Taxonomy and Evaluation of Prompt Security in Large Language Models
  • 智谱清言能生成 word 吗?AI 导出鸭一站式搞定文档导出难题
  • 31. 完美转发:将参数原样传递
  • 在MacOS上如何安装配置工时通
  • 驱动更新工具
  • 第30章 「对称破缺」—— 悦儿篇