当前位置: 首页 > news >正文

图解PyTorch gather函数:从困惑到精通,一个例子讲清张量收集操作

图解PyTorch gather函数:从困惑到精通,一个例子讲清张量收集操作

很多PyTorch初学者在面对gather函数时都会感到困惑——这个看似简单的操作,为什么总是让人摸不着头脑?今天我们就用一个具体的例子,配合直观的图解,彻底讲清楚gather的工作原理和应用场景。

想象你正在处理一个分类任务的输出:模型对每个样本预测了各个类别的得分,形状为[batch_size, num_classes]。现在你需要根据真实标签(形状为[batch_size])收集每个样本对应正确类别的得分。这正是gather函数的典型应用场景。

1. 理解gather函数的基本概念

gather函数的官方定义是:沿着指定维度收集输入张量的值。听起来很抽象?让我们拆解这个定义:

torch.gather(input, dim, index, out=None) → Tensor
  • input:源张量,我们要从中收集数据
  • dim:收集操作的维度
  • index:指定收集位置的索引张量
  • out:可选参数,输出张量

关键点在于理解dim参数如何影响收集行为。我们可以把gather看作是一种"查表"操作:根据index提供的坐标,从input中取出对应的值。

2. 通过具体例子理解gather

让我们用一个具体的例子来演示。假设我们有一个2x3的张量:

import torch a = torch.arange(0, 6).reshape(2, 3) # tensor([[0, 1, 2], # [3, 4, 5]])

2.1 dim=0的情况

dim=0时,收集操作沿着第0维(行方向)进行。我们需要提供一个与输出形状相同的index张量,其中的值表示在第0维上的位置。

index = torch.LongTensor([[0, 1, 0], [1, 0, 0]]) result = torch.gather(a, 0, index) # tensor([[0, 4, 2], # [3, 1, 2]])

这个结果是怎么得到的呢?我们可以这样理解:

  1. 对于输出张量的每个位置(i,j),我们查看index[i,j]的值
  2. 这个值告诉我们从input的哪一行取数
  3. 具体来说:
    • result[0,0] = a[index[0,0],0] = a[0,0] = 0
    • result[0,1] = a[index[0,1],1] = a[1,1] = 4
    • result[0,2] = a[index[0,2],2] = a[0,2] = 2
    • result[1,0] = a[index[1,0],0] = a[1,0] = 3
    • result[1,1] = a[index[1,1],1] = a[0,1] = 1
    • result[1,2] = a[index[1,2],2] = a[0,2] = 2

2.2 dim=1的情况

dim=1时,收集操作沿着第1维(列方向)进行。同样需要一个与输出形状相同的index张量。

index = torch.LongTensor([[2, 0, 1], [1, 2, 0]]) result = torch.gather(a, 1, index) # tensor([[2, 0, 1], # [4, 5, 3]])

这次的计算逻辑是:

  1. 对于输出张量的每个位置(i,j),我们查看index[i,j]的值
  2. 这个值告诉我们从input的哪一列取数
  3. 具体来说:
    • result[0,0] = a[0,index[0,0]] = a[0,2] = 2
    • result[0,1] = a[0,index[0,1]] = a[0,0] = 0
    • result[0,2] = a[0,index[0,2]] = a[0,1] = 1
    • result[1,0] = a[1,index[1,0]] = a[1,1] = 4
    • result[1,1] = a[1,index[1,1]] = a[1,2] = 5
    • result[1,2] = a[1,index[1,2]] = a[1,0] = 3

3. gather在实际场景中的应用

让我们回到最初提到的分类任务场景。假设:

  • 模型预测得分:scores = torch.tensor([[0.1, 0.3, 0.6], [0.4, 0.2, 0.4]])(2个样本,3个类别)
  • 真实标签:labels = torch.tensor([2, 0])(第一个样本的真实类别是2,第二个是0)

我们需要收集每个样本对应真实类别的得分:

# 首先将labels的形状从[batch_size]变为[batch_size, 1] labels = labels.unsqueeze(1) # tensor([[2], # [0]]) # 然后使用gather收集得分 selected_scores = torch.gather(scores, 1, labels) # tensor([[0.6000], # [0.4000]])

这个例子展示了gather在机器学习中的典型应用——根据索引从预测结果中提取特定值。

4. gather与其他索引操作的对比

PyTorch提供了多种张量索引操作,理解它们与gather的区别很重要:

操作功能与gather的区别
index_select沿单一维度选择数据只能选择整个"切片",不能像gather那样灵活选择单个元素
masked_select根据布尔掩码选择元素返回一维张量,不保持输入形状
gather根据索引收集元素可以精确控制每个输出位置的来源,保持输入形状

提示:当需要保持输入张量的维度结构时,gather通常是更好的选择。

5. 常见问题与调试技巧

在使用gather时,经常会遇到以下问题:

  1. 维度不匹配错误:确保index张量的形状与输出形状一致

    • index的形状应该与input的形状相同,除了在dim维度上可以不同
  2. 索引越界错误:确保index中的值不超过inputdim维度上的大小减一

  3. 理解dim参数:记住dim指定的是"沿着哪个维度收集",而不是"收集哪个维度"

调试技巧:

  • 对于简单例子,可以手动计算几个值来验证理解
  • 使用小张量进行实验,打印中间结果
  • 画图辅助理解,特别是对于高维张量

6. 性能优化建议

虽然gather是一个非常有用的操作,但在性能敏感的场景中需要注意:

  1. 避免在循环中频繁调用gather,尽量批量处理
  2. 对于固定模式的索引,考虑使用index_select可能更高效
  3. 在GPU上,gather操作通常是并行化的,可以充分利用硬件优势
# 不推荐的写法 for i in range(batch_size): result[i] = torch.gather(input[i], dim, index[i]) # 推荐的写法 result = torch.gather(input, dim, index)

7. 高级应用示例

gather不仅可以用于简单的值收集,还可以实现一些高级操作。例如,实现一个简单的top-k选择:

# 假设我们有一个分数张量 scores = torch.tensor([[0.1, 0.3, 0.6, 0.2], [0.4, 0.2, 0.3, 0.1]]) # 获取每行的top-2值和索引 topk_values, topk_indices = torch.topk(scores, k=2, dim=1) # 使用gather重构topk_values reconstructed = torch.gather(scores, 1, topk_indices) # 应该与topk_values相同

这个例子展示了如何将gather与其他PyTorch操作结合使用,实现更复杂的功能。

http://www.jsqmd.com/news/792212/

相关文章:

  • 跨站请求伪造(CSRF)
  • AI技术大会摄影服务落地实录(SITS2026独家技术白皮书首发)
  • 英伟达巨额投资,四大云巨头财报亮眼,半导体产业扩张背后隐忧浮现
  • JiYuTrainer深度解析:3大核心技术实现极域电子教室破解与系统控制实战
  • day05补发
  • 2026年4月评价高的高密度硅酸钙板品牌推荐,玻璃热弯模具/汽车后视镜热弯模具,高密度硅酸钙板厂家怎么选择 - 品牌推荐师
  • 2026年4月行业内评价好的不锈钢板实力厂家口碑推荐,不锈钢装饰管/不锈钢折弯/不锈钢角钢,不锈钢板公司哪个好 - 品牌推荐师
  • 洛谷 P1333:瑞瑞的木棍 ← 欧拉回路 + 并查集
  • 掌握 ruby-build 环境变量配置:7 个技巧让 Ruby 安装效率翻倍
  • apio2026游记
  • 团队项目第二次作业
  • sparksql读取mysql表处理etl数据加工过程在把结果反插入库
  • 跨境电商物流解决方案-恒盛通国际快递服务 - 恒盛通物流
  • day05补发补充
  • 2026 年豆包开启付费订阅,中国 AI 大模型商业化迎来大考!
  • 时序数据库详解
  • 软工5月10号
  • Display Driver Uninstaller (DDU):彻底清理显卡驱动的终极解决方案
  • STM32 SDIO+PCM5102成功播放《义妹》
  • day04补发
  • 深入了解Python并发编程
  • 如何通过Noto Emoji实现跨平台表情符号统一:技术原理与应用实践
  • Qt/C++实战:手把手教你用QCustomPlot实现动态刷新热力图(模拟实时数据)
  • MySQL高级特性:索引优化详解
  • 2026年4月优质的初中效袋式过滤器批发厂家推荐,防潮设计适应潮湿环境 - 品牌推荐师
  • Redis数据结构与性能优化详解
  • 使用本地浏览器打开远程服务器生成的网页——详细教程
  • 打破语言壁垒:Translumo屏幕实时翻译工具的终极使用指南
  • 2026 年 Q1 全球互联网中断报告:断网、停电与战争
  • 20253221 2025-2026-2 《Python程序设计》实验3报告