交叉熵、Focal Loss以及其Pytorch实现

交叉熵、Focal Loss以及其Pytorch实现

本文参考链接:https://towardsdatascience.com/focal-loss-a-better-alternative-for-cross-entropy-1d073d92d075

文章目录

  • 交叉熵、Focal Loss以及其Pytorch实现
    • 一、交叉熵
    • 二、Focal loss
    • 三、Pytorch
      • 1.[交叉熵](https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html?highlight=nn+crossentropyloss#torch.nn.CrossEntropyLoss)
      • 2.[Focal loss](https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py)

一、交叉熵

损失是通过梯度回传用来更新网络参数是之产生的预测结果和真实值之间相似。不同损失函数有着不同的约束作用,不同的数据对损失函数有着不同的影响。

交叉熵是常见的损失函数,常见于语义分割、对比学习等。其函数表达式如下,其中 Y i 和 p i Y_i和p_i Yipi分别表示真实值和预测结果:
C r o s s E n t r o p y = − ∑ i = 1 i = n Y i l o g ( p i ) CrossEntropy=-\sum_{i=1}^{i=n}Y_ilog(p_i) CrossEntropy=i=1i=nYilog(pi)
因为 p i p_i pi值在0~1之间,故交叉熵大于等于0。这个函数什么时候最小呢?数学证明结果表明当 Y i = p i Y_i=p_i Yi=pi时交叉熵最小。下面,我们取二分类情况来进行简单证明:
B C E L o s s = − y l o g x − ( 1 − y ) l o g ( 1 − x ) BCELoss=-ylogx-(1-y)log(1-x) BCELoss=ylogx(1y)log(1x)
对BCELoss求导可得:
− y x + 1 − y 1 − x = y − x x − x 2 -\frac{y}{x}+\frac{1-y}{1-x}=\frac{y-x}{x-x^2} xy+1x1y=xx2yx
所以当 y = x y=x y=x时,二分类交叉熵取最小值。

那么,交叉熵有啥子问题?

  • 从表达式可以看出,交叉熵只针对单个像素进行比较,像素和像素之间并没有联系,这就需要我们在模型中使用空间注意力机制等使得特征在空间上进行交互。针对这个问题,不少论文提出了改进方案,如Context Prior for Scene Segmentation这篇论文就使用了和precision、recall等类似的损失函数(就像Dice loss和F1 score指标一样)。

  • 类别不平衡:这个问题比较常见,语义分割中类别在图片上总像素占比是不平衡。如果类别不平衡比较严重,交叉熵损失就会偏向于占比较高的类别,导致对占比较少的类别预测结果较差。解决这一方法为给交叉熵损失添加权重(平衡交叉熵)等,如下式:
    B a l a n c e d C r o s s E n t r o p y = − ∑ i = 1 i = n α i Y i l o g ( p i ) BalancedCrossEntropy=-\sum_{i=1}^{i=n}\alpha_iY_ilog(p_i) BalancedCrossEntropy=i=1i=nαiYilog(pi)

  • 困难样本:首先,我们要知道困难样本是那些模型反复出现巨大损失的例子,而简单样本是那些容易分类的例子。交叉熵对于所有样本同等对待,导致无法辨别困难样本和简单样本。解决这一问题就是接下来的损失函数Focal loss

二、Focal loss

Focal loss关注的是模型出错的例子,而不是它可以自信地预测的例子,确保对困难的例子的预测随着时间的推移而改善,而不是对容易的例子变得过于自信。

这到底是怎么做到的呢?Focal loss是通过一个叫做Down Weighting的东西来实现的。下调权重是一种技术,它可以减少容易的例子对损失函数的影响,从而使人们更加关注困难的例子。这种技术可以通过在交叉熵损失中加入一个调节因子来实现。其表达式如下:
F o c a l L o s s = − ∑ i = 1 i = n ( 1 − p i ) γ l o g p i FocalLoss=-\sum_{i=1}^{i=n}(1-p_i)^{\gamma}logp_i FocalLoss=i=1i=n(1pi)γlogpi
不同的 γ \gamma γ对损失有什么影响呢?如下图所示

img

不同的 γ \gamma γ ( 1 − p i ) γ (1-p_i)^{\gamma} (1pi)γ有什么影响呢,如下:

img

  • 在误分类样本的情况下, p i pi pi很小,使得调制因子大约或非常接近于1,这使损失函数不受影响。此时,Focal Loss和交叉熵损失相似。
  • 随着模型置信度的提高,即 p i → 1 pi→1 pi1,调制因子将趋于0,从而降低了分类良好的例子的损失值。聚焦参数, γ \gamma γ≥1,将重新调整调制因子,使容易的例子比困难的例子降权更多,减少它们对损失函数的影响。例如,考虑预测概率为0.9和0.6。考虑到 γ \gamma γ=2,对0.9计算出的损失值是4.5e-4,降权系数( 1 / ( 1 − q i ) 2 1/(1-q_i)^2 1/(1qi)2)为100,对0.6则是3.5e-2,降权系数为6.25。从实验来看, γ \gamma γ=2来说效果最好。
  • γ \gamma γ=0时,Focal Loss等同于Cross Entropy。

此外,加入平衡因子 α \alpha α,用来平衡正负样本本身的比例不均:文中 α \alpha α取0.25,即正样本要比负样本占比小,这是因为负例易分。其表达式如下:
F o c a l L o s s = − ∑ i = 1 i = n α i ( 1 − p i ) γ l o g p i FocalLoss=-\sum_{i=1}^{i=n}\alpha_i(1-p_i)^{\gamma}logp_i FocalLoss=i=1i=nαi(1pi)γlogpi
Focal Loss自然地解决了阶级不平衡的问题,(1因为来自多数类别的例子通常容易预测,而来自少数类别的例子由于缺乏数据或来自多数类别的例子在损失和梯度过程中占主导地位而难以预测。由于这种相似性,Focal Loss可能能够解决这两个问题。

三、Pytorch

1.交叉熵

Pytorch可以直接调用交叉熵损失函数nn.CrossEntropyLoss(),其功能还是比较全的。其中weight可以用了进行权重平衡,ignore_index可以用来忽略特定类别。输入的标签不需要进行one hot编码,其内部已经实现。nn.CrossEntropyLoss()=nn.NLLoss() + nn.LogSoftmax。
在这里插入图片描述

2.Focal loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):def __init__(self, gamma=0, alpha=None, size_average=True):super(FocalLoss, self).__init__()self.gamma = gammaself.alpha = alphaif isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)self.size_average = size_averagedef forward(self, input, target):if input.dim()>2:input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*Winput = input.transpose(1,2)    # N,C,H*W => N,H*W,Cinput = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,Ctarget = target.view(-1,1)logpt = F.log_softmax(input)logpt = logpt.gather(1,target)logpt = logpt.view(-1)pt = Variable(logpt.data.exp())if self.alpha is not None:if self.alpha.type()!=input.data.type():self.alpha = self.alpha.type_as(input.data)at = self.alpha.gather(0,target.data.view(-1))logpt = logpt * Variable(at)loss = -1 * (1-pt)**self.gamma * logptif self.size_average: return loss.mean()else: return loss.sum()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/14904.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringCloud入门实战(十二)-Sleuth+Zipkin分布式请求链路跟踪详解

📝 学技术、更要掌握学习的方法,一起学习,让进步发生 👩🏻 作者:一只IT攻城狮 ,关注我,不迷路 。 💐学习建议:1、养成习惯,学习java的任何一个技术…

Keras-5-深度学习用于文本和序列-处理文本数据

深度学习用于文本和序列 说明: 本篇学习记录为:《Python 深度学习》第6章第1节(处理文本数据) 知识点: 深度学习处理文本或序列数据的基本方法是:循环神经网络 (recurrent neural network) 和 一维卷积神经网络 (1D convert)&…

SpringBoot—统一功能处理

SpringBoot—统一功能处理 🔎小插曲(通过一级路由调用多种方法)🔎使用拦截器实现用户登录权限的统一校验自定义拦截器将自定义拦截器添加至配置文件中拦截器的实现原理统⼀访问前缀添加 🔎统一异常的处理🔎统一数据格式的返回统一…

天天刷题-->LeetCode(两数相加)

个人名片: 🐅作者简介:一名大二在校生,热爱生活,爱好敲码! \ 💅个人主页 🥇:holy-wangle ➡系列内容: 🖼️ tkinter前端窗口界面创建与优化 &…

oracle启动/关闭/查看监听+启动/关闭/查看数据库实例命令

启动oracle第一步启动监听,第二步启动数据库实例 (1)输入su oracle进入oracle用户状态 (2)这里的密码是你的root密码 1 启动/关闭/查看监听命令 (1)启动监听—— lsnrctl start &am…

Permission denied (publickey,password)问题的解决办法

[15:29:00.146] Terminal shell path: C:\WINDOWS\System32\cmd.exe [15:29:01.703] > root59.110.21.45: Permission denied (publickey,password). 解决: RSA key 登录方法/home/user/ 目录下建立 .ssh/ 文件夹 cd ~/ mkdir .ssh # 注意.ssh文件夹的权限 ch…

GIT版本控制常规性操作演示汇总

文章目录 GIT基本操作GIT配置个人信息配置:GIT查看个人信息配置:GIT的三大区域GIT回滚:git resetGIT恢复日志:git reflogGIT三大区域转换GIT新建分支GIT合并分支GIT删除分支码云上创建项目GIT变基:git rebase合并提交记…

随机产生50个100以内的不重复的整数,设计位图排序算法进行排序。

1.问题 随机产生50个100以内的不重复的整数,设计位图排序算法进行排序。 2.设计思路 阶段1: 初始化一个空集合    for i[0,n)    bit[i]0 阶段2: 读入数据i,并设置bit[i]1    for each i in the input file    bit[i]1…

3.6.共享内存的学习

目录 前言1. 共享内存2. shared memory案例3. 补充知识总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。 本次课程学习精简 CUDA 教程-共享…

基于C++、GDAL、OpenCV的矢量数据骨架线提取算法

基于C、GDAL、OpenCV的矢量数据骨架线提取算法 CGAL已经实现了该功能,但由于CGAL依赖于Boost库,编译后过大,因此本文所采用的这套方式实现骨架线提取功能。 效果: 思路: 1、将导入shp按照要素逐一拆分成新的shp 2、…

类Twitter风格的RSS阅读器

本文完成于 2 月中旬,其中的反代还是 frp npm 方案; 什么是 RSS ? RSS 是用 PHP、Laravel、Inertia.js、Tailwind 和 Vue.js 编写的简单的类Twitter 风格的 RSS阅读器,支持 RSS和ATOM 格式。 命令行安装 在群晖上以 Docker 方式安装。 官…

Nacos2.3.0源码启动报错找不到符号com.alibaba.nacos.consistency.entity

一. 源码下载编译:找不到符号com.alibaba.nacos.consistency.entity 如果报错找不到符号com.alibaba.nacos.consistency.entity Nacos\consistency\src\main\java\com\alibaba\nacos\consistency\entity 这个包下没有相关的java文件,其实是我们没有编译…