当前位置: 首页 > news >正文

知识图谱实战:手把手用PyTorch复现TuckER模型完成链接预测任务

知识图谱实战:手把手用PyTorch复现TuckER模型完成链接预测任务

知识图谱作为结构化知识的重要载体,在智能搜索、推荐系统和问答系统中发挥着关键作用。然而,现实中的知识图谱往往存在大量缺失链接,如何自动补全这些缺失信息成为学术界和工业界共同关注的焦点。本文将带你从零开始,用PyTorch实现TuckER模型——一种基于张量分解的知识图谱补全方法,完成链接预测任务。

1. 环境准备与数据加载

在开始编码之前,我们需要搭建合适的开发环境。推荐使用Python 3.8+和PyTorch 1.10+版本,这些组合能提供良好的兼容性和性能表现。

conda create -n kg python=3.8 conda activate kg pip install torch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 pip install pandas tqdm numpy

我们将使用FB15k-237数据集,这是知识图谱链接预测领域的标准基准数据集。它包含14,541个实体和237种关系,是原始FB15k数据集的改进版本,解决了其中的测试集泄露问题。

import pandas as pd def load_data(dataset_path): train = pd.read_csv(f"{dataset_path}/train.txt", sep="\t", header=None, names=["head", "relation", "tail"]) valid = pd.read_csv(f"{dataset_path}/valid.txt", sep="\t", header=None, names=["head", "relation", "tail"]) test = pd.read_csv(f"{dataset_path}/test.txt", sep="\t", header=None, names=["head", "relation", "tail"]) # 构建实体和关系的词汇表 entities = set(train["head"]).union(set(train["tail"])) relations = set(train["relation"]) return train, valid, test, entities, relations

提示:FB15k-237数据集可以从https://github.com/TimDettmers/ConvE获取,下载后解压到项目目录的data文件夹下。

2. TuckER模型架构解析与实现

TuckER模型的核心思想源自Tucker张量分解,它将知识图谱的三元组(头实体,关系,尾实体)表示为一个三阶张量,并通过分解这个张量来学习实体和关系的嵌入表示。

2.1 模型数学原理

TuckER模型的评分函数定义为:

φ(eₛ, r, eₒ) = W ×₁ eₛ ×₂ r ×₃ eₒ

其中:

  • W ∈ ℝ^{dₑ×dᵣ×dₑ} 是核心张量
  • eₛ, eₒ ∈ ℝ^{dₑ} 是头尾实体的嵌入向量
  • r ∈ ℝ^{dᵣ} 是关系的嵌入向量
  • ×ₙ 表示n模乘积

2.2 PyTorch实现

import torch import torch.nn as nn class TuckER(nn.Module): def __init__(self, num_entities, num_relations, entity_dim=200, relation_dim=200): super(TuckER, self).__init__() self.entity_embedding = nn.Embedding(num_entities, entity_dim) self.relation_embedding = nn.Embedding(num_relations, relation_dim) # 核心张量W的初始化 self.W = nn.Parameter(torch.randn(entity_dim, relation_dim, entity_dim)) # 初始化参数 nn.init.xavier_normal_(self.entity_embedding.weight) nn.init.xavier_normal_(self.relation_embedding.weight) nn.init.xavier_normal_(self.W) self.bce_loss = nn.BCELoss() def forward(self, heads, relations, tails): # 获取嵌入向量 e_s = self.entity_embedding(heads) # [batch_size, entity_dim] r = self.relation_embedding(relations) # [batch_size, relation_dim] e_o = self.entity_embedding(tails) # [batch_size, entity_dim] # 计算n模乘积 # W ×₂ r: [entity_dim, relation_dim, entity_dim] × [batch_size, relation_dim] W_r = torch.einsum('ijk,bi->bjk', self.W, r) # [batch_size, entity_dim, entity_dim] # (W ×₂ r) ×₁ e_s W_r_e_s = torch.einsum('bjk,bj->bk', W_r, e_s) # [batch_size, entity_dim] # ((W ×₂ r) ×₁ e_s) ×₂ e_o score = torch.einsum('bk,bk->b', W_r_e_s, e_o) # [batch_size] return torch.sigmoid(score)

3. 训练流程与技巧

3.1 负采样策略

知识图谱中只有正例三元组,我们需要生成负例来训练模型。常用的负采样方法包括:

  • 随机替换头实体或尾实体
  • 基于频率的替换(更少出现的实体有更高概率被选中)
  • 对抗性负采样(根据当前模型预测结果选择困难负例)
def generate_negative_samples(batch, num_entities, device='cpu'): heads, relations, tails = batch batch_size = heads.size(0) # 随机选择替换头实体或尾实体 neg_heads = heads.clone() neg_tails = tails.clone() # 50%概率替换头实体,50%概率替换尾实体 mask = torch.rand(batch_size, device=device) < 0.5 random_entities = torch.randint(0, num_entities, (batch_size,), device=device) neg_heads[mask] = random_entities[mask] neg_tails[~mask] = random_entities[~mask] return neg_heads, relations, neg_tails

3.2 训练循环实现

def train(model, train_data, num_epochs=100, batch_size=128, learning_rate=0.001): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for epoch in range(num_epochs): model.train() total_loss = 0.0 # 随机打乱训练数据 indices = torch.randperm(len(train_data)) for i in range(0, len(train_data), batch_size): batch_indices = indices[i:i+batch_size] batch = train_data[batch_indices] # 正例 pos_scores = model(batch[:,0], batch[:,1], batch[:,2]) pos_labels = torch.ones_like(pos_scores) # 负例 neg_heads, neg_rels, neg_tails = generate_negative_samples(batch, len(model.entity_embedding.weight)) neg_scores = model(neg_heads, neg_rels, neg_tails) neg_labels = torch.zeros_like(neg_scores) # 合并正负例 all_scores = torch.cat([pos_scores, neg_scores]) all_labels = torch.cat([pos_labels, neg_labels]) # 计算损失 loss = model.bce_loss(all_scores, all_labels) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss / (len(train_data)//batch_size)}")

注意:在实际应用中,建议添加学习率调度器和早停机制来优化训练过程。

4. 模型评估与结果分析

4.1 评估指标

知识图谱链接预测常用的评估指标包括:

  • MRR(Mean Reciprocal Rank):正确结果排名的倒数的平均值
  • Hits@k:正确结果出现在前k个预测中的比例
def evaluate(model, test_data, all_entities, device='cpu'): model.eval() ranks = [] with torch.no_grad(): for head, relation, tail in test_data: # 预测尾实体 head_tensor = torch.tensor([head]*len(all_entities), device=device) rel_tensor = torch.tensor([relation]*len(all_entities), device=device) tail_tensor = torch.tensor(list(all_entities), device=device) scores = model(head_tensor, rel_tensor, tail_tensor) sorted_indices = torch.argsort(scores, descending=True) # 找到正确尾实体的排名 rank = (sorted_indices == tail).nonzero().item() + 1 ranks.append(rank) # 计算指标 mrr = torch.mean(1.0 / torch.tensor(ranks, dtype=torch.float)).item() hits_10 = sum(r <= 10 for r in ranks) / len(ranks) hits_3 = sum(r <= 3 for r in ranks) / len(ranks) hits_1 = sum(r == 1 for r in ranks) / len(ranks) return {"MRR": mrr, "Hits@1": hits_1, "Hits@3": hits_3, "Hits@10": hits_10}

4.2 性能优化技巧

  1. 嵌入维度选择

    • 实体嵌入维度通常设置在100-500之间
    • 关系嵌入维度可以略小于实体嵌入维度
  2. 批量归一化: 在评分函数计算后添加批量归一化层可以稳定训练过程

  3. 标签平滑: 使用标签平滑技术可以防止模型对训练数据过拟合

class TuckERWithBN(nn.Module): def __init__(self, num_entities, num_relations, entity_dim=200, relation_dim=200): super().__init__() # ... 其他初始化代码同上 ... self.bn = nn.BatchNorm1d(1) def forward(self, heads, relations, tails): # ... 前面计算score的代码同上 ... score = torch.einsum('bk,bk->b', W_r_e_s, e_o) score = self.bn(score.unsqueeze(1)).squeeze(1) return torch.sigmoid(score)

5. 高级主题与扩展

5.1 多任务学习

TuckER的核心张量W可以看作是在不同关系间共享的知识,这种结构天然支持多任务学习。我们可以通过添加辅助任务来进一步提升模型性能:

class MultiTaskTuckER(TuckER): def __init__(self, num_entities, num_relations, entity_dim=200, relation_dim=200): super().__init__(num_entities, num_relations, entity_dim, relation_dim) self.relation_classifier = nn.Linear(relation_dim, num_relations) def forward(self, heads, relations, tails): # 原始链接预测任务 link_scores = super().forward(heads, relations, tails) # 关系分类任务 r_emb = self.relation_embedding(relations) relation_logits = self.relation_classifier(r_emb) return link_scores, relation_logits

5.2 模型压缩与部署

在实际应用中,模型可能需要部署到资源受限的环境。我们可以通过以下技术减小模型大小:

  1. 张量分解:对核心张量W进行低秩分解
  2. 量化:将模型参数从FP32转换为INT8
  3. 知识蒸馏:训练一个小型学生模型模仿大型教师模型的行为
class CompressedTuckER(nn.Module): def __init__(self, num_entities, num_relations, entity_dim=200, relation_dim=200, rank=50): super().__init__() self.entity_embedding = nn.Embedding(num_entities, entity_dim) self.relation_embedding = nn.Embedding(num_relations, relation_dim) # 低秩分解核心张量 self.U = nn.Parameter(torch.randn(entity_dim, rank)) self.V = nn.Parameter(torch.randn(relation_dim, rank)) self.W = nn.Parameter(torch.randn(entity_dim, rank)) def forward(self, heads, relations, tails): e_s = self.entity_embedding(heads) r = self.relation_embedding(relations) e_o = self.entity_embedding(tails) # 近似计算核心张量 core = torch.einsum('ik,jk,lk->ijl', self.U, self.V, self.W) # 计算评分 score = torch.einsum('ijk,bi,bj,bk->b', core, e_s, r, e_o) return torch.sigmoid(score)

在实际项目中,我发现合理设置嵌入维度和核心张量秩的平衡对模型性能影响很大。通常可以先使用完整模型训练,然后通过分析核心张量的奇异值来决定压缩后的秩大小。

http://www.jsqmd.com/news/681922/

相关文章:

  • Vue Antd Admin架构实战:如何构建高性能企业级中后台系统
  • 基于安卓的心理健康自评与干预系统毕设
  • 别再死记硬背DC脚本了!一个真实项目带你搞定Synopsys DC综合全流程(附完整脚本)
  • 飞书群聊的Jira Bug看板:手把手教你配置Jenkins定时任务和参数化构建
  • 为什么你需要Webcamoid:重新定义网络摄像头体验的终极工具
  • AssetRipper完全指南:三步掌握Unity资源提取终极工具
  • 金蝶云星空K3Cloud实战:手把手教你搞定生产退料单WEBAPI自定义(附完整C#代码)
  • 4月22日成都地区包钢产无缝钢管(8163-20#;外径42-630mm)现货报价 - 四川盛世钢联营销中心
  • 别再只会用QMessageBox::information了!Qt对话框进阶:手把手教你打造自定义按钮和详细信息的弹窗
  • 从模型到芯片:手把手教你用RKNN-Toolkit Lite在RV1126开发板上跑通第一个AI Demo
  • 手把手教你用STM32F411CEU6和W25Q128打造一个超迷你的U盘(附完整代码)
  • ExplorerPatcher终极指南:免费恢复Windows 11经典界面与高效工作流
  • NeRF实战:用Google Colab免费GPU,30分钟从照片生成你的第一个3D模型
  • Tesseract OCR终极指南:如何用开源引擎实现高效文字识别
  • openKylin 2.0 SP2第三次更新:优化关键模块,新增装包功能提升速度
  • TI C2000 DSP的CAN中断实战:一个邮箱如何接收多个ID的数据帧?
  • 5分钟快速上手PKHeX自动合法性插件:宝可梦数据合规终极指南
  • 从‘秒’到‘纳秒’:手把手教你用`std::chrono`设计一个带暂停/重置功能的跨平台计时器类
  • 别再只用MD5了!深入对比PostgreSQL的SCRAM-SHA-256和MD5,附AWS RDS实战配置避坑指南
  • Django后台进阶:用SimpleUI自定义菜单与数据展示,打造你的专属运营中台
  • 22日成都市批发兼零售螺旋焊管(Q235B;内径DN200-3500mm)现货报价 - 四川盛世钢联营销中心
  • Mac音乐解密神器:3分钟解锁QQ音乐加密格式,让音乐自由播放
  • ComfyUI-Impact-Pack:AI图像精细化处理的全能工具包
  • Visual Syslog Server:Windows平台最完整的日志集中管理终极指南
  • 彻底告别激活烦恼:KMS智能激活脚本终极解决方案
  • 目前口碑好的GEO全托管供应商找哪家 - 小张小张111
  • 如何高效解决B站视频下载难题:BiliDownloader实战指南
  • 联想电脑开机进入 Diagnostics UEFI 界面?一文教你快速退出 + 排查原因
  • 抖音无水印视频下载终极教程:3步免费批量保存完整作品集
  • DPABI实战:手把手教你搞定静息态fMRI统计分析与多重比较矫正(附避坑指南)