yolo增加slide loss,改善样本不平衡问题

slide loss的主要作用是让模型更加关注难例,可以轻微的改善模型在难例检测上的效果

论文地址:https://arxiv.org/pdf/2208.02019.pdf

代码:GitHub - Krasjet-Yu/YOLO-FaceV2: YOLO-FaceV2: A Scale and Occlusion Aware Face Detector

        样本不平衡问题,即在大多数情况下,容易样本的数量很大,而困难样本相对稀疏,引起了很多关注。在本文的工作中,设计了一个看起来像“slide”的Slide Loss函数来解决这个问题。简单样本和困难样本之间的区别是基于预测框和ground truth 框的IoU大小。为了减少超参数,将所有边界框的 IoU 值的平均值作为阈值 µ,小于µ的取负样本,大于µ的取正样本。

        然而,由于分类不明确,边界附近的样本往往会遭受较大的损失。希望模型能够学习优化这些样本,并更充分地使用这些样本来训练网络。然而,此类样本的数量相对较少。因此,尝试为困难样本分配更高的权重。首先通过参数μ将样本分为正样本和负样本。然后,通过加权函数Slide对边界处的样本进行强调,如图 4 所示。Slide加权函数可以表示为公式5。

在utils/loss.py增加

import math
class SlideLoss(nn.Module):def __init__(self, loss_fcn):super(SlideLoss, self).__init__()self.loss_fcn = loss_fcnself.reduction = loss_fcn.reductionself.loss_fcn.reduction = 'none'  # required to apply SL to each elementdef forward(self, pred, true, auto_iou=0.5):loss = self.loss_fcn(pred, true)if auto_iou < 0.2:auto_iou = 0.2b1 = true <= auto_iou - 0.1a1 = 1.0b2 = (true > (auto_iou - 0.1)) & (true < auto_iou)a2 = math.exp(1.0 - auto_iou)b3 = true >= auto_ioua3 = torch.exp(-(true - 1.0))modulating_weight = a1 * b1 + a2 * b2 + a3 * b3loss *= modulating_weightif self.reduction == 'mean':return loss.mean()elif self.reduction == 'sum':return loss.sum()else:  # 'none'return loss

在data\hyps\hyp.scratch-low.yaml中增加

slide_ratio: 1 # >=1启用slide loss, <1关闭

在utils/loss.py的ComputeLoss函数中做如下修改:

class ComputeLoss:# Compute lossesdef __init__(self, model, autobalance=False):super(ComputeLoss, self).__init__()device = next(model.parameters()).device  # get model deviceh = model.hyp  # hyperparameters# Define criteriaBCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0))  # positive, negative BCE targets# slide lossself.slide_ratio = h['slide_ratio']if self.slide_ratio > 0:BCEcls, BCEobj = SlideLoss(BCEcls), SlideLoss(BCEobj)# Focal lossg = h['fl_gamma']  # focal loss gammaif g > 0:BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)det = model.module.model[-1] if is_parallel(model) else model.model[-1]  # Detect() moduleself.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02])  # P3-P7self.ssi = list(det.stride).index(16) if autobalance else 0  # stride 16 indexself.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalancefor k in 'na', 'nc', 'nl', 'anchors':setattr(self, k, getattr(det, k))def __call__(self, p, targets):  # predictions, targets, modeldevice = targets.devicelcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)lrepBox, lrepGT = torch.zeros(1, device=device), torch.zeros(1, device=device)tcls, tbox, indices, anchors = self.build_targets(p, targets)  # targets# Lossesfor i, pi in enumerate(p):  # layer index, layer predictionsb, a, gj, gi = indices[i]  # image, anchor, gridy, gridxtobj = torch.zeros_like(pi[..., 0], device=device)  # target objn = b.shape[0]  # number of targetsif n:ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets# Regressionpxy = ps[:, :2].sigmoid() * 2. - 0.5pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]pbox = torch.cat((pxy, pwh), 1)  # predicted boxiou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # iou(prediction, target)auto_iou = iou.mean()lbox += (1.0 - iou).mean()  # iou loss# Objectnesstobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio# Classificationif self.nc > 1:  # cls loss (only if multiple classes)t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targetst[range(n), tcls[i]] = self.cpif self.slide_ratio > 0:lcls += self.BCEcls(ps[:, 5:], t, auto_iou)  # BCEelse:lcls += self.BCEcls(ps[:, 5:], t)  # BCE# Append targets to text file# with open('targets.txt', 'a') as file:#     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]if self.slide_ratio > 0 and n:obji = self.BCEobj(pi[..., 4], tobj, auto_iou)else:obji = self.BCEobj(pi[..., 4], tobj)lobj += obji * self.balance[i]  # obj lossif self.autobalance:self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()if self.autobalance:self.balance = [x / self.balance[self.ssi] for x in self.balance]lbox *= self.hyp['box']lobj *= self.hyp['obj']lcls *= self.hyp['cls']bs = tobj.shape[0]  # batch sizeloss = lbox + lobj + lclsreturn loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()

 主要修改如下:

1、__init__中增加

        # slide lossself.slide_ratio = h['slide_ratio']if self.slide_ratio > 0:BCEcls, BCEobj = SlideLoss(BCEcls), SlideLoss(BCEobj)

2、计算完iou后增加

auto_iou = iou.mean()

3、在类别损失函数上

                    if self.slide_ratio > 0:lcls += self.BCEcls(ps[:, 5:], t, auto_iou)  # BCEelse:lcls += self.BCEcls(ps[:, 5:], t)  # BCE

4、前背景损失函数上

            if self.slide_ratio > 0 and n:obji = self.BCEobj(pi[..., 4], tobj, auto_iou)else:obji = self.BCEobj(pi[..., 4], tobj)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/116278.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023年“羊城杯”网络安全大赛 决赛 AWDP [Break+Fix] Web方向题解wp 全

终于迎来了我的第一百篇文章。 这次决赛赛制是AWDP。BreakFix&#xff0c;其实就是CTFFix&#xff0c;Fix规则有点难崩。Break和Fix题目是一样的。 总结一下&#xff1a;败北&#xff0c;还是太菜了得继续修炼一下。 一、Break ezSSTI 看到是SSTI&#xff0c;焚靖直接一把梭…

软件设计模式系列之十三——享元模式

1 模式的定义 享元模式&#xff08;Flyweight Pattern&#xff09;是一种结构型设计模式&#xff0c;它旨在减少内存占用或计算开销&#xff0c;通过共享大量细粒度对象来提高系统的性能。这种模式适用于存在大量相似对象实例&#xff0c;但它们的状态可以外部化&#xff08;e…

2023华为杯数学建模竞赛E题

一、前言 颅内出血&#xff08;ICH&#xff09;是由多种原因引起的颅腔内出血性疾病&#xff0c;既包括自发性出血&#xff0c;又包括创伤导致的继发性出血&#xff0c;诊断与治疗涉及神经外科、神经内科、重症医学科、康复科等多个学科&#xff0c;是临床医师面临的重要挑战。…

免费获取独立ChatGPT账户!!

GPT对于每个科研人员已经成为不可或缺的辅助工具&#xff0c;不同的研究领域和项目具有不同的需求。如在科研编程、绘图领域&#xff1a;1、编程建议和示例代码: 无论你使用的编程语言是Python、R、MATLAB还是其他语言&#xff0c;都可以为你提供相关的代码示例。2、数据可视化…

2023 年 Android 毕业设计选题推荐,200 道 Android 毕业设计题目,避免踩坑

前言 选择一个Android毕业设计题目是一个重要的决策&#xff0c;它将影响你未来几个月的工作。以下是一些关于如何选择一个合适的Android毕业设计题目以及如何避免踩坑的建议&#xff1a; 兴趣和热情&#xff1a;首先&#xff0c;选择你真正感兴趣的领域。如果你对某个领域充…

Python:Django框架的Hello wrold示例

Django是Python的目前很常用的web框架&#xff0c;遵循MVC设计模式。 以下介绍如何安装Django框架&#xff0c;并生成最简单的项目&#xff0c;输出Hello world。(开发工具VScode) 一、安装Django 在VScode终端控制台执行以下指令安装Django python install django 如果要查…

相机有俯仰角时如何将像素坐标正确转换到其他坐标系

一般像素坐标系转相机坐标系都是默认相机是水平的&#xff0c;没有考虑相机有俯仰角的情况&#xff0c;大致的过程是&#xff1a;像素坐标系统-->图像坐标系-->相机坐标系 ->世界坐标系或雷达坐标系: 像素坐标系 像素坐标系&#xff08;u&#xff0c;v&#xff09;是…

AIX360-CEMExplainer: MNIST Example

CEMExplainer: MNIST Example 这一部分屁话有点多&#xff0c;导包没问题的话可以跳过加载MNIST数据集加载经过训练的MNIST模型加载经过训练的卷积自动编码器模型&#xff08;可选&#xff09;初始化CEM解释程序以解释模型预测解释输入实例获得相关否定&#xff08;Pertinent N…

停车场系统源码

源码下载地址&#xff08;小程序开源地址&#xff09;&#xff1a;停车场系统小程序&#xff0c;新能源电动车充电系统&#xff0c;智慧社区物业人脸门禁小程序: 【涵盖内容】&#xff1a;城市智慧停车系统&#xff0c;汽车新能源充电&#xff0c;两轮电动车充电&#xff0c;物…

基于Android+OpenCV+CNN+Keras的智能手语数字实时翻译——深度学习算法应用(含Python、ipynb工程源码)+数据集(三)

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理2. 数据增强3. 模型构建4. 模型训练及保存1&#xff09;模型训练2&#xff09;模型保存 5. 模型评估 相关其它博客工程源代码下载其它资料下载 前言 本项目依赖于Keras深度学习模型&#xff0c;旨在对…

雷达编程实战之静态杂波滤除与到达角估计

雷达中经过混频的中频信号常常混有直流分量等一系列硬件设计引入的固定频率杂波&#xff0c;我们称之位静态杂波&#xff0c;雷达信号处理需要把这些静态杂波滤除从而有效的提高信噪比&#xff0c;实现准确的目标检测功能。 目标的到达角估计作为常规车载雷达信号处理的末端&am…

机器视觉康耐视Visionpro-脚本编写标记标识:点,直线,矩形,圆

显示标记标识的重要作用就是,对NG或者OK对操作机器视觉的人去看到具体位置缺陷或者NG坐标。 一.点CogPointMarker CogPointMarker PointMarker1 = new CogPointMarker();//创建对象,点CogPointMarker //注意运行工具 PointMarker1.X = 100; PointMarker1