常用分类损失CE Loss、Focal Loss及GHMC Loss理解与总结

一、CE Loss

定义

交叉熵损失(Cross-Entropy Loss,CE Loss)能够衡量同一个随机变量中的两个不同概率分布的差异程度,当两个概率分布越接近时,交叉熵损失越小,表示模型预测结果越准确。

公式

二分类

二分类的CE Loss公式如下,

其中,M:正样本数量,N:负样本数量,y_{i}:真实值, p_{i}:预测值

多分类

在计算多分类的CE Loss时,首先需要对模型输出结果进行softmax处理。公式如下,

其中, output:模型输出,p:对模型输出进行softmax处理后的值, ​​​​​:真实值的one hot编码​(假设模型在做5分类,如果y_{i}=2,则=[0,0,1,0,0])

代码实现

二分类

import torch
import torch.nn as nn
import mathcriterion = nn.BCELoss()
output = torch.rand(1, requires_grad=True)
label = torch.randint(0, 1, (1,)).float()
loss = criterion(output, label)print("预测值:", output)
print("真实值:", label)
print("nn.BCELoss:", loss)for i in range(label.shape[0]):if label[i] == 0:res = -math.log(1-output[i])elif label[i] == 1:res = -math.log(output[i])
print("自己的计算结果", res)"""
预测值: tensor([0.7359], requires_grad=True)
真实值: tensor([0.])
nn.BCELoss: tensor(1.3315, grad_fn=<BinaryCrossEntropyBackward0>)
自己的计算结果 1.331509556677378
"""

多分类

import torch
import torch.nn as nn
import mathcriterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)print("预测值:", output)
print("真实值:", label)
print("nn.CrossEntropyLoss:", loss)output = torch.softmax(output, dim=1)
print("softmax后的预测值:", output)one_hot = torch.zeros_like(output).scatter_(1, label.view(-1, 1), 1)
print("真实值对应的one_hot编码", one_hot)res = (-torch.log(output) * one_hot).sum()
print("自己的计算结果", res)"""
预测值: tensor([[-0.7459, -0.3963, -1.8046,  0.6815,  0.2965]], requires_grad=True)
真实值: tensor([1])
nn.CrossEntropyLoss: tensor(1.9296, grad_fn=<NllLossBackward0>)
softmax后的预测值: tensor([[0.1024, 0.1452, 0.0355, 0.4266, 0.2903]], grad_fn=<SoftmaxBackward0>)
真实值对应的one_hot编码 tensor([[0., 1., 0., 0., 0.]])
自己的计算结果 tensor(1.9296, grad_fn=<SumBackward0>)
"""

二、Focal Loss

定义

虽然CE Loss能够衡量同一个随机变量中的两个不同概率分布的差异程度,但无法解决以下两个问题:1、正负样本数量不平衡的问题(如centernet的分类分支,它只将目标的中心点作为正样本,而把特征图上的其它像素点作为负样本,可想而知正负样本的数量差距之大);2、无法区分难易样本的问题(易分类的样本的分类错误的损失占了整体损失的绝大部分,并主导梯度)

为了解决以上问题,Focal Loss在CE Loss的基础上改进,引入了:1、正负样本数量调节因子以解决正负样本数量不平衡的问题;2、难易样本分类调节因子以聚焦难分类的样本

公式

二分类

公式如下,

 

​​​​​​​

其中,\alpha:正负样本数量调节因子,\gamma:难易样本分类调节因子

多分类

其中,\alpha _{y_{i}}y_{i}类别的权重

代码实现

二分类

def sigmoid_focal_loss(inputs: torch.Tensor,targets: torch.Tensor,alpha: float = -1,gamma: float = 2,reduction: str = "none",
) -> torch.Tensor:"""Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.Args:inputs: A float tensor of arbitrary shape.The predictions for each example.targets: A float tensor with the same shape as inputs. Stores the binaryclassification label for each element in inputs(0 for the negative class and 1 for the positive class).alpha: (optional) Weighting factor in range (0,1) to balancepositive vs negative examples. Default = -1 (no weighting).gamma: Exponent of the modulating factor (1 - p_t) tobalance easy vs hard examples.reduction: 'none' | 'mean' | 'sum''none': No reduction will be applied to the output.'mean': The output will be averaged.'sum': The output will be summed.Returns:Loss tensor with the reduction option applied."""inputs = inputs.float()targets = targets.float()p = torch.sigmoid(inputs)ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")p_t = p * targets + (1 - p) * (1 - targets)loss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * targets + (1 - alpha) * (1 - targets)loss = alpha_t * lossif reduction == "mean":loss = loss.mean()elif reduction == "sum":loss = loss.sum()return loss

步骤1、首先对输入进行sigmoid处理,

p = torch.sigmoid(inputs)

步骤2、随后求出CE Loss,

ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

步骤3、定义p_{t}^{i},公式为:

p_t = p * targets + (1 - p) * (1 - targets)

步骤4、为CE Loss添加难易样本分类调节因子,

loss = ce_loss * ((1 - p_t) ** gamma)

步骤5、定义\alpha _{t}^{i},公式为:

alpha_t = alpha * targets + (1 - alpha) * (1 - targets)

步骤6、为步骤4的损失添加正负样本数量调节因子,

loss = alpha_t * loss

多分类

def multi_cls_focal_loss(inputs: torch.Tensor,targets: torch.Tensor,alpha: torch.Tensor,gamma: float = 2,reduction: str = "none",
) -> torch.Tensor:inputs = inputs.float()targets = targets.float()ce_loss = nn.CrossEntropyLoss()(inputs, targets, reduction="none")one_hot = torch.zeros_like(inputs).scatter_(1, targets.view(-1, 1), 1)p_t = inputs * one_hotloss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * one_hotloss = alpha_t * lossreturn loss

三、GHMC Loss

定义

Focal Loss在CE Loss的基础上改进后,解决了正负样本不平衡以及无法区分难易样本的问题,但也会过分关注难分类的样本(离群点),导致模型学歪。为了解决这个问题,GHMC(Gradient Harmonizing Mechanism-C)定义了梯度模长,该梯度模长正比于分类的难易程度,目的是让模型不要关注那些容易学的样本,也不要关注那些特别难分的样本

公式

1、定义梯度模长

二分类的CE Loss公式如下,

假设x是模型的输出,假设p=sigmoid(x),求损失对x的偏导,

因此,定义梯度模长如下,

其中, p:预测值,p^{\ast }:真实值

梯度模长与样本数量的关系如下,

2、定义梯度密度(单位梯度模长g上的样本数量

  

其中,g_{k}:第k个样本的梯度模长,\delta _{\varepsilon }(g_{k},g)g_{k}(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})范围内的样本数量,l_{\varepsilon }(g):区间(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})的长度

3、定义梯度密度协调参数(gradient density harmonizing parameter)

其中,N:样本总数

 4、定义GHMC Loss

 

代码实现

def _expand_binary_labels(labels, label_weights, label_channels):bin_labels = labels.new_full((labels.size(0), label_channels), 0)inds = torch.nonzero(labels >= 1).squeeze()if inds.numel() > 0:bin_labels[inds, labels[inds] - 1] = 1bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weightsclass GHMC(nn.Module):def __init__(self,bins=10,momentum=0,use_sigmoid=True,loss_weight=1.0):super(GHMC, self).__init__()self.bins = binsself.momentum = momentumself.edges = [float(x) / bins for x in range(bins+1)]self.edges[-1] += 1e-6if momentum > 0:self.acc_sum = [0.0 for _ in range(bins)]self.use_sigmoid = use_sigmoidself.loss_weight = loss_weightdef forward(self, pred, target, label_weight, *args, **kwargs):""" Args:pred [batch_num, class_num]:The direct prediction of classification fc layer.target [batch_num, class_num]:Binary class target for each sample.label_weight [batch_num, class_num]:the value is 1 if the sample is valid and 0 if ignored."""if not self.use_sigmoid:raise NotImplementedError# the target should be binary class labelif pred.dim() != target.dim():target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))target, label_weight = target.float(), label_weight.float()edges = self.edgesmmt = self.momentumweights = torch.zeros_like(pred)# 计算梯度模长g = torch.abs(pred.sigmoid().detach() - target)valid = label_weight > 0tot = max(valid.float().sum().item(), 1.0)# 设置有效区间个数n = 0for i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i+1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] \+ (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1if n > 0:weights = weights / nloss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / totreturn loss * self.loss_weight

步骤一、将梯度模长划分为bins(默认为10)个区域,

self.edges = [float(x) / bins for x in range(bins+1)]
"""
[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000, 1.0000]
"""

步骤二、计算梯度模长

g = torch.abs(pred.sigmoid().detach() - target)

步骤三、计算落入不同bin区间的梯度模长数量

valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0
for i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i+1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1
if n > 0:weights = weights / n

步骤四、计算GHMC Loss

loss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / tot * self.loss_weight

【参考文章】

Focal Loss的理解以及在多分类任务上的使用(Pytorch)_focal loss 多分类_GHZhao_GIS_RS的博客-CSDN博客

focal loss 通俗讲解 - 知乎

Focal Loss损失函数(超级详细的解读)_BigHao688的博客-CSDN博客

5分钟理解Focal Loss与GHM——解决样本不平衡利器 - 知乎 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/21607.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

安装orcle报错:指定的 Oracle 系统标识符 (SID) 已在使用

安装orcle报错&#xff1a;[INS-35075]指定的 Oracle 系统标识符 (SID) 已在使用 说明前面的orcle没有彻底删除 解决这个问题&#xff1a; 搜索框 —— > 输入&#xff1a;regedit ——> 回车 运行regedit&#xff0c;选择HKEY_LOCAL_MACHINE SOFTWARE ORACLE&#xff…

数字图像处理【11】OpenCV-Canny边缘提取到FindContours轮廓发现

本章主要介绍图像处理中一个比较基础的操作&#xff1a;Canny边缘发现、轮廓发现 和 绘制轮廓。概念不难&#xff0c;主要是结合OpenCV 4.5的API相关操作&#xff0c;为往下 "基于距离变换的分水岭图像分割" 做知识储备。 Canny边缘检测 在讲述轮廓之前&#xff0c;…

【Hippo4j源码的方式安装部署教程】

&#x1f680; 线程池管理工具-Hippo4j &#x1f680; &#x1f332; AI工具、AI绘图、AI专栏 &#x1f340; &#x1f332; 如果你想学到最前沿、最火爆的技术&#xff0c;赶快加入吧✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;CSDN-Java领域优质创作者&#…

Object.fromEntries()将键值对列表转换为一个对象

Object.fromEntries() 静态方法将键值对列表转换为一个对象 将 Array 转换成对象&#xff1a; let arr [["name","张三"],["age","40"]] let obj Object.fromEntries(arr); console.log(obj);将 Map 转换成对象&#xff1a; let …

Spring 项目创建和使用2 (Bean对象的存取)

目录 一、创建 Bean 对象 二、将Bean对象存储到 Spring容器中 三、创建 Spring 上下文&#xff08;得到一个Spring容器&#xff09; 1. 通过在启动类中 ApplicationContext 获取一个 Spring容器 2. 通过在启动类种使用 BeanFactory 的方式来得到 Spring 对象 &#xff08;此…

C# Linq 详解一

目录 一、概述 二、Where 三、Select 四、GroupBy 五、First / FirstOrDefault 六、Last / LastOrDefault C# Linq 详解一 1.Where 2.Select 3.GroupBy 4.First / FirstOrDefault 5.Last / LastOrDefault C# Linq 详解二 1.OrderBy 2.OrderByDescending 3.Skip 4.Take …

第一百零六天学习记录:数据结构与算法基础:单链表(王卓教学视频)

线性表的链式表示和实现 结点在存储器中的位置是任意的&#xff0c;即逻辑上相邻的数据元素在物理上不一定相邻 线性表的链式表示又称为非顺序映像或链式映像。 用一组物理位置任意的存储单元来存放线性表的数据元素。 这组存储单元既可以是连续的&#xff0c;也可以是不连续的…

C#生成类库dll以及调用实例

本文讲解如何用C#语言生成类库并用winform项目进行调用 目录 创建C#类库项目 Winform调用dll 创建C#类库项目 编写代码 using System.Threading;namespace ClassLibrary1 {public class Class1{private Timer myTimer = null;//定义定时器用于触发事件//定义公共的委托和调…

短视频抖音seo矩阵系统源码开发者思路(一)

一套优秀的短视频获客系统&#xff0c;支持短视频智能剪辑、短视频定时发布&#xff0c;短视频排名查询及优化&#xff0c;短视频智能客服等&#xff0c;那么短视频seo系统具体开发应该具备哪些功能呢&#xff1f;今天小编就跟大家分享一下我们的技术开发思路。 抖音矩阵系统源…

Qt Https通信: TLS initialization failed 解决方法

Qt Https通信&#xff1a; TLS initialization failed 解决方法&#xff0c;Window端使用Qt 做开发请求Https资源时&#xff0c;会经常遇到 TLS initialization failed。 原因分析&#xff1a; 在Qt中并未包含 SSL所包含的库&#xff0c;因此需要开发者&#xff0c;自己将库拷贝…

百度iOS端长连接组件建设及应用实践

作者 | 百度消息中台团队 导读 在过去的十年里&#xff0c;移动端技术飞速发展&#xff0c;移动应用逐渐成为主要的便捷访问和使用互联网的方式&#xff0c;承接了越来越多的业务和功能&#xff0c;这也意味着对移动端和服务器之间的通信效率和稳定性提出了更高的要求。为了实现…

C语言实现简易通讯录

目录 普通版 功能需求 模块设计 test.c模块实现 contact.h模块实现 类型的声明 函数的声明 头文件、枚举、宏定义 contact.c 模块实现 初始化通讯录 增加联系人 显示所有联系人的信息 查找函数 删除指定联系人 查找指定联系人 修改指定联系人 进阶版通讯录&a…