当前位置: 首页 > news >正文

CANN Bench交叉熵损失算子评测

CrossEntropyLoss 算子 API 描述

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

1. 算子简介

计算交叉熵损失,用于分类任务。

主要应用场景

  • 多分类任务的损失函数(图像分类、文本分类等)
  • 语言模型的 next-token 预测训练
  • 支持硬标签(类别索引)和软标签(概率分布)两种模式

算子特征

  • 难度等级:L2(NumericalStable)
  • 双输入(logits 和 target)单输出(loss),涉及 softmax、对数、归约等多步计算
  • 输入 x 为 (N, C) 或更高维的 logits 张量,target 为 (N,) 的类别索引或 (N, C) 的软标签

2. 算子定义

数学公式

基本公式

$$ L = -\log\left(\frac{\exp(x_{target})}{\sum_{j}\exp(x_j)}\right) $$

等价于:

$$ L = -x_{target} + \log\left(\sum_{j}\exp(x_j)\right) $$

带权重的公式

$$ L = -weight_{target} \cdot \log\left(\frac{\exp(x_{target})}{\sum_{j}\exp(x_j)}\right) $$

其中:

  • reduction='none'时返回每个样本的损失,shape 为 (N,)
  • reduction='mean'时返回 batch 平均损失(标量)
  • reduction='sum'时返回 batch 总损失(标量)
  • ignore_index指定的标签不参与损失计算

3. 接口规范

算子原型

cann_bench.cross_entropy_loss(Tensor x, Tensor target, str reduction, int ignore_index) -> Tensor loss

输入参数说明

参数类型默认值描述
xTensor必选输入 logits 张量(未经 softmax)
targetTensor必选目标标签索引(hard labels)或概率分布(soft labels)
reductionstring"mean"损失聚合方式 ('none' | 'mean' | 'sum')
ignore_indexint-100忽略的标签索引(不影响损失计算)

输出

参数Shapedtype描述
lossreduction='none' 时为 (N,),否则为标量与输入 x 相同损失值

数据类型

x dtypetarget dtype输出 dtype
float32int32 / int64float32
float16int32 / int64float16
bfloat16int32 / int64bfloat16
float32float32float32
float16float16float16
bfloat16bfloat16bfloat16

规则与约束

  • x 的 shape 为 (N, C) 或 (N, C, d1, d2, ...),其中 N 为 batch size,C 为类别数
  • 硬标签模式:target 的 shape 为 (N,) 或 (N, d1, d2, ...),值为 [0, C) 范围内的类别索引
  • 软标签模式:target 的 shape 为 (N, C),值为概率分布
  • ignore_index仅在硬标签模式下生效
  • 输入 x 应为原始 logits(未经 softmax),内部自动应用 log_softmax
  • 需注意数值稳定性:内部实现应使用 log-sum-exp 技巧避免溢出

支持范围

输入 tensor 各维度与参数的支持范围:

维度 / 参数范围备注
N(batch size,x 第 0 维)1 ~ 2097152cases.csv 实测 2 ~ 1,000,003
C(类别数,x 第 1 维)2 ~ 16384cases.csv 实测 2 ~ 16,384
额外空间维度d_i(x 第 ≥2 维)1 ~ 1024cases.csv 实测 3 ~ 1024
rank(x)(x 维度数)2 ~ 8cases.csv 实测 2 ~ 5 维
rank(target)rank(x)-1 或 rank(x)硬标签缺 C 维;软标签同 x;cases.csv 全为硬标签
reduction"none" / "mean" / "sum"cases.csv 三种均覆盖
ignore_indexint64 任意值cases.csv 实测 -100 / -1 / 0 / 10 / 50 / 100

约束:硬标签模式下 target 各元素取值范围为 [0, C) 或等于ignore_index;软标签模式下 target 形状须与 x 完全一致。

4. 精度要求

采用生态算子精度标准进行验证。

误差指标

  1. 平均相对误差(MERE):采样点中相对误差平均值

    $$ \text{MERE} = \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

  2. 最大相对误差(MARE):采样点中相对误差最大值

    $$ \text{MARE} = \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

通过标准

数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2
通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2

当平均相对误差 MERE < Threshold,最大相对误差 MARE < 10 * Threshold 时判定为通过。

5. 标准 Golden 代码

import torch """ CrossEntropyLoss 算子 Torch Golden 参考实现 计算交叉熵损失,用于分类任务 公式: L = -log(exp(x[target]) / sum(exp(x))) 或带 weight: L = -weight[target] * log(exp(x[target]) / sum(exp(x))) 参考 PyTorch API: torch.nn.CrossEntropyLoss https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html Parameters: - input: (N, C) 或 (N, C, H, W) 等 - logits 张量(未经 softmax) - target: (N,) 硬标签 或 (N, C) 软标签(概率分布) - weight: (C,) 各类别的权重(可选) - ignore_index: int, 默认 -100 - 忽略的标签索引 - reduction: 'none' | 'mean' | 'sum', 默认 'mean' - 损失聚合方式 """ def cross_entropy_loss( x: torch.Tensor, target: torch.Tensor, reduction: str = 'mean', ignore_index: int = -100 ) -> torch.Tensor: """ 计算交叉熵损失 Args: x: 输入 logits 张量,shape (N, C) 或 (N, C, d1, d2, ...) N = batch size, C = 类别数(channel_first 约定) 注意:输入应为 logits(未经 softmax),内部会自动应用 log_softmax target: 目标标签 - 硬标签:shape (N,) 或 (N, d1, d2, ...),值为类别索引 - 软标签:shape (N, C),值为概率分布 reduction: 损失聚合方式 'none': 返回每个样本的损失,shape (N,) 'mean': 返回 batch 平均损失 'sum': 返回 batch 总损失 ignore_index: 忽略的标签索引 当 target 为硬标签且值为 ignore_index 时,该样本不计入损失 Returns: 损失值:如果 reduction='none',返回 shape (N,) 的张量 否则返回标量张量 Examples: >>> N, C = 16, 10 # 16个样本,10个类别 >>> x = torch.randn(N, C) >>> target = torch.randint(0, C, (N,)) >>> loss = cross_entropy_loss(x, target) """ # 直接调用 PyTorch 标准 CrossEntropyLoss 实现 # torch.nn.functional.cross_entropy 内部会自动应用 log_softmax loss = torch.nn.functional.cross_entropy( input=x, target=target, reduction=reduction, ignore_index=ignore_index ) return loss

6. 额外信息

算子调用示例

import torch import cann_bench x = torch.randn(1024, 2048, dtype=torch.float32, device="npu") target = torch.randint(0, 2048, (1024,), dtype=torch.int64, device="npu") loss = cann_bench.cross_entropy_loss(x, target, reduction="mean", ignore_index=-100) loss = cann_bench.cross_entropy_loss(x, target, reduction="sum", ignore_index=-100) loss = cann_bench.cross_entropy_loss(x, target, reduction="none", ignore_index=-100)

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

http://www.jsqmd.com/news/849341/

相关文章:

  • Matlab阶跃响应性能指标自动化计算:从原理到工程实践
  • 如何快速上手elec-ops-inspection:昇腾平台部署指南
  • Configor 自动重载功能深度解析:实现配置热更新的终极指南
  • CANN/hccl RDMA QP端口配置路径
  • 轨距调整片定制哪家好?2026年绝缘轨距块生产厂家优质供应商推荐指南:新建铁路配件领衔 - 栗子测评
  • 2026机房不间断电源生产厂家哪家好?深圳不间断电源生产厂家实力深度解析 - 栗子测评
  • cann/asc-devkit SetGradOutput接口
  • CANN ops-fft部署指南:生产环境中的配置、监控与故障排除
  • npc_gzip异常处理与调试手册:解决压缩器错误的10个实用技巧
  • Commit Mono版本管理指南:如何优雅地升级和回滚字体版本
  • 源头工厂直供:利成充气水池定制厂家,广东便携式宠物泳池、PVC 戏水玩具、水上充气浮排专业生产基地 - 栗子测评
  • 穿透算法黑箱:2026论文降AI率工具深度测评,早标网语义保真度99%
  • 橡胶垫板定制厂家推荐:新建铁路配件领衔,2026年口碑好的调高垫板批发厂家/轨道橡胶垫板生产厂家/精调件生产厂家盘点 - 栗子测评
  • Transformer架构解析:自注意力机制与LLM核心技术
  • CrossGeo:首个跨卫星-无人机-地面三重视角的6-DoF 3D重建与定位数据集详解
  • 【YOLO目标检测全栈实战】48 深入TensorRT加速:从28ms到6ms的C++推理实战
  • Seed-VC语音克隆指南:5分钟实现零样本实时语音转换的终极方案
  • ARM SPE Profiling Buffer机制与性能分析实践
  • 地空协同巡检新范式:elec-ops-inspection 3D空间建模技术
  • GIFT应用案例:从Web服务到移动应用的实际部署方案
  • USB/IP Windows:打破物理限制的USB设备网络共享终极方案
  • 钢制平开防火窗|2026价格与工程应用要点
  • STR71X芯片JTAG失效分析与Bootloader恢复指南
  • Symfony String国际化实战:为什么它比原生PHP字符串函数更强大
  • 如何用Lano Visualizer打造智能音频可视化桌面:从音乐爱好者到专业用户的完整指南
  • 【独家首发】Gemini Pro函数调用(Function Calling)深度解析:7个生产环境踩坑案例+可复用的TypeScript Schema模板
  • 保姆级教程:手把手教你用ROS话题转发搞定CARLA与Autoware的传感器数据对齐
  • Windows 11文件资源管理器标签化神器:终极窗口管理解决方案
  • Egg.js TodoMVC实现:完整CRUD操作与前端交互实战
  • 【YOLO目标检测全栈实战】49 模型服务化:用Triton Inference Server部署YOLOv8全流程实战