当前位置: 首页 > news >正文

从公式到代码:手把手复现阿里ESMM模型(PaddlePaddle/PyTorch版)

从公式到代码:手把手复现阿里ESMM模型(PaddlePaddle/PyTorch版)

在推荐系统的技术演进中,多任务学习已成为提升模型效果的关键策略。阿里妈妈团队提出的ESMM(Entire Space Multi-Task Model)通过创新的概率分解思想,有效解决了转化率预估中的样本选择偏差和数据稀疏问题。本文将带您深入理解这一经典模型的数学原理,并分别用PaddlePaddle和PyTorch框架实现完整解决方案。

1. ESMM核心原理与技术突破

1.1 概率图视角下的转化链路建模

电商场景的用户行为遵循严格的曝光→点击→转化顺序。设:

  • $x$:用户特征和上下文特征
  • $y$:点击行为(0/1)
  • $z$:转化行为(0/1)

ESMM的核心公式揭示了三个关键概率的关系: $$ pCTCVR = pCTR \times pCVR $$ 其中:

  • $pCTR = p(y=1|x)$
  • $pCVR = p(z=1|y=1,x)$
  • $pCTCVR = p(y=1,z=1|x)$

这种分解带来两个重要优势:

  1. 全空间建模:CTR和CTCVR任务可以使用全部曝光样本训练
  2. 隐式学习:CVR参数通过乘积关系间接优化

1.2 网络架构设计精要

ESMM的模型结构包含三个核心组件:

组件输入输出训练样本
共享Embedding层原始特征特征嵌入全量曝光样本
CTR塔特征嵌入pCTR全量曝光样本
CVR塔特征嵌入pCVR仅点击样本

注意:虽然CVR塔理论上只处理点击样本,但其参数通过CTCVR的联合损失进行更新

2. PaddlePaddle实现详解

2.1 环境配置与数据准备

首先安装必要依赖:

pip install paddlepaddle==2.4.0 pip install pandas sklearn

准备示例数据格式:

user_iditem_idcate_idclickconversion
12345678910
23456789000

特征处理关键代码:

import paddle from paddle.io import Dataset class ESMNDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, idx): sample = self.data.iloc[idx] return { 'user_id': sample['user_id'], 'item_id': sample['item_id'], 'cate_id': sample['cate_id'], 'click': sample['click'], 'conversion': sample['conversion'] & sample['click'] }

2.2 模型构建

完整模型实现:

class ESMM(paddle.nn.Layer): def __init__(self, user_num, item_num, cate_num, embed_dim=64): super().__init__() # 共享特征嵌入 self.user_emb = paddle.nn.Embedding(user_num, embed_dim) self.item_emb = paddle.nn.Embedding(item_num, embed_dim) self.cate_emb = paddle.nn.Embedding(cate_num, embed_dim) # CTR塔 self.ctr_mlp = paddle.nn.Sequential( paddle.nn.Linear(embed_dim*3, 128), paddle.nn.ReLU(), paddle.nn.Linear(128, 64), paddle.nn.ReLU(), paddle.nn.Linear(64, 2) ) # CVR塔 self.cvr_mlp = paddle.nn.Sequential( paddle.nn.Linear(embed_dim*3, 128), paddle.nn.ReLU(), paddle.nn.Linear(128, 64), paddle.nn.ReLU(), paddle.nn.Linear(64, 2) ) def forward(self, inputs): # 特征嵌入 user_emb = self.user_emb(inputs['user_id']) item_emb = self.item_emb(inputs['item_id']) cate_emb = self.cate_emb(inputs['cate_id']) concat_emb = paddle.concat([user_emb, item_emb, cate_emb], axis=1) # CTR预测 ctr_logits = self.ctr_mlp(concat_emb) ctr_pred = paddle.nn.functional.softmax(ctr_logits)[:, 1] # CVR预测 cvr_logits = self.cvr_mlp(concat_emb) cvr_pred = paddle.nn.functional.softmax(cvr_logits)[:, 1] # CTCVR计算 ctcvr_pred = ctr_pred * cvr_pred return ctr_pred, cvr_pred, ctcvr_pred

2.3 自定义损失函数

实现论文中的联合损失:

class ESMNLoss(paddle.nn.Layer): def __init__(self): super().__init__() self.ctr_loss = paddle.nn.BCELoss() self.ctcvr_loss = paddle.nn.BCELoss() def forward(self, preds, labels): ctr_pred, _, ctcvr_pred = preds click_label = labels['click'] conversion_label = labels['conversion'] ctr_loss = self.ctr_loss(ctr_pred, click_label) ctcvr_loss = self.ctcvr_loss(ctcvr_pred, conversion_label) return ctr_loss + ctcvr_loss

3. PyTorch实现方案

3.1 模型结构迁移

PyTorch版本的核心差异:

import torch import torch.nn as nn class ESMMTorch(nn.Module): def __init__(self, user_num, item_num, cate_num, embed_dim=64): super().__init__() # 共享嵌入层 self.user_emb = nn.Embedding(user_num, embed_dim) self.item_emb = nn.Embedding(item_num, embed_dim) self.cate_emb = nn.Embedding(cate_num, embed_dim) # 网络塔结构 self.ctr_tower = nn.Sequential( nn.Linear(embed_dim*3, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid() ) self.cvr_tower = nn.Sequential( nn.Linear(embed_dim*3, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid() ) def forward(self, x): user_emb = self.user_emb(x['user_id']) item_emb = self.item_emb(x['item_id']) cate_emb = self.cate_emb(x['cate_id']) concat_emb = torch.cat([user_emb, item_emb, cate_emb], dim=1) pctr = self.ctr_tower(concat_emb) pcvr = self.cvr_tower(concat_emb) pctcvr = pctr * pcvr return pctr.squeeze(), pcvr.squeeze(), pctcvr.squeeze()

3.2 训练流程优化

PyTorch训练循环示例:

def train_epoch(model, dataloader, optimizer, device): model.train() total_loss = 0 for batch in dataloader: optimizer.zero_grad() # 数据转移到设备 inputs = {k: v.to(device) for k,v in batch.items()} labels = { 'click': inputs['click'].float(), 'conversion': (inputs['click'] * inputs['conversion']).float() } # 前向计算 ctr_pred, cvr_pred, ctcvr_pred = model(inputs) # 损失计算 ctr_loss = F.binary_cross_entropy(ctr_pred, labels['click']) ctcvr_loss = F.binary_cross_entropy(ctcvr_pred, labels['conversion']) loss = ctr_loss + ctcvr_loss # 反向传播 loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(dataloader)

4. 实战技巧与调优策略

4.1 特征工程最佳实践

  • 用户侧特征
    • 历史点击/转化统计
    • 兴趣标签
    • 活跃时段
  • 物品侧特征
    • 类目属性
    • 价格分段
    • 销量统计
  • 上下文特征
    • 曝光位置
    • 时间周期
    • 设备类型

4.2 模型调优关键点

  1. Embedding维度选择

    • 高基数特征:16-64维
    • 低基数特征:8-16维
  2. 塔结构设计

    # 更深的网络结构示例 self.ctr_tower = nn.Sequential( nn.Linear(embed_dim*3, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid() )
  3. 损失函数加权

    # 根据任务重要性调整权重 loss = alpha * ctr_loss + beta * ctcvr_loss

4.3 线上部署考量

  • 性能优化
    • 使用TensorRT加速推理
    • 实现Embedding缓存机制
  • 效果监控
    • 建立CTR/CVR漂移检测
    • 设计A/B测试分层策略

在电商推荐系统中,ESMM的落地需要与召回模块、排序模块协同工作。实际部署时,我们发现将CVR预测结果与CTR预测进行动态加权(如score = CTR^α * CVR^β),能够更好地平衡点击率和转化率的目标。

http://www.jsqmd.com/news/978412/

相关文章:

  • 除了点灯,在STM32F407上跑OpenHarmony还能做什么?聊聊外设驱动与生态拓展
  • 别再死记硬背了!从Buck电路入手,图解SPST/SPDT开关的半导体实现原理
  • 别再只用UUID v4了!5个版本(v1到v5)的实战选择指南与Node.js代码示例
  • 别再搞混了!一文讲透Windbg网络调试、远程调试与真机双机调试的区别
  • 不只是编译:用OpenMVG 2.0 + CloudCompare 玩转你的第一份3D稀疏点云
  • 2026年价格实惠的去核机推荐厂家 - mypinpai
  • 从ESP-01S到ESP-12F:一个毕业生的物联网上云踩坑实录(附完整接线图与避坑清单)
  • 符号不变注意力机制:Transformer架构的创新改进
  • 2026年6月重庆大学城靠谱画室评测:4家机构核心维度对比 - 奔跑123
  • 别再手动调Excel了!用Python的openpyxl批量设置字体、边框和行高,效率翻倍
  • 从CPLD到低成本FPGA:利用AGM AG576SL100,我如何为老项目“偷”出了4个额外IO口?
  • 计算机毕业设计之基于 Hadoop技术贝壳网商品房租赁数据分析与可视化
  • 新手电商开店必看:快递批量查询从入门到精通(完整版)
  • STM32单片机光照检测智能调光系统Protest仿真+代码+报告+讲解视频
  • 2026年哈氏合金管口碑好的品牌排名 - mypinpai
  • WPS表格转换踩坑实录:逗号、空格用不对,格式全乱!附正确设置图解
  • 02-Hooks完全指南——08-useTransition 与 useDeferredValue
  • WPS表格进阶玩法:巧用‘文本转表格’功能,一键处理调查问卷和导出数据
  • 不止于稀疏点云:用OpenMVG 2.0完成SFM后,如何无缝衔接OpenMVS进行稠密重建?
  • 别再手动对齐了!用Word/WPS的‘文本转表格’功能,5分钟搞定杂乱数据整理
  • pdfplumber:Python PDF 解析与表格提取利器
  • 简单C++
  • 其他推荐 - 本地品牌推荐
  • 光猫‘死前’信号揭秘:DyingGasp电路在PON网络中的实战应用与故障排查指南
  • 【STM32】配置vscode+C工具链+Cortex-Debug开发环境,IC:STM32F411CEU6
  • 双组份背胶选购指南,兴佰诚值得选吗 - mypinpai
  • 从水箱报警到花盆浇水:用窗口比较器LM393DIY一个超实用的水位监控器
  • MyComputerManager:基于WPF的Windows注册表管理系统架构深度解析
  • 多标签表单与文件上传的完美结合
  • 从OFDM仿真到性能对比:深入理解LMMSE与LS信道估计的MATLAB实战(含信噪比影响分析)