CANN/pypto topk操作
# pypto.topk
【免费下载链接】pyptoPyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
功能说明
获取最后一个维度的前k个最大值或最小值及其对应的索引。
如果输入是向量,则在向量中找到前k个最大值或最小值及其对应的索引;如果输入是矩阵,则沿最后一个维度计算每行中前k个最大值或最小值及其对应的索引。如下图所示,对Shape为(4, 32)的二维矩阵进行排序,k设置为1,输出结果为[[32] [32] [32] [32]]。
函数原型
topk(input: Tensor, k: int, dim: Optional[int] = None, largest: bool = True, algo: TopKAlgo = TopKAlgo.MERGE_SORT) -> Tuple[Tensor, Tensor]参数说明
| 参数名 | 输入/输出 | 说明 |
|---|---|---|
| input | 输入 | 源操作数。 支持的类型为:Tensor。 Tensor支持的数据类型为: - MERGE_SORT: DT_FP32。 - RADIX_SELECT: DT_BF16,DT_FP16,DT_FP32。 不支持空Tensor;Shape仅支持1-4维;Shape Size不大于2147483647(即INT32_MAX)。 |
| k | 输入 | 返回元素的数量。 k的大小应该满足:1 <= k <= input.shape[dim]。 |
| dim | 输入 | 指定排序的维度。 目前仅支持按最后一个维度排序,即dim= -1或dim= input.shape.size() - 1。 |
| largest | 输入 | 如果为True,返回最大元素。如果为False,返回最小元素。 |
| algo | 输入 | 算法枚举类型,用以控制TopK计算的流程,具体定义为:TopKAlgo 。 默认为 MERGE_SORT(归并排序算法)。 |
返回值说明
返回一个命名元组(values, indices),其中包含input在指定维度dim下每行中最大或最小的k个元素的值和索引。
约束说明
- 只支持对尾轴进行topk操作;
- TileShape尾轴32bytes对齐(TileShape[-1]*4 % 32 == 0);
- 选用MERGE_SORT算法时,TileShape尾轴需要小于22KB(TileShape[-1]*4 < 22KB);
- 选用RADIX_SELECT算法时,记TileShape尾轴为tile,则需要临时空间2*tile*sizeof(srcType)+6*tile+1024+max(1024, 8*tile),临时空间加上输入输出的tile块不能超过UB大小;
- k <= TileShape[-1] && k <= input.shape[-1];
- RADIX_SELECT算法仅支持Ascend 950PR/Ascend 950DT;
调用示例
TileShape设置示例
说明:调用该operation接口前,应通过set_vec_tile_shapes设置TileShape。
TileShape维度应和输入input一致。
示例1:输入input shape为[m, n, p],dim为2,largest为True,输出为[m, n, k], TileShape设置为[m1, n1, p1], 则m1, n1, p1分别用于切分m, n, p轴。p1必须大于等于k,k轴不支持切分,必须保证全载。
pypto.set_vec_tile_shapes(4, 16, 32)接口调用示例
x = pypto.tensor([2, 3], pypto.DT_FP32) y = pypto.topk(x, 2, -1, True, pypto.TopKAlgo.MERGE_SORT)结果示例如下:
输入数据x: [[1.0 2.0 3.0], [1.0 2.0 3.0]] 输出数据y[0]: [[3.0 2.0], [3.0 2.0]] 输出数据y[1]: [[2, 1], [2, 1]]【免费下载链接】pyptoPyPTO(发音: pai p-t-o):Parallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
