CVPR2022人脸识别Partial FC论文及代码学习笔记

论文链接:https://openaccess.thecvf.com/content/CVPR2022/papers/An_Killing_Two_Birds_With_One_Stone_Efficient_and_Robust_Training_CVPR_2022_paper.pdf

代码链接:insightface/recognition/arcface_torch at master · deepinsight/insightface · GitHub

背景

使用基于百万规模的数据集和基于margin的softmax损失函数来学习区分性的embeddings是当前人脸识别的SOTA方法。然而,全连接层的内存和计算成本随着训练集中ID数量的增加而线性增加。此外,大规模训练数据存在类间冲突(同一个人被分成不同ID)和长尾分布的问题。

传统FC

将传统的FC层应用在大规模的数据集上时,存在以下缺陷:

1、gradient confusion under interclass conflict

WebFace42M里有很多不同类别对之间的余弦相似度大于0.4,这表明类间冲突仍然存在于这些清洗过的数据集中。直接优化的话会导致gradient confusion(同一个人的特征非常相似却要掰成两个ID)

2、centers of tail classes undergo too many passive updates

每个iteration都优化图片数量很少的id,可能会导致负优化

3、the storage and calculation of the FC layer can easily exceed current GPU capabilities

PartialFC

在训练期间仍然维护所有类别中心,但只随机采样一小部分负类别中心来计算基于margin的softmax损失,而不是在每次迭代中使用所有负类别中心。更具体地说,首先从每个GPU收集embeddings和标签,然后将组合的特征和标签分布到所有GPU。为了平衡每个GPU的内存使用和计算成本,为每个GPU设置了一个内存缓冲区(下面代码中的perm)。内存缓冲区的大小由类别总数和负类别中心的采样率决定。在每个GPU上,首先通过标签选择正类中心并放入缓冲区,然后随机选择一小部分负类中心(负类中心的数量为self.sample_rate * self.num_local)填充缓冲区的其余部分,

def sample(self, labels, index_positive):"""This functions will change the value of labelsParameters:-----------labels: torch.Tensorpassindex_positive: torch.Tensorpassoptimizer: torch.optim.Optimizerpass"""with torch.no_grad():positive = torch.unique(labels[index_positive], sorted=True).cuda()if self.num_sample - positive.size(0) >= 0:perm = torch.rand(size=[self.num_local]).cuda()perm[positive] = 2.0index = torch.topk(perm, k=self.num_sample)[1].cuda()index = index.sort()[0].cuda()else:index = positiveself.weight_index = indexlabels[index_positive] = torch.searchsorted(index, labels[index_positive])return self.weight[self.weight_index]

随后,使用选出的样本中心去与特征相乘并计算基于margin的softmax损失。

PFC在DDP框架下的流程图如下图所示,

整体代码如下,

class PartialFC_V2(torch.nn.Module):"""https://arxiv.org/abs/2203.15565A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).When sample rate less than 1, in each iteration, positive class centers and a random subset ofnegative class centers are selected to compute the margin-based softmax loss, all classcenters are still maintained throughout the whole training process, but only a subset isselected and updated in each iteration... note::When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).Example:-------->>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)>>> for img, labels in data_loader:>>>     embeddings = net(img)>>>     loss = module_pfc(embeddings, labels)>>>     loss.backward()>>>     optimizer.step()"""_version = 2def __init__(self,margin_loss: Callable,embedding_size: int,num_classes: int,sample_rate: float = 1.0,fp16: bool = False,):"""Paramenters:-----------embedding_size: intThe dimension of embedding, requirednum_classes: intTotal number of classes, requiredsample_rate: floatThe rate of negative centers participating in the calculation, default is 1.0."""super(PartialFC_V2, self).__init__()assert (distributed.is_initialized()), "must initialize distributed before create this"self.rank = distributed.get_rank()self.world_size = distributed.get_world_size()self.dist_cross_entropy = DistCrossEntropy()self.embedding_size = embedding_sizeself.sample_rate: float = sample_rateself.fp16 = fp16self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)self.class_start: int = num_classes // self.world_size * self.rank + min(self.rank, num_classes % self.world_size)self.num_sample: int = int(self.sample_rate * self.num_local)self.last_batch_size: int = 0self.is_updated: bool = Trueself.init_weight_update: bool = Trueself.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))# margin_lossif isinstance(margin_loss, Callable):self.margin_softmax = margin_losselse:raisedef sample(self, labels, index_positive):"""This functions will change the value of labelsParameters:-----------labels: torch.Tensorpassindex_positive: torch.Tensorpassoptimizer: torch.optim.Optimizerpass"""with torch.no_grad():positive = torch.unique(labels[index_positive], sorted=True).cuda()if self.num_sample - positive.size(0) >= 0:perm = torch.rand(size=[self.num_local]).cuda()perm[positive] = 2.0index = torch.topk(perm, k=self.num_sample)[1].cuda()index = index.sort()[0].cuda()else:index = positiveself.weight_index = indexlabels[index_positive] = torch.searchsorted(index, labels[index_positive])return self.weight[self.weight_index]def forward(self,local_embeddings: torch.Tensor,local_labels: torch.Tensor,):"""Parameters:----------local_embeddings: torch.Tensorfeature embeddings on each GPU(Rank).local_labels: torch.Tensorlabels on each GPU(Rank).Returns:-------loss: torch.Tensorpass"""local_labels.squeeze_()local_labels = local_labels.long()batch_size = local_embeddings.size(0)if self.last_batch_size == 0:self.last_batch_size = batch_sizeassert self.last_batch_size == batch_size, (f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda()for _ in range(self.world_size)]_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)distributed.all_gather(_gather_labels, local_labels)embeddings = torch.cat(_list_embeddings)labels = torch.cat(_gather_labels)## 选出落在本进程对应的类别范围内的数据labels = labels.view(-1, 1)index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)## 标签不在本类别段的, 将其类别标签设为-1labels[~index_positive] = -1## 将类别ID平移到原点(因为不同进程都会初始化对应的self.weight, 若不平移回去, 则label与self.weight中的index会对应不上)labels[index_positive] -= self.class_startif self.sample_rate < 1:weight = self.sample(labels, index_positive)else:weight = self.weightwith torch.cuda.amp.autocast(self.fp16):norm_embeddings = normalize(embeddings)norm_weight_activated = normalize(weight)logits = linear(norm_embeddings, norm_weight_activated)if self.fp16:logits = logits.float()logits = logits.clamp(-1, 1)logits = self.margin_softmax(logits, labels)loss = self.dist_cross_entropy(logits, labels)return loss

实验结果

将PFC替换掉传统FC后,模型在WebFace(包括4m、12m、42m)上的性能会有所提升,

 消融实验的结果如下,

与SOTA方法的性能对比如下, 

结论与讨论

结论

作者提出了一种用于在大规模数据集上训练人脸识别模型的方法——Partial FC (PFC)。在PFC的每次迭代中,仅选择一小部分类别中心来计算基于边际的softmax损失,这样可以显著减少类间冲突的概率、尾类中心的被动更新频率以及计算需求。通过广泛的实验,作者验证了所提出的PFC的有效性、鲁棒性和高效性。

局限性

尽管在WebFace上训练的PFC模型在高质量测试集上取得了不错的结果,但在人脸分辨率较低或低光照条件下拍摄的人脸上,PFC模型的表现可能较差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/701624.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

量化研究---A股赚钱日历,上证指数为例,提供源代码

今天把A股的全部数据导出做了一些赚钱日历分析&#xff0c;看那个月赚钱容易&#xff0c;那个月赚钱困难 导入需要的库 import pandas as pdimport matplotlib.pyplot as pltimport quantstats as qsfrom trader_tool.index_data import index_datafrom trader_tool import j…

BakedSDF: Meshing Neural SDFs for Real-Time View Synthesis 论文阅读

&#xff08;水一篇博客&#xff09; 项目主页 BakedSDF: Meshing Neural SDFs for Real-Time View Synthesis 作者介绍 是 Mildenhall 和 Barron 参与的工作&#xff08;都是谷歌的&#xff09;&#xff0c;同时一作是 Lipman 的学生&#xff0c;VolSDF 的一作。本文引用…

五分钟“手撕”时间复杂度与空间复杂度

目录 一、算法效率 什么是算法 如何衡量一个算法的好坏 算法效率 二、时间复杂度 时间复杂度的概念 大O的渐进表示法 推导大O阶方法 常见时间复杂度计算举例 三、空间复杂度 常见时间复杂度计算举例 一、算法效率 什么是算法 算法(Algorithm)&#xff1a;就是定…

24/05/14总结

签到2&#xff1a; 签到界面上有时间显示&#xff0c;签到码输入框&#xff0c;开始签到&#xff0c;当倒计时结束&#xff0c;老师端和学生端都会显示签到结果&#xff0c;所以签到结果需要建表&#xff1a;&#xff08;签到了的学生和未签到的学生&#xff0c; 这次签到的时间…

Elasticsearch优化手段

ES 的默认配置已经提供了良好的开箱即用的体验&#xff0c;但是仍有一些优化手段去继续提升它的使用性能。 一 General recommendations 通用建议。 01 Dont return large result sets 不要返回大量的结果集。ES 是一个搜索引擎&#xff0c;擅长于返回匹配度较高的几个文…

1.柔性数组

1.柔性数组 我们先来介绍一下什么是柔性数组&#xff1a; 在C语言中&#xff0c;柔性数组&#xff08;Flexible Array&#xff09;并不是一个标准的术语&#xff0c;但它通常指的是结构体中最后一个元素是一个没有指定大小的数组。这种结构体设计允许在运行时动态分配数组的大…

ES6之正则扩展

正则表达式扩展 u修饰符&#xff08;Unicode模式&#xff09;y修饰符&#xff08;Sticky或粘连模式&#xff09;s修饰符&#xff08;dotAll模式&#xff09;Unicode属性转义正则实例的flags属性字符串方法与正则表达式的整合 javascript的常用的正则表达式 验证数字邮箱验证手机…

Linux 第三十一章

&#x1f436;博主主页&#xff1a;ᰔᩚ. 一怀明月ꦿ ❤️‍&#x1f525;专栏系列&#xff1a;线性代数&#xff0c;C初学者入门训练&#xff0c;题解C&#xff0c;C的使用文章&#xff0c;「初学」C&#xff0c;linux &#x1f525;座右铭&#xff1a;“不要等到什么都没有了…

TortoiseGit的安装

TortoiseSvn和TortoiseGit都是针对代码进行版本管理的工具&#xff0c;又俗称小乌龟&#xff0c;简洁而可视化的操作界面&#xff0c;免去繁琐的命令行输入。只需要记住常用的几个操作步骤就能快速上手。 TortoiseGit安装 1、TortoiseGit作为git的版本管理工具 &#xff0c;但…

零基础10 天入门 Web3之第3天

10 天入门 Web3之第3天 什么是以太坊&#xff0c;以太坊能做什么&#xff1f;Web3 是互联网的下一代&#xff0c;它将使人们拥有自己的数据并控制自己的在线体验。Web3 基于区块链技术&#xff0c;该技术为安全、透明和可信的交易提供支持。我准备做一个 10 天的学习计划&…

粮油码垛机:自动化与智能化仓储的关键角色

在快速发展的现代化仓储物流领域&#xff0c;粮油码垛机正逐渐成为自动化与智能化仓储的关键角色。它以其高效、精准、节省人力的特点&#xff0c;赢得了众多粮油生产企业的青睐&#xff0c;成为仓储管理升级换代的明星产品。 一、粮油码垛机的技术革新 随着科技的发展&#…

【C语言】4.C语言数组(2)

文章目录 6. 二维数组的创建6.1 ⼆维数组的概念6.2 ⼆维数组的创建 7. 二维数组的初始化7.1 不完全初始化7.2 完全初始化7.3 按照⾏初始化7.4 初始化时省略⾏&#xff0c;但是不能省略列 8. 二维数组的使用8.1 ⼆维数组的下标8.2 ⼆维数组的输⼊和输出 9. 二维数组在内存中的存…