简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到 \bar z 作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding(\bar z)之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

图片

图片

Total Coding Rate(TCR)

公式如下:

图片

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):def __init__(self, num_patch = 4):self.num_patch = num_patchdef __call__(self, x):normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])aug_transform = transforms.Compose([transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),transforms.RandomGrayscale(p=0.2),GBlur(p=0.1),transforms.RandomApply([Solarization()], p=0.1),transforms.ToTensor(),  normalize])augmented_x = [aug_transform(x) for i in range(self.num_patch)]return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):net.zero_grad()opt.zero_grad()data = torch.cat(data, dim=0) data = data.cuda()z_proj = net(data)z_list = z_proj.chunk(num_patches, dim=0)z_avg = chunk_avg(z_proj, num_patches)# Contractive Lossloss_contract, _ = contractive_loss(z_list, z_avg)loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值(\bar z):

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):x_list = x.chunk(n_chunks,dim=0)x = torch.stack(x_list,dim=0)if not normalize:return x.mean(0)else:return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值(\bar z)的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):def __init__(self, ):super().__init__()passdef forward(self, z_list, z_avg):z_sim = 0num_patch = len(z_list)z_list = torch.stack(list(z_list), dim=0)z_avg = z_list.mean(dim=0)z_sim = 0for i in range(num_patch):z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()z_sim = z_sim/num_patchz_sim_out = z_sim.clone().detach()return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):z_list = z.chunk(num_patches,dim=0)loss = 0for i in range(num_patches):loss += criterion(z_list[i])loss = loss/num_patchesreturn loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的 Z 指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

图片

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76

class TotalCodingRate(nn.Module):def __init__(self, eps=0.01):super(TotalCodingRate, self).__init__()self.eps = epsdef compute_discrimn_loss(self, W):"""Discriminative Loss."""p, m = W.shape  #[d, B]I = torch.eye(p,device=W.device)scalar = p / (m * self.eps)logdet = torch.logdet(I + scalar * W.matmul(W.T))return logdet / 2.def forward(self,X):return - self.compute_discrimn_loss(X.T)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/67479.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

麦肯锡重磅发布2023年15项技术趋势,生成式AI首次入选,选对了就是风口

两位朋友在不同群里分享了同一份深度报告。 一位是LH美女,她在“AIGC时代”群里上传了这份文档,响应寥寥,可能是因为这些报告没有像八卦文那样容易带来冲击。 你看韩彬的这篇《金融妲己:基金公司女销售的瓜,一个比一个…

水库大坝安全监测系统实施方案

一、方案概述 水库大坝作为特殊的建筑,其安全性质与房屋等建筑物完全不同,并且建造在地质构造复杂、岩土特性不均匀的地基上,目前对于大坝监测多采用人工巡查的方法,存在一定的系统误差,其工作性态和安全状况随时都在变…

最新Kali Linux安装教程:从零开始打造网络安全之旅

Kali Linux,全称为Kali Linux Distribution,是一个操作系统(2013-03-13诞生),是一款基于Debian的Linux发行版,基于包含了约600个安全工具,省去了繁琐的安装、编译、配置、更新步骤,为所有工具运行提供了一个…

如何搭建个人邮件服务hmailserver并实现远程发送邮件

文章目录 1. 安装hMailServer2. 设置hMailServer3. 客户端安装添加账号4. 测试发送邮件5. 安装cpolar6. 创建公网地址7. 测试远程发送邮件8. 固定连接公网地址9. 测试固定远程地址发送邮件 hMailServer 是一个邮件服务器,通过它我们可以搭建自己的邮件服务,通过cpolar内网映射工…

母婴即时零售行业数据可视化分析

对新晋父母来说,很多母婴用品如同一位贴心的助手,为他们的宝宝提供温暖和呵护。从婴儿床垫到可爱的拼图玩具,每一件用品都是为宝宝的成长和发展量身定制。对于繁忙的父母们而言,这些用品不仅帮助照顾孩子,更是为他们减…

【BASH】回顾与知识点梳理(二十一)

【BASH】回顾与知识点梳理 二十一 二十一. Linux 的文件权限与目录配置21.1 使用者与群组属主(文件拥有者)属组(群组概念)其他人的概念root(万能的天神)Linux 用户身份与群组记录的文件 21.2 Linux 文件权限概念Linux 文件属性Linux 文件权限的重要性 21.3 如何改变文件属性与权…

21款美规奔驰GLS450更换中规高配主机,汉化操作更简单

很多平行进口的奔驰GLS都有这么一个问题,原车的地图在国内定位不了,语音交互功能也识别不了中文,原厂记录仪也减少了,使用起来也是很不方便的。 可以实现以下功能: ①中国地图 ②语音小助手(你好&#xf…

面向云思考安全

Gartner最近的一项研究表明,到 2025 年,85% 的企业会采用云战略,虽然这一数字是面向全球的,但可以看到在中国的环境中,基于云所带来的优势,越来越多的企业也同样开始积极向云转型。 但同时,有报…

SQL Server Reporting Services 报错:报表服务器无法访问服务帐户的私钥

解决这个问题,有小伙伴提到可以使用命令 exec DeleteEncryptedContent 但这对这边的环境时行不通的,我在【服务账户】的配置和【数据库】的配置中到使用了域账户,试了几次都不行。改成使用内置账户就好了。具体原因还没扒拉(欢迎…

【数据库】Sql Server可视化工具SSMS条件和SQL窗格以及版本信息

2023年,第34周,第1篇文章。给自己一个目标,然后坚持总会有收货,不信你试试! SQL SERVER 官方本身就有数据库可视化管理工具SSMS,所以大部分都会使用SSMS。以前版本是直接捆绑, 安装完成就自带有…

归并排序 与 计数排序

目录 1.归并排序 1.1 递归实现归并排序: 1.2 非递归实现归并排序 1.3 归并排序的特性总结: 1.4 外部排序 2.计数排序 2.1 操作步骤: 2.2 计数排序的特性总结: 3. 7种常见比较排序比较 1.归并排序 基本思想: 归并排序(MERGE-SORT)是建立在归并操作上的一种…

Vue使用jspdf和html2canvas组件库结合导出PDF文件

效果图: 1、安装依赖: npm install html2canvas --save npm install jspdf --save 或 yarn add html2canvas --save yarn add jspdf --save 2、封装全局调用方法:this.$exportPDF(#id,文件名) 新建js文件:/utils/html2Pdf.js&am…