从数据清洗到特征提取:用PyTorch Tensor索引函数(masked_select/non_zero/gather)搞定真实数据处理任务
从数据清洗到特征提取:用PyTorch Tensor索引函数搞定真实数据处理任务
想象一下这样的场景:你刚拿到一份用户行为数据集,准备训练推荐模型。打开CSV文件后却发现——数据里混杂着缺失值、异常点击记录、非标准ID格式。此时你需要快速清洗数据并提取有效特征,而PyTorch的Tensor索引函数就是你的瑞士军刀。本文将带你用non_zero、masked_select和gather三个函数,完成从原始数据到模型输入的完整预处理流程。
1. 实战场景:用户行为数据清洗
假设我们有一份电商平台的用户点击日志,原始数据如下表所示:
| 用户ID | 点击时间戳 | 停留时长(秒) | 商品类别 | 是否购买 |
|---|---|---|---|---|
| U1001 | 1633024000 | 120 | 3 | 1 |
| U1002 | 1633024015 | -1 | 1 | 0 |
| NaN | 1633024020 | 300 | 2 | 0 |
| U1003 | 1633024100 | 5 | 5 | 1 |
| U1001 | 1633024150 | 60 | NaN | 0 |
1.1 用non_zero定位有效数据
首先加载数据并转换为Tensor:
import torch import numpy as np # 模拟原始数据 data = np.array([ ['U1001', 1633024000, 120, 3, 1], ['U1002', 1633024015, -1, 1, 0], [np.nan, 1633024020, 300, 2, 0], ['U1003', 1633024100, 5, 5, 1], ['U1001', 1633024150, 60, np.nan, 0] ]) # 转换为数值型Tensor user_ids = torch.tensor([int(uid[1:]) if isinstance(uid, str) else -1 for uid in data[:,0]]) timestamps = torch.tensor(data[:,1], dtype=torch.float32) durations = torch.tensor(data[:,2], dtype=torch.float32) categories = torch.tensor([c if not np.isnan(c) else -1 for c in data[:,3]], dtype=torch.long) purchased = torch.tensor(data[:,4], dtype=torch.bool)找出有效用户ID的记录:
valid_user_mask = (user_ids != -1) valid_indices = valid_user_mask.nonzero().squeeze() print(f"有效记录索引:{valid_indices.tolist()}")提示:
nonzero()返回的是二维坐标,用squeeze()压缩单维度后更易处理
1.2 用masked_select清洗异常值
处理停留时长的异常值(负值和过长值):
# 定义合理停留时长范围(5秒到10分钟) reasonable_duration = (durations >= 5) & (durations <= 600) # 同时满足用户有效和时长合理 clean_mask = valid_user_mask & reasonable_duration clean_durations = durations.masked_select(clean_mask) print(f"清洗后时长数据:{clean_durations}")处理商品类别缺失值:
valid_categories = categories.masked_select(categories != -1) unique_categories = torch.unique(valid_categories) print(f"有效商品类别:{unique_categories.tolist()}")2. 特征工程:构建用户画像
2.1 用gather实现特征重组
假设我们需要按用户ID重组特征:
# 创建用户ID到索引的映射 user_dict = {uid.item(): idx for idx, uid in enumerate(user_ids[valid_user_mask].unique())} # 示例:统计每个用户的总停留时长 user_duration = torch.zeros(len(user_dict)) for uid, idx in user_dict.items(): user_mask = (user_ids == uid) user_duration[idx] = durations.masked_select(user_mask).sum()更高效的做法是使用gather:
# 构建用户-行为关系矩阵 user_idx = torch.tensor([user_dict[uid.item()] for uid in user_ids[valid_user_mask]]) behavior_matrix = torch.stack([ durations[valid_user_mask], purchased[valid_user_mask].float() ], dim=1) # 按用户ID聚合特征 aggregated_features = torch.zeros(len(user_dict), 2) aggregated_features.index_add_(0, user_idx, behavior_matrix)2.2 组合索引实现高级特征提取
提取高价值用户(停留时间长且购买率高)的特征:
avg_duration = aggregated_features[:,0] / (user_idx.bincount(minlength=len(user_dict)).float()) purchase_rate = aggregated_features[:,1] / (user_idx.bincount(minlength=len(user_dict)).float()) high_value_mask = (avg_duration > 30) & (purchase_rate > 0.3) high_value_users = torch.tensor(list(user_dict.keys())).masked_select(high_value_mask)3. 性能优化技巧
3.1 避免不必要的内存拷贝
# 不推荐写法(创建临时Tensor) tmp = data[:, 2] clean_durations = tmp[(tmp >=5) & (tmp <=600)] # 推荐写法(直接操作原始Tensor) clean_durations = durations.masked_select(reasonable_duration)3.2 利用广播机制批量处理
同时处理多个条件:
# 定义多条件掩码 condition_mask = ( (durations >= 5) & (durations <= 600) & (categories != -1) & (user_ids != -1) ) # 一次性提取符合所有条件的数据 clean_data = { 'durations': durations.masked_select(condition_mask), 'categories': categories.masked_select(condition_mask), 'users': user_ids.masked_select(condition_mask) }4. 完整案例:构建推荐系统训练数据
将上述技术整合到完整流程:
def preprocess(raw_data): # 转换原始数据 user_ids = torch.tensor([int(uid[1:]) if isinstance(uid, str) else -1 for uid in raw_data[:,0]]) features = torch.tensor(raw_data[:,1:], dtype=torch.float32) # 数据清洗 valid_mask = (user_ids != -1) & ~torch.any(features.isnan(), dim=1) clean_features = features[valid_mask] clean_users = user_ids[valid_mask] # 特征工程 user_stats = torch.zeros(clean_users.max()+1, 3) user_stats[:,0] = torch.scatter_add( torch.zeros(clean_users.max()+1), 0, clean_users, clean_features[:,0] # 停留时长 ) user_stats[:,1] = torch.scatter_add( torch.zeros(clean_users.max()+1), 0, clean_users, clean_features[:,2] # 购买金额 ) user_stats[:,2] = torch.scatter_add( torch.zeros(clean_users.max()+1), 0, clean_users, torch.ones_like(clean_users, dtype=torch.float) # 访问次数 ) return user_stats # 示例用法 training_data = preprocess(data) print(f"最终训练数据维度:{training_data.shape}")这个案例展示了如何将三个核心索引函数串联使用:先用nonzero定位有效数据,再用masked_select过滤异常值,最后用gather(及其变体scatter_add)重组特征。实际项目中,这种处理流程可以节省大量Pandas操作时间,特别适合在GPU上处理大规模数据。
