Supervised Contrastive 损失函数详解

在这里插入图片描述
有什么不对的及时指出,共同学习进步。(●’◡’●)

有监督对比学习将自监督批量对比方法扩展到完全监督设置,能够有效地利用标签信息。属于同一类的点簇在嵌入空间中被拉到一起,同时将来自不同类的样本簇推开。这种损失显示出对自然损坏很稳健,并且对优化器和数据增强等超参数设置更稳定。

有监督对比学习论文的贡献

  1. 提出了对比损失函数一种新的扩展,允许每个锚点都有多个正样本,使对比学习适应完全监督设置。
  2. 该损失为很多数据集的top-1的准确率带来了提升,对自然损坏有稳健性。
  3. 损失函数的梯度鼓励从硬正样本和硬的负样本中学习。(硬的正样本与锚点图像不相似的正样本,硬的负样本就是与锚点图像相似的负样本,都是难以学习的那种)
  4. 对比损失函数不如交叉熵损失函数对超参数敏感。

自监督对比学习损失
在这里插入图片描述
有监督对比学习损失
在这里插入图片描述
文中对交叉熵损失训练,自监督对比损失训练和有监督对比损失训练进行比较
在这里插入图片描述
推理模型中的参数个数始终保持不变,应该是推理的时候就是编码器+分类头都一样。
上图是训练的时候,交叉熵损失不必说。
自监督损失一般采用的是个体判别代理任务,正样本是自身经过数据增强后的图像(一般一个正样本),其他的都是负样本,训练编码器的时候让正样本和锚点图像经过编码器得到的特征尽可能接近,与负样本之间的特征尽可能拉远。
有监督对比学习,有标签信息,正样本除了自身数据增强后的之外还有这个类别中的其他样本(一般这个batch_size中)。
stage1就是训练编码器。
stage2是训练分类头,作者指出不需要训练线性分类器,并且先前的工作已经使用k -最近邻分类或原型分类来评估分类任务上的表示。线性分类器也可以与编码器联合训练,只要不将梯度传播回编码器即可,就是分类头和编码器之间训练要分开。
有监督对比学习损失代码
对比学习对比的是特征,所以损失函数的输入是特征,有监督对比学习损失还要输入标签信息。
损失函数就是模型的输出和标签(这里是mask)之间的差距,输出和标签差距越大,那么loss就越大。
输出这里是编码器的输出就是特征,标签就是类别标签。标签是如何起作用的呢?就是让损失函数区分这个batchsize中的正负样本,属于同一类就是正样本,其他都是负样本。
其中标签mask怎么获得,一个是通过label,另一个直接输入。label是每个数据的类别信息,label.view(1,-1)变成列向量然后再与它的转置进行torch.eq(),得到一个矩阵mask,mask(i,j)如果第i个数据和第j个数据类别相同那么这个位置是True,否则为False,float就变成0,1。后面乘了一个对角线元素为0,其他位置元素为1的矩阵,就是不让每个feature与自身对比。
我们看它self.contrast_mode="one"的时候只是比较feature中第0个特征(也就是平常的第一个特征),那么锚点特征就是所有数据的第0个特征;"all"就是所有的特征都要对比;锚点特征就是所有数据的所有特征。 torch.cat(torch.unbind(features, dim=1), dim=0)把feature按照第1维拆开,然后在第0维上cat,然后比较的feature的形式就是每一个数据的第1个特征|每个数据的第2个特征|…|每个数据的第n个特征,排列,这些特征是排在一起的在一个维度上。锚点特征要么是输入特征组的每个数据的第0个特征要么就是这些比较的特征。(不太理解为什么one的时候比较特征还是所有的)
锚点特征与比较特征的转置相乘,得到的就是batch_size*channel个相似矩阵,每两个数据在这个特征下的相似度。然后这个相似度矩阵要和我们得到的mask进行比较,就是上面的第二个式子。
下面是详细解释。

"""
Author: Yonglong Tian (yonglong@mit.edu)
Date: May 07, 2020
"""
from __future__ import print_functionimport torch
import torch.nn as nnclass SupConLoss(nn.Module):"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.It also supports the unsupervised contrastive loss in SimCLR"""def __init__(self, temperature=0.07, contrast_mode='all',base_temperature=0.07):super(SupConLoss, self).__init__()self.temperature = temperatureself.contrast_mode = contrast_mode#设置对比的模式有one和all两种,代表对比一个channel还是所有,个人理解self.base_temperature = base_temperature #设置的温度def forward(self, features, labels=None, mask=None):"""Compute loss for model. If both `labels` and `mask` are None,it degenerates to SimCLR unsupervised loss:https://arxiv.org/pdf/2002.05709.pdfArgs:features: hidden vector of shape [bsz, n_views, ...].labels: ground truth of shape [bsz].mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample jhas the same class as sample i. Can be asymmetric.Returns:A loss scalar."""device = (torch.device('cuda')#设置设备if features.is_cudaelse torch.device('cpu'))if len(features.shape) < 3:raise ValueError('`features` needs to be [bsz, n_views, ...],''at least 3 dimensions are required')if len(features.shape) > 3:# batch_size, channel,H,W,平铺变成batch_size, channel, (H,W)features = features.view(features.shape[0], features.shape[1], -1)batch_size = features.shape[0]if labels is not None and mask is not None:#只能存在一个raise ValueError('Cannot define both `labels` and `mask`')elif labels is None and mask is None:#如果两个都没有就是无监督对比损失,mask就是一个单位阵mask = torch.eye(batch_size, dtype=torch.float32).to(device)elif labels is not None:#有标签,就把他变成masklabels = labels.contiguous().view(-1, 1)#contiguous深拷贝,与原来的labels没有关系,展开成一列,这样的话能够计算mask,否则labels一维的话labels.T是他本身捕获发生转置if labels.shape[0] != batch_size:raise ValueError('Num of labels does not match num of features')mask =  torch.eq(labels, labels.T).float().to(device)#label和label的转置比较,感觉应该是广播机制,让label和label.T都扩充了然后进行比较,相同的是1,不同是0.#这里就是由label形成mask,mask(i,j)代表第i个数据和第j个数据的关系,如果两个类别相同就是1, 不同就是0else:mask = mask.float().to(device)#有mask就直接用mask,mask也是代表两个数据之间的关系contrast_count = features.shape[1]#对比数是channel的个数contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)#把feature按照第1维拆开,然后在第0维上cat,(batch_size*channel,h*w..)#后面就是展开的feature的维度#这个操作就和后面mask.repeat对上了,这个操作是第一个数据的第一维特征+第二个数据的第一维特征+第三个数据的第一维特征这样排列的与mask对应if self.contrast_mode == 'one':#如果mode=one,比较feature中第1维中的0号元素(batch, h*w)anchor_feature = features[:, 0]anchor_count = 1elif self.contrast_mode == 'all':#all就(batch*channel, h*w)anchor_feature = contrast_featureanchor_count = contrast_countelse:raise ValueError('Unknown mode: {}'.format(self.contrast_mode))# compute logitsanchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),#两个相乘获得相似度矩阵,乘积值越大代表越相关self.temperature)# for numerical stabilitylogits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)#计算其中最大值logits = anchor_dot_contrast - logits_max.detach()#减去最大值,都是负的了,指数就小于等于1# tile maskmask = mask.repeat(anchor_count, contrast_count)#repeat它就是把mask复制很多份# mask-out self-contrast caseslogits_mask = torch.scatter(#生成一个mask形状的矩阵除了对角线上的元素是0,其他位置都是1, 不会对自身进行比较torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0)mask = mask * logits_mask# compute log_probexp_logits = torch.exp(logits) * logits_mask#定义其中的相似度log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))#softmax# compute mean of log-likelihood over positive# modified to handle edge cases when there is no positive pair# for an anchor point. # Edge case e.g.:- # features of shape: [4,1,...]# labels:            [0,1,1,2]# loss before mean:  [nan, ..., ..., nan] mask_pos_pairs = mask.sum(1)#mask的和mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)#满足返回1,不满足返回mask_pos_pairs.保证数值稳定mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs# lossloss = - (self.temperature / self.base_temperature) * mean_log_prob_pos#类似蒸馏temperature温度越高,分布曲线越平滑不易陷入局部最优解,温度低,分布陡峭loss = loss.view(anchor_count, batch_size).mean()#计算平均return loss

使用的化就是下面这段:

loss = criterion(features, labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/422567.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【表情识别阅读笔记】Towards Semi-Supervised Deep FER with An Adaptive Confidence Margin

论文名&#xff1a; Towards Semi-Supervised Deep Facial Expression Recognition with An Adaptive Confidence Margin 论文来源&#xff1a; CVPR 发表时间&#xff1a; 2022-04 研究背景&#xff1a; 对大量图片或视频进行手工标注表情是一件极其繁琐的事情&#xff0c;因此…

eNSP学习——部分VLAN间互通、部分VLAN间隔离、VLAN内用户隔离(MUX-VLAN)

MUX VLAN&#xff08;Multiplex VLAN&#xff09;提供了一种通过VLAN进行网络资源控制 的机制。通过MUX VLAN提供的二层流量隔离的机制可以实现企业内部员 工之间互相通信&#xff0c;而企业外来访客之间的互访是隔离的。 特点&#xff1a; 一、主VLAN端口可以和所有VLAN通信 二…

设计亚马逊按销售排名功能

1&#xff1a; 定义 Use Cases 和 约束 Use cases 作用域内的Use Case Service 通过目录计算过去一周内最受欢迎的产品User 通过目录去View过去周内最受欢迎的产品Service 有高可用 作用域外 整个电商网站 设计组件&#xff08;只是计算销售排名&#xff09; 约束和假设…

Windows云服务器如何配置多用户登录?(Windows 2012)华为云官方文档与视频地址

Windows云服务器如何配置多用户登录&#xff1f;&#xff08;Windows 2012&#xff09;_弹性云服务器 ECS_故障排除_多用户登录_华为云 打开任务栏左下角的“服务器管理器”&#xff0c;在左侧列表中选中“本地服务器” 然后将右侧“远程桌面”功能的选项修改为“启用”&#x…

插混、增程、纯电为什么说纯电是未来的趋势

技术路线&#xff1a;插混、增程、纯电趋势判断 新能源汽车目前有纯电动、增程式、插电式3 种主流技术路径&#xff0c;其中增程式和插电式均为混动技术。纯电动汽车是指以动力电池为动力&#xff0c;用电机驱动车轮行驶&#xff1b;混动技术分为串联、并联、混联3 种模式&…

计算机网络——第四层:传输层以及TCP UDP

1. 传输层的协议 1.1 TCP (传输控制协议) - rfc793 连接模式的传输。 保证按顺序传送数据包。 流量控制、错误检测和在数据包丢失时的重传。 用于需要可靠传输的应用&#xff0c;如网络&#xff08;HTTP/HTTPS&#xff09;、电子邮件&#xff08;SMTP, IMAP, POP3&#xff09;…

Mybatis 动态SQL条件查询(注释和XML方式都有)

需求 : 根据用户的输入情况进行条件查询 新建了一个 userInfo2Mapper 接口,然后写下如下代码,声明 selectByCondition 这个方法 package com.example.mybatisdemo.mapper; import com.example.mybatisdemo.model.UserInfo; import org.apache.ibatis.annotations.*; import j…

跟着我学Python进阶篇:03. 面向对象(下)

往期文章 跟着我学Python基础篇&#xff1a;01.初露端倪 跟着我学Python基础篇&#xff1a;02.数字与字符串编程 跟着我学Python基础篇&#xff1a;03.选择结构 跟着我学Python基础篇&#xff1a;04.循环 跟着我学Python基础篇&#xff1a;05.函数 跟着我学Python基础篇&#…

Django框架二

一、模型层及ORM 1.模型层定义 负责跟数据库之间进行通信 2.Django配置mysql 安装mysqlclient&#xff0c;mysqlclient版本最好在13.13以上 pip3 install mysqlclient DATABASES {default: {ENGINE: django.db.backends.mysql,NAME: "mysite1",USER:root,PASSWO…

U-Boot 中使用 nfs 命令加载文件报错指南

目录 问题一问题描述错误原因解决方案 问题二问题描述解决方案 更多内容 在嵌入式 Linux 开发中&#xff0c;我们经常使用 nfs 命令加载服务端的共享文件或者挂载文件系统。关于服务端 NFS 服务的搭建可以参考基于 NFS 的文件共享实现。 U-Boot 也支持了 nfs 命令&#xff0c;…

JRT和springboot比较测试

想要战胜他&#xff0c;必先理解他。这两天系统的学习Maven和跑springboot工程&#xff0c;从以前只是看着复杂&#xff0c;现在到亲手体验一下&#xff0c;亲自实践的才是更可靠的了解。 第一就是首先Maven侵入代码结构&#xff0c;代码一般要按约定搞src/main/java。如果是能…

2526. 随机数生成器(BSGS,推导)

题目路径&#xff1a; https://www.acwing.com/problem/content/2528/ 思路&#xff1a;