当前位置: 首页 > news >正文

从数据清洗到特征提取:用PyTorch Tensor索引函数(masked_select/non_zero/gather)搞定真实数据处理任务

从数据清洗到特征提取:用PyTorch Tensor索引函数搞定真实数据处理任务

想象一下这样的场景:你刚拿到一份用户行为数据集,准备训练推荐模型。打开CSV文件后却发现——数据里混杂着缺失值、异常点击记录、非标准ID格式。此时你需要快速清洗数据并提取有效特征,而PyTorch的Tensor索引函数就是你的瑞士军刀。本文将带你用non_zeromasked_selectgather三个函数,完成从原始数据到模型输入的完整预处理流程。

1. 实战场景:用户行为数据清洗

假设我们有一份电商平台的用户点击日志,原始数据如下表所示:

用户ID点击时间戳停留时长(秒)商品类别是否购买
U1001163302400012031
U10021633024015-110
NaN163302402030020
U10031633024100551
U1001163302415060NaN0

1.1 用non_zero定位有效数据

首先加载数据并转换为Tensor:

import torch import numpy as np # 模拟原始数据 data = np.array([ ['U1001', 1633024000, 120, 3, 1], ['U1002', 1633024015, -1, 1, 0], [np.nan, 1633024020, 300, 2, 0], ['U1003', 1633024100, 5, 5, 1], ['U1001', 1633024150, 60, np.nan, 0] ]) # 转换为数值型Tensor user_ids = torch.tensor([int(uid[1:]) if isinstance(uid, str) else -1 for uid in data[:,0]]) timestamps = torch.tensor(data[:,1], dtype=torch.float32) durations = torch.tensor(data[:,2], dtype=torch.float32) categories = torch.tensor([c if not np.isnan(c) else -1 for c in data[:,3]], dtype=torch.long) purchased = torch.tensor(data[:,4], dtype=torch.bool)

找出有效用户ID的记录:

valid_user_mask = (user_ids != -1) valid_indices = valid_user_mask.nonzero().squeeze() print(f"有效记录索引:{valid_indices.tolist()}")

提示:nonzero()返回的是二维坐标,用squeeze()压缩单维度后更易处理

1.2 用masked_select清洗异常值

处理停留时长的异常值(负值和过长值):

# 定义合理停留时长范围(5秒到10分钟) reasonable_duration = (durations >= 5) & (durations <= 600) # 同时满足用户有效和时长合理 clean_mask = valid_user_mask & reasonable_duration clean_durations = durations.masked_select(clean_mask) print(f"清洗后时长数据:{clean_durations}")

处理商品类别缺失值:

valid_categories = categories.masked_select(categories != -1) unique_categories = torch.unique(valid_categories) print(f"有效商品类别:{unique_categories.tolist()}")

2. 特征工程:构建用户画像

2.1 用gather实现特征重组

假设我们需要按用户ID重组特征:

# 创建用户ID到索引的映射 user_dict = {uid.item(): idx for idx, uid in enumerate(user_ids[valid_user_mask].unique())} # 示例:统计每个用户的总停留时长 user_duration = torch.zeros(len(user_dict)) for uid, idx in user_dict.items(): user_mask = (user_ids == uid) user_duration[idx] = durations.masked_select(user_mask).sum()

更高效的做法是使用gather

# 构建用户-行为关系矩阵 user_idx = torch.tensor([user_dict[uid.item()] for uid in user_ids[valid_user_mask]]) behavior_matrix = torch.stack([ durations[valid_user_mask], purchased[valid_user_mask].float() ], dim=1) # 按用户ID聚合特征 aggregated_features = torch.zeros(len(user_dict), 2) aggregated_features.index_add_(0, user_idx, behavior_matrix)

2.2 组合索引实现高级特征提取

提取高价值用户(停留时间长且购买率高)的特征:

avg_duration = aggregated_features[:,0] / (user_idx.bincount(minlength=len(user_dict)).float()) purchase_rate = aggregated_features[:,1] / (user_idx.bincount(minlength=len(user_dict)).float()) high_value_mask = (avg_duration > 30) & (purchase_rate > 0.3) high_value_users = torch.tensor(list(user_dict.keys())).masked_select(high_value_mask)

3. 性能优化技巧

3.1 避免不必要的内存拷贝

# 不推荐写法(创建临时Tensor) tmp = data[:, 2] clean_durations = tmp[(tmp >=5) & (tmp <=600)] # 推荐写法(直接操作原始Tensor) clean_durations = durations.masked_select(reasonable_duration)

3.2 利用广播机制批量处理

同时处理多个条件:

# 定义多条件掩码 condition_mask = ( (durations >= 5) & (durations <= 600) & (categories != -1) & (user_ids != -1) ) # 一次性提取符合所有条件的数据 clean_data = { 'durations': durations.masked_select(condition_mask), 'categories': categories.masked_select(condition_mask), 'users': user_ids.masked_select(condition_mask) }

4. 完整案例:构建推荐系统训练数据

将上述技术整合到完整流程:

def preprocess(raw_data): # 转换原始数据 user_ids = torch.tensor([int(uid[1:]) if isinstance(uid, str) else -1 for uid in raw_data[:,0]]) features = torch.tensor(raw_data[:,1:], dtype=torch.float32) # 数据清洗 valid_mask = (user_ids != -1) & ~torch.any(features.isnan(), dim=1) clean_features = features[valid_mask] clean_users = user_ids[valid_mask] # 特征工程 user_stats = torch.zeros(clean_users.max()+1, 3) user_stats[:,0] = torch.scatter_add( torch.zeros(clean_users.max()+1), 0, clean_users, clean_features[:,0] # 停留时长 ) user_stats[:,1] = torch.scatter_add( torch.zeros(clean_users.max()+1), 0, clean_users, clean_features[:,2] # 购买金额 ) user_stats[:,2] = torch.scatter_add( torch.zeros(clean_users.max()+1), 0, clean_users, torch.ones_like(clean_users, dtype=torch.float) # 访问次数 ) return user_stats # 示例用法 training_data = preprocess(data) print(f"最终训练数据维度:{training_data.shape}")

这个案例展示了如何将三个核心索引函数串联使用:先用nonzero定位有效数据,再用masked_select过滤异常值,最后用gather(及其变体scatter_add)重组特征。实际项目中,这种处理流程可以节省大量Pandas操作时间,特别适合在GPU上处理大规模数据。

http://www.jsqmd.com/news/791196/

相关文章:

  • LangGraph 常见错误与排错实战手册
  • 如何3步解决Blue Archive自动脚本Mumu模拟器检测问题
  • ThinkPad风扇终极静音方案:TPFanCtrl2智能温控神器深度解析
  • QKeyMapper:Windows平台下无需重启系统的终极按键映射解决方案
  • Java的反射机制
  • 2026宁波黄金回收店哪家好?本地7家正规商家实测排名 - 生活测评君
  • 构建AI增强的第二大脑:从知识管理到智能创造的实战指南
  • 揭秘2026全球AI大会签到系统崩溃真相:生物识别+区块链双认证背后的17个失效节点
  • 【SITS 2026权威前瞻】:AI原生研发的5大范式跃迁与企业落地避坑指南
  • 从命令行安装命令行包管理器:Windows用户的自动化救星
  • 将Taotoken作为统一网关整合至企业现有微服务架构
  • 在CentOS 7虚拟机上部署ICC 2016:从安装器配置到环境调优全流程
  • QueryExcel:批量Excel数据检索的自动化解决方案
  • postman使用
  • 心理咨询医院暖心指南与真实案例分享
  • 从根桥选举到环路防护:一张图看懂RSTP的5大保护机制(附配置命令)
  • 3步解锁微信网页版:高效实用的浏览器插件解决方案
  • 世界模型:通往AGI的必经之路,还是数据驱动的幻觉?
  • 从陈硕的测试数据看,为什么muduo网络库的吞吐量能比Boost.Asio高15%?
  • 从按钮到进度条:深度解析QSS text-align属性的‘有限’支持与实战替代方案
  • SAP资产折旧别只记成本中心了!试试这招,让项目成本核算更清晰(附ACSET避坑点)
  • 从入场到泊车仅97秒,2026 AI大会智能诱导系统深度拆解,含V2X路侧单元部署图谱
  • 为什么92%的AI项目卡在工程化?AI原生开发流程重构,从概念验证到规模化交付的终极解法
  • 初创公司如何借助taotoken多模型能力快速构建ai产品原型
  • 如何快速搭建专业Webmail系统:Roundcube完整配置指南
  • 开发AI应用时如何利用Taotoken模型广场进行选型测试
  • 保姆级教程:用PCL的ProgressiveMorphologicalFilter搞定机载LiDAR点云地面提取(附避坑指南)
  • 别再为喜马拉雅xm格式发愁了!实测微软商店版喜马拉雅,下载的音频直接就是mp3
  • 如何为 Hermes Agent 配置 Taotoken 作为自定义模型供应商
  • 将Claude Code编程助手无缝切换至Taotoken平台的配置指南