当前位置: 首页 > news >正文

损失函数 的 硬截断 和 平滑衰减

损失函数 的 硬截断 和 平滑衰减

flyfish

在逐样本损失计算完成、取平均之前,对损失过高的样本做权重压制,不删除样本,只削弱它们对梯度的贡献,属于软降权——既保留了样本的监督信号,又避免极端难样本/疑似错标样本带偏整个模型。

损失硬截断

损失硬截断是给单样本损失设置一个上限,超过这个阈值的损失,直接按阈值计算。相当于一刀切,超过上限的样本梯度不再放大。

代码实现

classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma=2,alpha=None,smoothing=0.0,num_classes=2,max_loss=None):""" :param max_loss: 单样本损失上限,None表示不开启截断;设置数值后,单样本损失不会超过该值 """super().__init__()self.gamma=gamma self.alpha=torch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothing=smoothing self.num_classes=num_classes self.max_loss=max_loss# 损失截断阈值defforward(self,inputs,targets):targets_one_hot=torch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targets=targets_one_hot*(1-self.smoothing)+self.smoothing/self.num_classes log_probs=torch.nn.functional.log_softmax(inputs,dim=1)probs=torch.exp(log_probs)p_t=(probs*targets_one_hot).sum(dim=1,keepdim=True)focal_weight=(1-p_t)**self.gamma ce_loss=(-soft_targets*log_probs).sum(dim=1)loss=focal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t=(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim=1)loss=loss*alpha_t# ========== 损失截断 ==========ifself.max_lossisnotNone:loss=torch.clamp(loss,max=self.max_loss)returnloss.mean()

使用方式

在训练函数里初始化损失时,多加一个max_loss参数即可:

# 示例:单样本损失最高不超过2.0,超过的全部按2.0计算criterion=FocalLossWithSmoothing(gamma=FOCAL_GAMMA,alpha=FOCAL_ALPHA,smoothing=LABEL_SMOOTHING,num_classes=NUM_CLASSES,max_loss=2.0# 开启截断,阈值可按需调整)

平滑衰减降权

硬截断是一刀切:损失超过阈值,直接砍平,损失值瞬间不再增长,像台阶一样突变;
平滑衰减是越涨越慢:损失低于阈值时正常计算,超过阈值后还能继续涨,但增长速度会越来越慢,过渡是顺滑的曲线,没有突变台阶。

它的目的:既保留损失越高、权重越大的相对顺序,又不让极端高损失样本无限放大梯度带偏模型,同时保证训练过程梯度平稳,不会出现跳变

代码实现 只需要把截断部分替换成平滑衰减逻辑即可:

classFocalLossWithSmoothing(nn.Module):def__init__(self,gamma=2,alpha=None,smoothing=0.0,num_classes=3,loss_threshold=1.8):super().__init__()self.gamma=gamma self.alpha=torch.tensor(alpha).to(DEVICE)ifalphaelseNoneself.smoothing=smoothing self.num_classes=num_classes self.loss_threshold=loss_threshold# 平滑衰减阈值defforward(self,inputs,targets):targets_one_hot=torch.zeros_like(inputs).scatter_(1,targets.unsqueeze(1),1)soft_targets=targets_one_hot*(1-self.smoothing)+self.smoothing/self.num_classes log_probs=torch.nn.functional.log_softmax(inputs,dim=1)probs=torch.exp(log_probs)p_t=(probs*targets_one_hot).sum(dim=1,keepdim=True)focal_weight=(1-p_t)**self.gamma ce_loss=(-soft_targets*log_probs).sum(dim=1)loss=focal_weight.squeeze()*ce_lossifself.alphaisnotNone:alpha_t=(self.alpha.unsqueeze(0)*targets_one_hot).sum(dim=1)loss=loss*alpha_t# 平滑衰减降权:压制极端高损失样本ifself.loss_thresholdisnotNone:high_loss_mask=loss>self.loss_threshold loss[high_loss_mask]=self.loss_threshold+torch.log(1+loss[high_loss_mask]-self.loss_threshold)returnloss.mean()

假设设置阈值 = 1.5,看不同原始损失对应的处理结果:

原始单样本损失硬截断后损失变化特点
1.0(正常样本)1.0低于阈值,完全不变
1.4(较难样本)1.4低于阈值,完全不变
1.5(阈值点)1.5刚好等于阈值
1.6(难样本)1.5超过一点点,直接被砍成1.5,瞬间停止增长
3.0(极难/错标样本)1.5不管多高,全砍成1.5,和1.6的样本权重完全一样

硬截断的问题

  1. 阈值点处损失突变,梯度也会突变,训练过程容易出现震荡;
  2. 所有超过阈值的样本,损失都一样,丢失了难分程度的差异信息——3.0的极难样本和1.6的轻微难样本,对模型的贡献变得完全相同,有点矫枉过正。

平滑衰减的逻辑:两段式 + 对数压缩

代码里用的是阈值以下正常计算,阈值以上对数压缩的两段式策略,公式是:
处理后损失={原始损失原始损失≤阈值阈值+log⁡(1+原始损失−阈值)原始损失>阈值 \text{处理后损失} = \begin{cases} \text{原始损失} & \text{原始损失} \le 阈值 \\ 阈值 + \log(1 + \text{原始损失} - 阈值) & \text{原始损失} > 阈值 \end{cases}处理后损失={原始损失阈值+log(1+原始损失阈值)原始损失阈值原始损失>阈值

为什么用 log(对数)函数?

对数函数有两个完美匹配需求的特性:

  1. 单调递增:原始损失越大,处理后的损失也一定越大,不会改变谁更难、谁损失更高的排序,样本的相对权重关系保留了;
  2. 增速递减:x 越大,log(x) 涨得越慢。原始损失越高,压缩力度越强,正好符合极端样本降权更多的需求。

直观对比效果

还是设阈值 = 1.5,算一组真实数值,一眼就能看出区别:

原始单样本损失硬截断后平滑衰减后直观感受
1.01.01.00低于阈值,两者完全一样
1.41.41.40低于阈值,两者完全一样
1.51.51.50阈值点,两者对齐
1.61.51.595只超了一点点,压缩很轻微,几乎和原值差不多
2.01.51.693超了0.5,增长明显放缓,不再是直线涨
3.01.51.946超了1.5,涨幅被大幅压缩,不会涨到3.0
5.01.52.208超了3.5,增速进一步变慢,和3.0的差距被缩小

可以明显看到:
刚超过阈值时,损失几乎不受影响,过渡非常顺滑;
损失越高,被压缩得越厉害,但始终保持越高越重的排序;
不会像硬截断那样,所有高损失全变成同一个值。

对应代码

loss[high_loss_mask]=self.loss_threshold+torch.log(1+loss[high_loss_mask]-self.loss_threshold)

拆解开:

  1. loss[high_loss_mask] - self.loss_threshold:算出损失超出阈值的部分(增量);
  2. 1 + 增量:加1保证对数的输入大于0,避免出现负数报错;
  3. torch.log(...):对超出的增量做对数压缩,让增量涨得变慢;
  4. self.loss_threshold + 压缩后的增量:把基准阈值加回来,保证阈值点处数值连续、没有台阶。

什么时候用硬截断,什么时候用平滑衰减?

方案场景特点
硬截断确定有大量标注错误,想直接屏蔽极端错标的影响简单粗暴,可控性强,调试方便
平滑衰减样本大多是标注正确的难样本(比如小目标、低对比度),只想削弱、不想完全屏蔽更温和,梯度平稳,训练更稳定,保留难样本的相对差异信息
http://www.jsqmd.com/news/1071887/

相关文章:

  • 如何高效解决浏览器全屏API兼容性问题:screenfull.js进阶实战指南
  • Get Shit Done:重新定义AI编程工作流的革命性框架
  • 拒绝踩坑!企业搭建多商户商城/知识付费平台,技术选型到底该看什么?
  • 全能免费在线工具箱ToolBoxMax,100+工具本地浏览器运行,保护隐私无需注册
  • 杭州吟颂职称政策调研:浙江省工程师申报要求
  • 双重检测不用慌!okbiye 分层降重降 AIGC 方案一次性打通论文审核关卡
  • 深度解析kohya_ss训练监控:5个关键技术指标与可视化实战指南
  • 为什么 SSR 一定会有 hydration mismatch?
  • 3步轻松上手ESP32物联网开发:Arduino核心的终极入门指南
  • 正态总体样本方差、t 分布 纯文本笔记
  • Git 超详细入门教程(附实战命令 常见坑)
  • 【影刀】手机自动化运行输入框无法输入文字,报错提示ACTION_SET_PROGRESS has failed on the element ‘android.view.accessibility.
  • 5个PDFPatcher实战技巧:免费解决PDF格式难题的完整指南
  • 流式微调(Streaming Fine-tuning)正在重构AI架构——3家头部企业已验证的4类低代码集成范式
  • PDFPatcher完全指南:5个实战技巧快速解决PDF处理难题
  • 终极指南:如何让老旧Mac免费安装最新macOS系统
  • 【昇腾/AscendC开发】AscendC 910B GM 标量/MTE 双向缓存不一致 Bug 详解
  • PREEMPT_RT 技术实现:local_lock
  • PDF补丁丁完全指南:5个免费开源技巧彻底解决PDF编辑难题
  • 如何让Intel显卡火力全开:MPV播放器硬件加速终极优化指南
  • 试试连Claude Code团队都在使用的终端软件Ghostty
  • PDF处理架构解析:PDFPatcher开源工具箱的技术实现与实战指南
  • 物联网智能锁实战:公寓/集团宿舍实名核验+远程授权落地方案
  • 太原食品级干冰
  • ESP32 Arduino开发终极指南:5步轻松配置物联网开发环境
  • 终极LX Music音源配置指南:3分钟解锁全网无损音乐
  • 视频电子设备音画不同步?可能是晶振温漂在“捣鬼”
  • 天磊卫士:全链路 AI 安全合规服务,护航人工智能规范落地
  • 射频内透热 vs 红外 vs EMS vs 艾灸:四种减重设备技术路线一文说清
  • 2026国内龙虾下载推荐 五款实测 Aionclaw 领衔自动化提效指南