当前位置: 首页 > news >正文

torch-rechub学习打卡笔记(一)

【Torch-RecHub 学习笔记】Task 1:环境搭建与 DSSM 召回实战

1. 任务背景

在推荐系统领域,高效的特征处理和模型框架是开展研究的基础。本次 Task 1 的核心目标是完成 Torch-RecHub 环境配置,并跑通基础的召回模型流程。


2. 环境准备

根据官方文档要求,确保系统满足以下基础配置:

  • Python: 3.9+
  • PyTorch: 1.7+(推荐 CUDA 版本)
  • 核心依赖: NumPy, Pandas, SciPy, Scikit-learn

3. 框架设计与 DSSM 模型

Torch-RecHub 采用模块化设计,将推荐任务抽象为 特征(Features)模型(Models)训练器(Trainers)

本次实验使用 DSSM(Deep Structured Semantic Model) 作为召回模型。DSSM 是一种经典的双塔结构,通过分别构建用户塔和物品塔,将二者映射到同一低维向量空间中,再利用向量相似度完成匹配。


4. 实验流程说明

整体实验流程如下:

  1. 数据加载与预处理(MovieLens-1M 采样数据)
  2. 类别特征编码(Label Encoding)
  3. 构建用户特征与物品特征
  4. 构造序列特征与训练样本
  5. 定义 DSSM 双塔模型
  6. 模型训练
  7. 用户与物品向量导出

5. 关键代码实现

5.1 数据加载与预处理

import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import LabelEncoderfrom torch_rechub.basic.features import SparseFeature, SequenceFeature
from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import df_to_dict, MatchDataGenerator
from torch_rechub.utils.match import generate_seq_feature_match, gen_model_inputtorch.manual_seed(2022)# Load data
data_url = "https://raw.githubusercontent.com/datawhalechina/torch-rechub/main/examples/matching/data/ml-1m/ml-1m_sample.csv"
data = pd.read_csv(data_url)
print(f"Dataset size: {len(data)} records")# Category feature
data["cate_id"] = data["genres"].apply(lambda x: x.split("|")[0])

5.2 特征编码

user_col, item_col = "user_id", "movie_id"
sparse_features = ["user_id", "movie_id", "gender", "age", "occupation", "zip", "cate_id"]feature_max_idx = {}
for feat in sparse_features:encoder = LabelEncoder()data[feat] = encoder.fit_transform(data[feat]) + 1feature_max_idx[feat] = data[feat].max() + 1

5.3 用户与物品画像构建

user_cols = ["user_id", "gender", "age", "occupation", "zip"]
item_cols = ["movie_id", "cate_id"]user_profile = data[user_cols].drop_duplicates("user_id")
item_profile = data[item_cols].drop_duplicates("movie_id")

5.4 序列特征与训练数据生成

df_train, df_test = generate_seq_feature_match(data,user_col,item_col,time_col="timestamp",item_attribute_cols=[],sample_method=1,mode=0,neg_ratio=3,min_item=0
)x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=50)
y_train = x_train["label"]x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=50)

5.5 特征类型定义(重点修正)

user_features = [SparseFeature(name, vocab_size=feature_max_idx[name], embed_dim=16)for name in user_cols
]user_features += [SequenceFeature("hist_movie_id",vocab_size=feature_max_idx["movie_id"],embed_dim=16,pooling="mean",shared_with="movie_id")
]item_features = [SparseFeature(name, vocab_size=feature_max_idx[name], embed_dim=16)for name in item_cols
]

5.6 DataLoader 与模型定义

all_item = df_to_dict(item_profile)
test_user = x_testdg = MatchDataGenerator(x=x_train, y=y_train)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)model = DSSM(user_features,item_features,temperature=0.02,user_params={"dims": [128, 64, 32], "activation": "prelu"},item_params={"dims": [128, 64, 32], "activation": "prelu"},
)

5.7 模型训练与向量导出

trainer = MatchTrainer(model,mode=0,optimizer_params={"lr": 1e-4, "weight_decay": 1e-6},n_epoch=3,device="cpu",
)trainer.fit(train_dl)user_embedding = trainer.inference_embedding(model, mode="user", data_loader=test_dl, model_path="./")
item_embedding = trainer.inference_embedding(model, mode="item", data_loader=item_dl, model_path="./")print(user_embedding.shape)
print(item_embedding.shape)

6. 运行结果与分析

  • 样本规模:共处理 100 条采样记录,生成训练集 384 条,测试集 2 条。
  • 向量维度:成功导出用户与物品向量,Embedding 维度均为 32。
User embedding shape: torch.Size([2, 32])
Item embedding shape: torch.Size([93, 32])

7. 学习总结

通过 Task 1 的学习,我完整跑通了从环境搭建、特征工程到 DSSM 模型推理的全过程。Torch-RecHub 在特征抽象与训练流程上的高度封装,使得实验实现更加清晰高效。后续计划将社区搜索与召回模型相结合,探索更复杂的交互式推荐场景。


http://www.jsqmd.com/news/366762/

相关文章:

  • 还在愁论文?AI 写论文软件排行榜你真会选吗?
  • Linux文件目录权限
  • TEASOFT驱动Keysight示波器自动截图:一键获取波形图并嵌入CSDN
  • AI行业入门必看:收藏这份岗位指南,小白也能抓住大模型机遇!
  • 还在找论文神器?AI 写作软件排行榜答案在这
  • 掌握AI能力图谱,从入门到精通:收藏这份AI产品经理实战指南
  • 合肥三十六行(石家庄)分公司 本地生活数字化服务标杆 - 野榜数据排行
  • 拒绝被替代:做 AI 时代的“知识饲养员”,而不是“操作工”
  • 语音通话库——VoLTE功能集成方案
  • 完整教程:核药:以放射性核素为 “探针” 与 “武器”,重塑疾病精准诊疗格局
  • 利用MATLAB程序复现二氧化钒(VO2)介电常数的计算方法及在CST中创建Drude模型的详...
  • 精密制造QMS解决方案:海岸线PQM破解质量追溯与交付难题
  • 2026国内最新全屋定制板材十大实力厂家推荐!山东等地优质环保/抗菌/ENF级/门墙柜一体化板材品牌权威榜单发布 - 品牌推荐2026
  • 直流电压源+双向DCDC变换器+负载+锂离子电池+控制系统,Simulink仿真模型。 有两种...
  • 2026年2月哈尔滨跟团游旅行社竞争格局深度分析报告 - 2026年企业推荐榜
  • 2026年全国真发假发定制品牌哪家专业?聚焦高端品质与个性化适配方向 - 深度智识库
  • 2026年开福区足疗老店评测:一站式奢享体验成新标杆 - 2026年企业推荐榜
  • 打卡信奥刷题(2825)用C++实现信奥题 P4231 三步必杀
  • 从ChatGPT到新质生产力:一份信息驱动的AI研究方向指南
  • Zabbix数据采集页面,主机可用性是灰色的问题排查解决笔记
  • YC 2026未来方向
  • 阿如那从极致反派到热血番男主,网友:内娱需要这样的男主
  • 2026全国管材源头厂家实力榜:涵盖 PE 管、PVC 管、复合管 - 深度智识库
  • 多号发圈终于不用来回切换了,3步搞定!
  • 劝所有私域运营/销售:微信自动回复早用早轻松
  • 动态模型切割工具EzySlice完整实现逻辑
  • 环境治理AI:异常检测在基础设施污染源的自动定位工具
  • 2026年2月哈尔滨跟团游旅行社战略选择与五强深度解析 - 2026年企业推荐榜
  • 京东比价项目的开展和API接口接入的具体步骤是什么?
  • 大模型时代,普通人也能入行AI?收藏这份3步进阶指南,3-5个月实现职业跃迁!