当前位置: 首页 > news >正文

唐诗模型训练及使用

数据集:

训练代码

# train_tangshi.py — 检索模型:对比学习编码 + 诗句向量库相似度检索
import argparse
import random
import re
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Pathimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Samplerfrom retrieval import eval_text_hit, search_textDATA_PATH = Path(__file__).parent / "tang_poems.txt"
SAVE_DIR = Path(__file__).parent / "checkpoints"MAX_LEN = 48
BATCH_SIZE = 256
POEMS_PER_BATCH = 32
LINES_PER_POEM = BATCH_SIZE // POEMS_PER_BATCH
EMBED_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2
DROPOUT = 0.2
LR = 3e-4
EPOCHS = 20
TEMPERATURE = 0.07
SUBSTR_PROB = 0.35
MIN_SUBSTR = 2
SEED = 42
ENCODE_BATCH = 512
VAL_CHUNK = 256def get_device():if torch.cuda.is_available():return torch.device("cuda")if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():return torch.device("mps")return torch.device("cpu")HEADER_RE = re.compile(r"^【(.+?)】《(.+?)》\s*$")@dataclass
class Poem:author: strtitle: strlines: list[str]def format(self) -> str:return f"【{self.author}】《{self.title}》\n" + "\n".join(self.lines)def load_poems(path: Path) -> list[Poem]:text = path.read_text(encoding="utf-8")poems = []for block in re.split(r"\n\s*\n", text.strip()):raw = [ln.strip() for ln in block.splitlines() if ln.strip()]if not raw:continuem = HEADER_RE.match(raw[0])if not m:continuelines = [ln for ln in raw[1:] if ln]if lines:poems.append(Poem(m.group(1), m.group(2), lines))return poemsclass CharVocab:def __init__(self, chars: str):self.itos = ["<pad>", "<unk>"] + sorted(set(chars))self.stoi = {c: i for i, c in enumerate(self.itos)}self.pad_id = 0def __len__(self):return len(self.itos)@classmethoddef from_itos(cls, itos: list[str]) -> "CharVocab":v = cls.__new__(cls)v.itos = itosv.stoi = {c: i for i, c in enumerate(itos)}v.pad_id = 0return vdef encode(self, s: str, max_len: int) -> torch.Tensor:ids = [self.stoi.get(c, 1) for c in s[:max_len]]if len(ids) < max_len:ids += [self.pad_id] * (max_len - len(ids))return torch.tensor(ids, dtype=torch.long)def build_samples(poems: list[Poem]) -> list[tuple[str, int]]:samples = []for pid, poem in enumerate(poems):for line in poem.lines:line = line.strip()if len(line) >= 2:samples.append((line, pid))return samplesclass Line2PoemDataset(Dataset):def __init__(self, samples: list[tuple[str, int]], vocab: CharVocab, max_len: int, augment: bool):self.samples = samplesself.vocab = vocabself.max_len = max_lenself.augment = augmentdef __len__(self):return len(self.samples)def _maybe_substring(self, line: str) -> str:if not self.augment or len(line) <= MIN_SUBSTR or random.random() > SUBSTR_PROB:return linen = random.randint(MIN_SUBSTR, len(line))start = random.randint(0, len(line) - n)return line[start : start + n]def __getitem__(self, idx):line, pid = self.samples[idx]line = self._maybe_substring(line)return self.vocab.encode(line, self.max_len), pidclass PoemBatchSampler(Sampler):"""每批含多首诗、每首多句,保证对比学习有同诗正样本。"""def __init__(self, poem_to_indices: dict[int, list[int]], poems_per_batch: int, lines_per_poem: int):self.poem_to_indices = poem_to_indicesself.pids = list(poem_to_indices.keys())self.poems_per_batch = poems_per_batchself.lines_per_poem = lines_per_poemdef __len__(self):return len(self.pids) // self.poems_per_batchdef __iter__(self):pids = self.pids.copy()random.shuffle(pids)for start in range(0, len(pids) - self.poems_per_batch + 1, self.poems_per_batch):batch_idx = []for pid in pids[start : start + self.poems_per_batch]:idxs = self.poem_to_indices[pid]if len(idxs) >= self.lines_per_poem:batch_idx.extend(random.sample(idxs, self.lines_per_poem))else:batch_idx.extend(random.choices(idxs, k=self.lines_per_poem))yield batch_idxdef collate_batch(batch):xs, pids = zip(*batch)return torch.stack(xs), torch.tensor(pids, dtype=torch.long)class LineEncoder(nn.Module):def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_layers: int, dropout: float):super().__init__()self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)self.lstm = nn.LSTM(embed_dim,hidden_dim,num_layers=num_layers,batch_first=True,dropout=dropout if num_layers > 1 else 0.0,)self.dropout = nn.Dropout(dropout)self.proj = nn.Linear(hidden_dim, embed_dim)def forward(self, x: torch.Tensor) -> torch.Tensor:emb = self.embed(x)lengths = (x != 0).sum(dim=1).clamp(min=1)packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)_, (h, _) = self.lstm(packed)h = self.dropout(h[-1])z = self.proj(h)return F.normalize(z, dim=1)def supcon_loss(features: torch.Tensor, labels: torch.Tensor, temperature: float) -> torch.Tensor:"""同诗诗句互为正样本的监督对比损失。"""device = features.deviceb = features.size(0)sim = torch.matmul(features, features.T) / temperaturelabels = labels.view(-1, 1)mask_pos = (labels == labels.T).float()logits_mask = 1.0 - torch.eye(b, device=device)mask_pos = mask_pos * logits_maskexp_sim = torch.exp(sim) * logits_masklog_denom = torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)log_prob = sim - log_denompos_count = mask_pos.sum(dim=1)loss = -((mask_pos * log_prob).sum(dim=1) / pos_count.clamp(min=1))valid = pos_count > 0if valid.any():return loss[valid].mean()return loss.mean()@torch.no_grad()
def encode_samples(model: LineEncoder,samples: list[tuple[str, int]],vocab: CharVocab,max_len: int,device: torch.device,batch_size: int = ENCODE_BATCH,
) -> tuple[torch.Tensor, torch.Tensor]:model.eval()embs, pids = [], []for i in range(0, len(samples), batch_size):chunk = samples[i : i + batch_size]xs = torch.stack([vocab.encode(line, max_len) for line, _ in chunk]).to(device)embs.append(model(xs).cpu())pids.extend(pid for _, pid in chunk)return torch.cat(embs, dim=0), torch.tensor(pids, dtype=torch.long)@torch.no_grad()
def eval_retrieval(model: LineEncoder,queries: list[tuple[str, int]],bank: torch.Tensor,bank_pids: torch.Tensor,vocab: CharVocab,max_len: int,device: torch.device,topk: int = 1,
) -> float:"""在诗句向量库上做 Hit@K:top-K 中是否包含正确诗 ID。"""if not queries:return 0.0bank = bank.to(device)bank_pids = bank_pids.to(device)hits, total = 0, 0for i in range(0, len(queries), VAL_CHUNK):chunk = queries[i : i + VAL_CHUNK]xs = torch.stack([vocab.encode(line, max_len) for line, _ in chunk]).to(device)q = model(xs)sim = q @ bank.Ttopk_idx = sim.topk(min(topk, bank.size(0)), dim=1).indicesfor j, (_, true_pid) in enumerate(chunk):pred_pids = bank_pids[topk_idx[j]].tolist()if true_pid in pred_pids:hits += 1total += 1return hits / total@torch.no_grad()
def search_lines(model: LineEncoder,vocab: CharVocab,query: str,bank: torch.Tensor,bank_pids: torch.Tensor,max_len: int,device: torch.device,topk: int,
) -> list[tuple[int, float]]:model.eval()x = vocab.encode(query.strip(), max_len).unsqueeze(0).to(device)q = model(x)bank_d = bank.to(device)sim = (q @ bank_d.T).squeeze(0)k = min(topk, sim.size(0))scores, idx = sim.topk(k)return [(int(bank_pids[i]), float(scores[j])) for j, i in enumerate(idx.tolist())]def train_one_epoch(model, loader, optimizer, device, scaler=None):model.train()total_loss, n = 0.0, 0use_amp = scaler is not None and device.type == "cuda"for x, y in loader:x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)optimizer.zero_grad(set_to_none=True)if use_amp:with torch.amp.autocast("cuda", dtype=torch.float16):z = model(x)loss = supcon_loss(z, y, TEMPERATURE)scaler.scale(loss).backward()scaler.unscale_(optimizer)torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)scaler.step(optimizer)scaler.update()else:z = model(x)loss = supcon_loss(z, y, TEMPERATURE)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)optimizer.step()total_loss += loss.item() * x.size(0)n += x.size(0)return total_loss / max(n, 1)def main():parser = argparse.ArgumentParser(description="唐诗检索模型(对比学习)")parser.add_argument("--epochs", type=int, default=EPOCHS)parser.add_argument("--batch-size", type=int, default=BATCH_SIZE)args = parser.parse_args()poems_per_batch = max(8, args.batch_size // LINES_PER_POEM)lines_per_poem = args.batch_size // poems_per_batchrandom.seed(SEED)torch.manual_seed(SEED)device = get_device()print(f"使用设备: {device}")if device.type == "cuda":print(f"GPU: {torch.cuda.get_device_name(0)}")print(f"任务: 检索模型(对比学习 + 诗句向量库)")print(f"每批 {poems_per_batch} 首诗 × {lines_per_poem} 句 = {poems_per_batch * lines_per_poem} 样本")poems = load_poems(DATA_PATH)print(f"诗词数量: {len(poems)}")all_chars = "".join(p.author + p.title + "".join(p.lines) for p in poems)vocab = CharVocab(all_chars)print(f"字符表大小: {len(vocab)}")samples = build_samples(poems)print(f"诗句行总数: {len(samples)}")random.shuffle(samples)split = int(len(samples) * 0.98)train_samples, val_samples = samples[:split], samples[split:]poem_to_indices: dict[int, list[int]] = defaultdict(list)for i, (_, pid) in enumerate(train_samples):poem_to_indices[pid].append(i)train_ds = Line2PoemDataset(train_samples, vocab, MAX_LEN, augment=True)train_loader = DataLoader(train_ds,batch_sampler=PoemBatchSampler(poem_to_indices, poems_per_batch, lines_per_poem),collate_fn=collate_batch,num_workers=0,)model = LineEncoder(len(vocab), EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, DROPOUT).to(device)optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)scaler = torch.amp.GradScaler("cuda") if device.type == "cuda" else NoneSAVE_DIR.mkdir(parents=True, exist_ok=True)best_hit1 = 0.0print("正在编码训练集诗句库(用于验证)...")train_bank, train_bank_pids = encode_samples(model, train_samples, vocab, MAX_LEN, device)for epoch in range(1, args.epochs + 1):loss = train_one_epoch(model, train_loader, optimizer, device, scaler)hit1 = eval_retrieval(model, val_samples, train_bank, train_bank_pids, vocab, MAX_LEN, device, topk=1)hit5 = eval_retrieval(model, val_samples, train_bank, train_bank_pids, vocab, MAX_LEN, device, topk=5)text1 = eval_text_hit(poems, val_samples, topk=1)text5 = eval_text_hit(poems, val_samples, topk=5)print(f"Epoch {epoch}/{args.epochs}  loss={loss:.4f}  "f"val_text@1={text1:.2%}  val_text@5={text5:.2%}  "f"val_neural@1={hit1:.2%}  val_neural@5={hit5:.2%}")if hit1 >= best_hit1:best_hit1 = hit1print("  刷新最佳,重建全库向量...")full_bank, full_bank_pids = encode_samples(model, samples, vocab, MAX_LEN, device)torch.save({"model": model.state_dict(),"vocab_itos": vocab.itos,"poems": [{"author": p.author, "title": p.title, "lines": p.lines} for p in poems],"line_bank": full_bank.half(),"line_poem_ids": full_bank_pids,"config": {"max_len": MAX_LEN,"embed_dim": EMBED_DIM,"hidden_dim": HIDDEN_DIM,"num_layers": NUM_LAYERS,"dropout": DROPOUT,"temperature": TEMPERATURE,"model_type": "retrieval",},},SAVE_DIR / "tangshi_retriever_best.pt",)if epoch == 1 or epoch % 3 == 0:line, pid = random.choice(val_samples)th = search_text(poems, line, limit=1)if th:rid, score = th[0].poem_id, th[0].scoretag = "文本"else:rid, score = search_lines(model, vocab, line, train_bank, train_bank_pids, MAX_LEN, device, topk=1)[0]tag = "向量"ok = "" if rid == pid else ""print(f"  试查[{tag}] {ok} 输入: {line}")print(f"  预测: 【{poems[rid].author}】《{poems[rid].title}》  "f"真值: 【{poems[pid].author}】《{poems[pid].title}》  {score:.3f}")if epoch % 5 == 0 or epoch == args.epochs:train_bank, train_bank_pids = encode_samples(model, train_samples, vocab, MAX_LEN, device)print(f"训练完成,最佳 val_neural@1: {best_hit1:.2%}")print(f"权重: {SAVE_DIR / 'tangshi_retriever_best.pt'}")print('推理: python predict_tangshi.py "春眠不觉晓"')print('或仅文本: python predict_tangshi.py --text-only "春眠不觉晓"')if __name__ == "__main__":main()
train_tangshi.py

image

预测模型

# predict_tangshi.py — 混合检索:精确/子串/模糊文本 + 向量相似度
import argparse
from pathlib import Pathimport torchfrom retrieval import (Hit,build_exact_index,dedupe_poems,merge_hits,search_exact,search_text,
)
from train_tangshi import (MAX_LEN,Poem,LineEncoder,CharVocab,build_samples,encode_samples,get_device,search_lines,
)CKPT_PATH = Path(__file__).parent / "checkpoints" / "tangshi_retriever_best.pt"def load_checkpoint(path: Path, device):ckpt = torch.load(path, map_location=device, weights_only=False)poems = [Poem(p["author"], p["title"], p["lines"]) for p in ckpt["poems"]]vocab = CharVocab.from_itos(ckpt["vocab_itos"])cfg = ckpt["config"]model = LineEncoder(len(vocab),cfg["embed_dim"],cfg["hidden_dim"],cfg["num_layers"],cfg["dropout"],).to(device)model.load_state_dict(ckpt["model"])model.eval()if "line_bank" in ckpt:bank = ckpt["line_bank"].float()bank_pids = ckpt["line_poem_ids"]else:print("检查点无向量库,正在现场编码全库(较慢)...")samples = build_samples(poems)bank, bank_pids = encode_samples(model, samples, vocab, cfg.get("max_len", MAX_LEN), device)return model, vocab, poems, bank, bank_pids, cfg.get("max_len", MAX_LEN)def neural_hits(model, vocab, query, bank, bank_pids, max_len, device, search_k: int) -> list[Hit]:raw = search_lines(model, vocab, query, bank, bank_pids, max_len, device, topk=search_k)return [Hit(pid, score, "", "neural") for pid, score in raw]def format_hit(poems: list[Poem], h: Hit, rank: int) -> str:method_cn = {"exact": "精确匹配", "text": "文本匹配", "neural": "向量检索"}.get(h.method, h.method)parts = [f"--- 候选 {rank}({method_cn},得分 {h.score:.3f})---"]if h.matched_line:parts.append(f"匹配句: {h.matched_line}")parts.append(poems[h.poem_id].format())return "\n".join(parts)def predict(poems: list[Poem],exact_index: dict[str, int],model,vocab,bank,bank_pids,query: str,device,max_len: int,topk: int = 3,search_k: int = 30,text_only: bool = False,
) -> str:query = query.strip()if not query:return "请输入诗句。"exact = search_exact(exact_index, query)if exact:return f"(精确匹配)\n{poems[exact.poem_id].format()}"text_hits = search_text(poems, query, min_score=0.55, limit=search_k)if text_only or text_hits and text_hits[0].score >= 0.85:hits = dedupe_poems(text_hits, topk)mode = "文本检索(子串/模糊)"else:nh = [] if text_only else neural_hits(model, vocab, query, bank, bank_pids, max_len, device, search_k)hits = merge_hits(text_hits, nh, topk)mode = "混合检索"if not hits:return f"查询: {query}\n未检索到相关诗句。"parts = [f"查询: {query}", f"方式: {mode}", ""]for i, h in enumerate(hits, 1):parts.append(format_hit(poems, h, i))parts.append("")return "\n".join(parts).rstrip()def main():parser = argparse.ArgumentParser(description="唐诗检索:输入诗句,返回整首")parser.add_argument("query", nargs="?", help="诗句,省略则进入交互模式")parser.add_argument("--ckpt", type=Path, default=CKPT_PATH)parser.add_argument("--topk", type=int, default=3)parser.add_argument("--text-only", action="store_true", help="仅用文本匹配,不用向量")args = parser.parse_args()device = get_device()if args.text_only:from train_tangshi import load_poems, DATA_PATHpoems = load_poems(DATA_PATH)model = vocab = bank = bank_pids = Nonemax_len = MAX_LENprint(f"已加载 {len(poems)} 首诗(仅文本匹配)\n")else:if not args.ckpt.exists():print(f"未找到: {args.ckpt}")print("可先用文本匹配: python predict_tangshi.py --text-only \"春眠不觉晓\"")print("或训练后: python train_tangshi.py")returnmodel, vocab, poems, bank, bank_pids, max_len = load_checkpoint(args.ckpt, device)print(f"已加载 {len(poems)} 首诗,向量库 {bank.size(0)} 条,设备: {device}\n")exact_index = build_exact_index(poems)def run_one(q: str):print(predict(poems,exact_index,model,vocab,bank,bank_pids,q,device,max_len,args.topk,text_only=args.text_only,))print()if args.query:run_one(args.query)returnprint("交互模式:输入诗句或片段;空行退出。")while True:try:q = input("诗句> ").strip()except (EOFError, KeyboardInterrupt):print()breakif not q:breakrun_one(q)if __name__ == "__main__":main()
predict_tangshi.py

image

 

http://www.jsqmd.com/news/897490/

相关文章:

  • 深度解析IDM激活脚本:从新手到专家的完整实战指南
  • AI生成内容声明必须包含的6个法律锚点,少1个即触发GDPR第58条执法调查——ChatGPT声明合规性压力测试报告
  • 全球ChatGPT替代率警报:客服、初阶编程、基础法律咨询等7类岗位需求萎缩超35%,但复合型提示工程师缺口达210万(附认证路径图)
  • 抖音无水印批量下载工具:三步法搞定内容采集与数据管理
  • 基于C2PA与TPM的实时视频流媒体内容溯源与认证系统设计与实现
  • Hive性能调优实战:告别Order By,拥抱Sort By与Distribute By
  • 5分钟免费汉化Axure全版本:告别英文界面,提升设计效率的完整指南
  • 从数据精准到非标定制:2026年污水COD检测仪哪家靠谱?头部企业技术实力与品牌解析 - 品牌推荐大师1
  • OpCore Simplify:5分钟自动化完成OpenCore配置的黑苹果利器
  • 教练辅助MARL框架:提升多智能体系统在智能体崩溃下的鲁棒性
  • 2026南京结婚西装定制权威评测:准新郎必收藏5大高口碑店铺排名 - 西装爱好者
  • 从零打造可落地的直流电机 PID 驱动系统 (十二):电流环控制实现
  • 从API密钥管理混乱到集中管控与审计日志带来的安全感
  • OpenClaw Agent 工作流无缝接入 Taotoken 的配置要点详解
  • 华硕笔记本性能优化神器GHelper:5分钟从卡顿到流畅的实战指南
  • 从 Web 到移动端再到打印:Highcharts 如何实现跨平台一致性图表体验
  • 说明书驱动机器学习开发:用Warp/Oz架构解决MLOps协作难题
  • 5分钟快速上手:用novelWriter高效管理你的小说创作
  • Codex「自我蒸馏」秘籍曝光:从程序员专属到全场景适用,能否解决token难题?
  • CentOS7 上 Oracle12c 企业级部署与深度配置实战
  • 万国全国售后网络焕新升级:2026年6月最新官方客户服务全指南 - 亨得利官方服务中心
  • RAG 系统知识库查不准问题治理:从模块职责划分到检索链路闭环设计
  • 专业守护时光:2026浪琴官方售后服务体系全解析 - 浪琴服务中心
  • LuaJIT字节码反编译:从黑盒到可读代码的3步实战指南
  • 基于主动推理的计算连续体碳感知调度:架构设计与工程实践
  • Flutter Widget组件学习(专为 Uniapp 转 Flutter 定制)
  • 体验Taotoken旗舰模型首发更新第一时间用上最新最强模型
  • 多云管理工具:统一管理多个云平台资源
  • 2026年河北玻璃钢环保设备采购指南:电缆桥架、化粪池、一体化泵站品牌深度横评 - 精选优质企业推荐官
  • 基于诊断引导与置信感知的故障鲁棒声源定位系统