当前位置: 首页 > news >正文

cann/cann-bench CrossEntropyLoss算子API描述

CrossEntropyLoss 算子 API 描述

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

1. 算子简介

计算交叉熵损失,用于分类任务。

主要应用场景

  • 多分类任务的损失函数(图像分类、文本分类等)
  • 语言模型的 next-token 预测训练
  • 支持硬标签(类别索引)和软标签(概率分布)两种模式

算子特征

  • 难度等级:L2(NumericalStable)
  • 双输入(logits 和 target)单输出(loss),涉及 softmax、对数、归约等多步计算
  • 输入 x 为 (N, C) 或更高维的 logits 张量,target 为 (N,) 的类别索引或 (N, C) 的软标签

2. 算子定义

数学公式

基本公式

$$ L = -\log\left(\frac{\exp(x_{target})}{\sum_{j}\exp(x_j)}\right) $$

等价于:

$$ L = -x_{target} + \log\left(\sum_{j}\exp(x_j)\right) $$

带权重的公式

$$ L = -weight_{target} \cdot \log\left(\frac{\exp(x_{target})}{\sum_{j}\exp(x_j)}\right) $$

其中:

  • reduction='none'时返回每个样本的损失,shape 为 (N,)
  • reduction='mean'时返回 batch 平均损失(标量)
  • reduction='sum'时返回 batch 总损失(标量)
  • ignore_index指定的标签不参与损失计算

3. 接口规范

算子原型

cann_bench.cross_entropy_loss(Tensor x, Tensor target, str reduction, int ignore_index) -> Tensor loss

输入参数说明

参数类型默认值描述
xTensor必选输入 logits 张量(未经 softmax)
targetTensor必选目标标签索引(hard labels)或概率分布(soft labels)
reductionstring"mean"损失聚合方式 ('none' | 'mean' | 'sum')
ignore_indexint-100忽略的标签索引(不影响损失计算)

输出

参数Shapedtype描述
lossreduction='none' 时为 (N,),否则为标量与输入 x 相同损失值

数据类型

x dtypetarget dtype输出 dtype
float32int32 / int64float32
float16int32 / int64float16
bfloat16int32 / int64bfloat16
float32float32float32
float16float16float16
bfloat16bfloat16bfloat16

规则与约束

  • x 的 shape 为 (N, C) 或 (N, C, d1, d2, ...),其中 N 为 batch size,C 为类别数
  • 硬标签模式:target 的 shape 为 (N,) 或 (N, d1, d2, ...),值为 [0, C) 范围内的类别索引
  • 软标签模式:target 的 shape 为 (N, C),值为概率分布
  • ignore_index仅在硬标签模式下生效
  • 输入 x 应为原始 logits(未经 softmax),内部自动应用 log_softmax
  • 需注意数值稳定性:内部实现应使用 log-sum-exp 技巧避免溢出

4. 精度要求

采用生态算子精度标准进行验证。

误差指标

  1. 平均相对误差(MERE):采样点中相对误差平均值

    $$ \text{MERE} = \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

  2. 最大相对误差(MARE):采样点中相对误差最大值

    $$ \text{MARE} = \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)+\text{1e-7}}) $$

通过标准

数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2
通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2

当平均相对误差 MERE < Threshold,最大相对误差 MARE < 10 * Threshold 时判定为通过。

5. 标准 Golden 代码

import torch """ CrossEntropyLoss 算子 Torch Golden 参考实现 计算交叉熵损失,用于分类任务 公式: L = -log(exp(x[target]) / sum(exp(x))) 或带 weight: L = -weight[target] * log(exp(x[target]) / sum(exp(x))) 参考 PyTorch API: torch.nn.CrossEntropyLoss https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html Parameters: - input: (N, C) 或 (N, C, H, W) 等 - logits 张量(未经 softmax) - target: (N,) 硬标签 或 (N, C) 软标签(概率分布) - weight: (C,) 各类别的权重(可选) - ignore_index: int, 默认 -100 - 忽略的标签索引 - reduction: 'none' | 'mean' | 'sum', 默认 'mean' - 损失聚合方式 """ def cross_entropy_loss( x: torch.Tensor, target: torch.Tensor, reduction: str = 'mean', ignore_index: int = -100 ) -> torch.Tensor: """ 计算交叉熵损失 Args: x: 输入 logits 张量,shape (N, C) 或 (N, C, d1, d2, ...) N = batch size, C = 类别数(channel_first 约定) 注意:输入应为 logits(未经 softmax),内部会自动应用 log_softmax target: 目标标签 - 硬标签:shape (N,) 或 (N, d1, d2, ...),值为类别索引 - 软标签:shape (N, C),值为概率分布 reduction: 损失聚合方式 'none': 返回每个样本的损失,shape (N,) 'mean': 返回 batch 平均损失 'sum': 返回 batch 总损失 ignore_index: 忽略的标签索引 当 target 为硬标签且值为 ignore_index 时,该样本不计入损失 Returns: 损失值:如果 reduction='none',返回 shape (N,) 的张量 否则返回标量张量 Examples: >>> N, C = 16, 10 # 16个样本,10个类别 >>> x = torch.randn(N, C) >>> target = torch.randint(0, C, (N,)) >>> loss = cross_entropy_loss(x, target) """ # 直接调用 PyTorch 标准 CrossEntropyLoss 实现 # torch.nn.functional.cross_entropy 内部会自动应用 log_softmax loss = torch.nn.functional.cross_entropy( input=x, target=target, reduction=reduction, ignore_index=ignore_index ) return loss

6. 额外信息

算子调用示例

import torch import cann_bench x = torch.randn(1024, 2048, dtype=torch.float32, device="npu") target = torch.randint(0, 2048, (1024,), dtype=torch.int64, device="npu") loss = cann_bench.cross_entropy_loss(x, target, reduction="mean", ignore_index=-100) loss = cann_bench.cross_entropy_loss(x, target, reduction="sum", ignore_index=-100) loss = cann_bench.cross_entropy_loss(x, target, reduction="none", ignore_index=-100)

【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力,涵盖算子生成、算子优化等领域,支撑模型选型、训练效果评估,统一量化评估标准,识别Agent能力短板,构建CANN领域评测平台,推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

http://www.jsqmd.com/news/787889/

相关文章:

  • 算法模拟与生命智能:从架构差异看AI的本质与局限
  • CANN/ops-nn硬Sigmoid反向传播算子
  • 命令行办公自动化:officecli-skills技能库实战指南
  • ARM虚拟处理器模型在无线设备开发中的关键作用
  • 房价预测项目:自己手写线性回归,值不值?
  • AI赋能食品工业:从合成生物学到智能制造的全面革新
  • Datadog Cursor插件:用自然语言对话查询监控数据的完整指南
  • CANN/pyasc算子编程接口
  • 3PEAK思瑞浦 LM2902A-TS2R-S TSSOP14 运算放大器
  • Meta广告AI代理实战:基于MCP协议构建自动化广告管理工具
  • Animal-AI环境:用强化学习复现动物认知实验,评估AI智能水平
  • 智能代理框架ProxyAI:AI赋能API网关与微服务架构实践
  • 集成学习在药物虚拟筛选中的应用:构建稳健AI预测模型
  • 基于FNN与XAI的微射流速度预测及气泡位置影响机制研究
  • 3PEAK思瑞浦 TPA3672-SO1R SOP8 运算放大器
  • SEO地理优化利器:hreflang与JSON-LD实战指南
  • AI赋能密度泛函理论:量子张量学习与机器学习泛函实践
  • 抖音内容下载终极指南:从零开始构建你的专属素材库
  • 动物森友会存档编辑器NHSE:终极完整指南与实战教程
  • AI驱动蛋白质工程:机器学习与拓扑数据分析的融合实践
  • AI接管运维:工程师秒变甩手掌柜
  • 5分钟掌握qmc-decoder:终极QQ音乐加密格式解密指南
  • 华为CANN通信远端内存API
  • CANN随机数算子库文档
  • Spring Boot 缓存优化:从入门到精通
  • 5G波形技术演进与新型解决方案对比
  • 钉钉机器人 Webhook 方式与 SDK 方式接入哪种更适合 CI/CD 场景?
  • 2026年四川地区钢材采购决策:如何筛选靠谱供应商与盛世钢联建立长期合作 - 四川盛世钢联营销中心
  • Arm安全协处理器寄存器架构与内存重映射技术解析
  • 2026粉末冶金加工厂家推荐:铜基与铁基粉末冶金厂家的工艺特点及应用领域 - 栗子测评