Improved Deep Metric Learning with Multi-class N-pair Loss Objective

Improved Deep Metric Learning with Multi-class N-pair Loss Objective

来源:

  • NIPS’2016
  • NEC Laboratories America

文章目录

  • Improved Deep Metric Learning with Multi-class N-pair Loss Objective
    • Distance Metric Learning
    • Deep Metric Learning with Multiple Negative Examples
      • N-pair loss for efficient deep metric learning
    • 总结
    • 参考

找到这篇论文是因为看了淘宝搜索出品的论文Rethinking the Role of Pre-ranking in Large-scale E-Commerce 1,文中就提到了传统的list-wise损失 不适用于列表中存在多个正样本的场景。从样本构造的角度来看,这种方式应该也适用于多标签分类。

度量学习一直是我想了解的一个领域,就拿这篇论文做一个开始吧。

Distance Metric Learning

度量学习(metric learning)2,简言之:学习数据的嵌入表示,嵌入具有这样的性质,相似的数据点距离近不相似的数据点距离远。度量学习中常见的两种损失:对比损失和三元组损失,二者形式化的表示:
L c o n t ( x i , x j ; f ) = 1 { y i = y j } ∣ ∣ f i − f j ∣ ∣ 2 2 + 1 { y i ≠ y j } m a x ( 0 , m − ∣ ∣ f i − f j ∣ ∣ 2 ) 2 \mathcal{L}_{cont}(x_i, x_j; f) = \mathbb{1}\{y_i = y_j\}||f_i - f_j||_2^2 + \mathbb{1}\{y_i \neq y_j\}max(0, m - ||f_i - f_j||_2)^2 Lcont(xi,xj;f)=1{yi=yj}∣∣fifj22+1{yi=yj}max(0,m∣∣fifj2)2

L t r i ( x , x + , x − ; f ) = m a x ( 0 , ∣ ∣ f − f + ∣ ∣ 2 2 − ∣ ∣ f − f − ∣ ∣ 2 2 + m ) \mathcal{L}_{tri}(x, x^+, x^-; f) = max(0, ||f - f^+||_2^2 - ||f - f^-||_2^2 + m) Ltri(x,x+,x;f)=max(0,∣∣ff+22∣∣ff22+m)

其中 L c o n t \mathcal{L}_{cont} Lcont为对比损失(现在火起来的对比学习), L t r i \mathcal{L}_{tri} Ltri为三元组损失, f f f表示样本的嵌入。在对比损失中,要求来自同类别的样本距离近,不同类别的样本距离远;三元组损失中要求正( x + x^+ x+)、负( x − x^- x)样本到锚点( x x x,如搜图场景中的查询图)的距离要大于一定的阈值。

度量学习有一些现在很常见的应用,例如人脸识别、搜图等。度量学习的样本中通常只有一个负样本,容易导致收敛速度慢和局部最优的问题。难负样本挖掘(提一嘴:随着更多的实践,愈发觉得数据质量的重要性,如何构造好的数据集是一个值得研究的问题)能够减轻这些问题,但是如何找到难负样本本身就是一个难题。

与常见的三元组损失(triplet loss)中一个锚样本、一个正样本和一个负样本不一样,论文提出了一个 ( N + 1 ) (N+1) (N+1)元组的损失,来使一个正样本与 N − 1 N-1 N1个负样本区分开来。

Deep Metric Learning with Multiple Negative Examples

在三元组损失中,如果要使得损失尽可能低,显然有这么几种情况:

  • 缩短正样本与锚样本的距离;
  • 增大负样本与锚样本的距离;
  • 以上二者的结合。

从三元组损失的计算方式上也可以看出,再一次更新中只会比较锚样本与一个负样本,忽略了其他类别的负样本。这就导致:每次只能使锚样本远离一种负类,或许又被推到其他负类那里去了。最终学习到的嵌入可能会出现这样的情况:锚样本离训练数据中出现较多的负类远,而离某些负类又很近

当然,我们可以为锚样本配很多个三元组,囊括不同类别的负样本,这样在多轮、充足的训练后嵌入能够具有理想的性质。这样做就面临了不稳定以及收敛速度慢的问题。因此,文中就提出了 N + 1 N+1 N+1元组的损失,二者的区别如下图所示:
Triplet loss and (N+1)-tuplet loss

Deep metric learning with (left) triplet loss and (right) (N+1)-tuplet loss.

上图中红色的圆表示负样本,蓝色的表示锚样本和正样本。从左侧可以看出, N + 1 N+1 N+1元组损失的一个很简单的出发点:既然一个负类的样本不够,那就每个负类都拿一个样本出来,组成一个 N + 1 N+1 N+1的元组。但是在类别很多的场景(比如人脸识别),计算的复杂度过高。文章的重点就在于如何设计这样一个计算上可行的损失函数。

下图是三元组损失(a)、 ( N + 1 ) (N+1) (N+1)元组损失(b)及其改进后的损失©的一个对比。 N N N-pair-mc loss(multi-class N-pair loss)损失就是文章最后提出的损失。

N-pair-mc loss
Triplet loss, (N+1)-tuplet loss, and multi-class N-pair loss with training batch construction.

( N + 1 ) (N+1) (N+1)元组损失可以定义如下:
L ( { x , x + , { x i } i = 1 N − 1 } ; f ) = l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) \mathcal{L}(\{x, x^+, \{x_i\}_{i=1}^{N-1}\}; f) = log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) L({x,x+,{xi}i=1N1};f)=log(1+i=1N1exp(fTfifTf+))
N N N等于2的时候该损失是与三元组损失等价的。提一嘴,这个形式和softplus的形式是一样的:
s o f t p l u s ( x ) = l o g ( 1 + e x p ( x ) ) softplus(x) = log(1 + exp(x)) softplus(x)=log(1+exp(x))
( N + 1 ) (N+1) (N+1)元组的损失可以写为如下形式:
l o g ( 1 + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) = − l o g e x p ( f T f + ) e x p ( f T f + ) + ∑ i = 1 N − 1 e x p ( f T f i − f T f + ) ) log(1 + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+)) = - log \frac{exp(f^T f^+)} {exp(f^T f^+) + \sum_{i=1}^{N-1} exp(f^T f_i - f^T f^+))} log(1+i=1N1exp(fTfifTf+))=logexp(fTf+)+i=1N1exp(fTfifTf+))exp(fTf+)
这样一看是不是就更顺眼了,这不就是多分类里的softmax loss嘛。

N-pair loss for efficient deep metric learning

论文提出了一种高效的批构造方法,以降低额外的计算开销。方法的名字叫multi-class N N N-pair loss( N N N-pair-mc),其构造方式如上图 ( c )所示。来个说文解字,道一道作者的解决方法。方法名中有个N-pair,就从这入手。假若我们有 N N N个pair:
{ ( x 1 , x 1 + ) , ⋯ , ( x N , x N + } , y i ≠ y j , ∀ i ≠ j \{(x_1, x_1^+), \cdots, (x_N, x_N^+\},\ y_i \neq y_j, \forall i \neq j {(x1,x1+),,(xN,xN+}, yi=yj,i=j
每个pair的样本来自不同的类别,在这 N N N个pair的基础上构建 N N N个元组 { S i } i = 1 N \{S_i\}_{i=1}^N {Si}i=1N,其中:
S i = { x i , x 1 + , x 2 + , ⋯ , x N + } S_i = \{x_i, x_1^+, x_2^+, \cdots, x_N^+\} Si={xi,x1+,x2+,,xN+}
其中 x i x_i xi就是锚样本。显然, S i S_i Si就是一个包含了一个 i i i类别正样本, N − 1 N-1 N1个其他类别负样本的 N + 1 N+1 N+1元组了。因此,对于一个由 N N N个查询组成的batch,只需要准备 2 N 2 N 2N个样本,即 N N N个锚样本和 N N N个对应类别的正样本,每个batch只需要** 2 N 2 N 2N次前向计算**样本的嵌入就可以了。而在三元组损失和 N + 1 N+1 N+1元组损失中分别是 3 N 3 N 3N ( N + 1 ) N (N+1) N (N+1)N。因此,对于 N N N个查询组成的batch,其损失可以如下计算:
L N − p a i r − m c ( { ( x i , x i + } i = 1 N ; f ) = 1 N ∑ i = 1 N l o g ( 1 + ∑ j ≠ i e x p ( f i T f j + − f i T f i + ) ) \mathcal{L}_{N-pair-mc}(\{(x_i, x_i^+\}_{i=1}^N ; f) = \frac{1} {N} \sum_{i=1}^N log (1 + \sum_{j \neq i} exp(f_i^T f_j^+ - f_i^T f_i^+)) LNpairmc({(xi,xi+}i=1N;f)=N1i=1Nlog(1+j=iexp(fiTfj+fiTfi+))
以上就是论文的主要内容了,当然论文中还提到了负类别挖掘,这个就暂且不提了。

总结

简言之,这篇论文将度量学习中常见的三元组损失中只有一个负样本扩展到每个样本中包含 N − 1 N-1 N1个负样本,并且为了计算的效率提出了 N N N-pair的batch构造方法以降低计算量。其实,如果在三元组损失的batch中精心设计各种类别样本的配比,比如每个batch只训练一个类别,是否也能达到类似的效果呢?

参考


  1. Rethinking the Role of Pre-ranking in Large-scale E-Commerce, KDD 2023. ↩︎

  2. 漫谈-Distance Metric Learning那些事儿:https://zhuanlan.zhihu.com/p/458114525. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/60150.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

redis的持久化

第一章、redis的持久化 1.1)持久化概述 ①持久化可以理解为将数据存储到一个不会丢失的地方,Redis 的数据存储在内存中,电脑关闭数据就会丢失,所以放在内存中的数据不是持久化的,而放在磁盘就算是一种持久化。 ②为…

python实现简单的爬虫功能

前言 Python是一种广泛应用于爬虫的高级编程语言,它提供了许多强大的库和框架,可以轻松地创建自己的爬虫程序。在本文中,我们将介绍如何使用Python实现简单的爬虫功能,并提供相关的代码实例。 如何实现简单的爬虫 1. 导入必要的…

Python中的dataclass:简化数据类的创建

Python中的dataclass是一个装饰器,用于自动添加一些常见的方法,如构造函数、__repr__、__eq__等。它简化了创建数据类的过程,减少了样板代码,提高了代码的可读性和可维护性。有点类似java里面的Java Bean。 让我们看一个简单的例子…

Vscode无法写入文件 NoPermissions (FileSystemError): Error: EACCES: permission

用Vscode想要新建一个index.html的时候遇到了下图问题,说没有权限无法写入文件。 没有权限,咱们给他加上权限哈哈哈,博主是Mac电脑,如下操作: 1.找到你项目的根目录,右键,点击“显示简介”。 …

【非欧几里得域信号的信号处理】使用经典信号处理和图信号处理在一维和二维欧几里得域信号上应用低通滤波器研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

vue3+element-plus点击列表中的图片预览时,图片被表格覆盖

文章目录 问题解决 问题 视觉 点击图片进行预览&#xff0c;但还能继续选中其他的图片进行预览&#xff0c;鼠标放在表格上&#xff0c;那一行表格也会选中&#xff0c;如图所示第一行的效果。 代码 <el-table-column prop"id" label"ID" width"…

React源码解析18(1)------ React.createElement 和 jsx

1.React.createElement 我们知道在React17版本之前&#xff0c;我们在项目中是一定需要引入react的。 import React from “react” 即便我们有时候没有使用到React&#xff0c;也需要引入。原因是什么呢&#xff1f; 在React项目中&#xff0c;如果我们使用了模板语法JSX&am…

MySQL单表查询

单表查询 素材&#xff1a; 表名&#xff1a;worker-- 表中字段均为中文&#xff0c;比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float(8,2) NOT NULL, 政治面貌 varch…

工程监测仪器振弦传感器信号转换器应用于隧洞监测

工程监测仪器振弦传感器信号转换器应用于隧洞监测 隧洞建设是重大工程项目&#xff0c;监测隧洞结构和环境的变化对确保隧洞安全和运行管理至关重要。工程监测仪器是实现隧洞监测的关键设备&#xff0c;其中振弦传感器和信号转换器是非常重要的组成部分。 振弦传感器是一种专门…

《Linux从练气到飞升》No.10 冯洛依曼体系结构

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的…

Finalshell连接Linux超时之Connection timed out: connect

目录 &#x1f349;前言 &#x1f33c;报错 &#x1f33c;摸索 &#x1f4aa;解决措施 &#x1f349;前言 &#xff08;1&#xff09;福利&#xff1a;花了2小时才解决的BUG&#xff0c;希望本篇文章能帮你10分钟解决&#xff01; &#xff08;2&#xff09;tips&#xff1…

Apache DolphinScheduler 3.1.8 版本发布,修复 SeaTunnel 相关 Bug

近日&#xff0c;Apache DolphinScheduler 发布了 3.1.8 版本。此版本主要基于 3.1.7 版本进行了 bug 修复&#xff0c;共计修复 16 个 bug, 1 个 doc, 2 个 chore。 其中修复了以下几个较为重要的问题&#xff1a; 修复在构建 SeaTunnel 任务节点的参数时错误的判断条件修复 …