Segment-anything学习到微调系列_SAM初步了解

## 前言

本系列文章是博主在工作中使用SAM模型时的学习笔记,包含三部分:

1. SAM初步理解,简单介绍模型框架,不涉及细节和代码
2. SAM细节理解,对各模块结合代码进一步分析
3. SAM微调实例,原始代码涉及隐私,此部分使用公开的VOC2007数据集,Point和Box作为提示进行mask decoder微调讲解

 

## 模型总览

SAM论文: https://arxiv.org/abs/2304.02643

SAM Github:https://github.com/facebookresearch/segment-anything

SAM在线demo: https://segment-anything.com/demo

 

SAM的一部分灵感是来源于NLP中的基座模型(Foundation Model),Foundation Model是OpenAI提出的一个概念,它指的是在超大量数据集上预训练过的大模型(如GPT系列、BERT),这些模型具有非常强大的 zero-shot 和 few-shot能力,结合prompt engineering和fine tuning等技术可以将基座模型应用在各种下游任务中并实现惊人的效果。

SAM就是想构建一个这样的图像分割基座模型,即使是一个未见过的数据集,模型也能自动或半自动(基于prompt)地完成下游的分割任务。为了实现这个目标,SAM定义了一种可提示化的分割任务(promptable segmentation task),这个提示可以是**点、框、掩码、文本**(代码中未实现)等形式,基于这个提示模型就能分割出提示处所在物体的masks。同时这种提示可以是模糊的,比如以下图剪刀握手那的黄色部分点为提示,分割掩码可以是下图最右边三种情况中任意一种,从上到下分别代表**whole, part, subpart**三种层级的分割,这也是SAM兼容的。要达到这种效果就需要足够的高质量分割数据,SAM团队用他们提出的Data Engine策略成功使用人工加模型自动标注的方式制作除了一个有10亿个masks的分割数据集**[SA-1B](https://ai.meta.com/datasets/segment-anything/)**,这也是他们核心的贡献之一,本文尾部会介绍相关流程。模型架构来说相对比较常规,主要是借鉴了ViT和DETR,本身创新不大。

 

 


如上图,SAM模型架构主要包括image encoder,prompt encoder和mask decoder三部分:

- image encoder,使用了ViT模型将图像编码得到image embedding
- prompt encoder,将point、box、mask、txt等提示信息进行编码,后续会和image embedding一起用于生成masks
- mask decoder,将上述两个模块得到的embeddings整合,然后结合两个可学习的tokens生成不同层级的masks和对应的置信度值

值得一提的是,**prompt encoder和mask decoder都是非常轻量的**,主要的计算开销都在image encoder上,这点从模型权重上也能看出来,以ViT_B为基础的SAM权重是375M,其中prompt encoder只有32.8k,mask decoder是16.3M(4.35%),剩余则是image encoder,可想而知图像编码这块是非常耗时的。因此在实际推理中,一般单张图的image embedding只计算一次,然后将结果缓存起来,需要的时候直接调用。在image embedding已经计算好的情况下,论文中说给定一个prompt,生成mask时prompt encoder和mask decoder在浏览器中的计算耗时也仅需50ms。下面会具体介绍下各模块的输入输出和流程,均只考虑batch size为1的情况,代码讲解在下一篇。

 

### Image encoder

**输入:**

默认是1024x1024的图像,如尺寸不一致会将原图按最长边resize

**输出:**

单张图的1x256x64x64的image embedding,即编码后的图像特征

 

#### 流程

 

上图是[ViT](https://arxiv.org/abs/2010.11929)ViT论文中的结构图,image encoder整体流程和ViT是一样的,区别在于不需要[class]token做分类,只输出最终的图像编码张量

* 输入1024的图,拆分成64x64的768维patchs
* 经过attention block(window和global的MSA,相对位置编码)和MLP得到同样大小64x64x768embbeding特征
* 再经过neck得到1x256x64x64的图片embedding

这块有一篇文字介绍的更详细,如果想了解更多细节可以看这篇:[Image encoder模块Vision Transformer网络解析](https://blog.csdn.net/yangyu0515/article/details/130200524)


### Prompt encoder

**输入:**

point、box、mask、txt(代码未实现)等prompt,格式一般如下,B为batch size

* point需要包含点的x,y坐标BxNx2和label(0为前景,1位背景)BxNx1
* box包含框的左上和右下两个点,BxNx4,对于某个gt即单个mask,只会有1个box;如果输入的是N个box最终会生成N个masks
* mask一般和SAM最终输出mask的hxw(256x256),Bx1xHxW
* txt在SAM代码中未实现,这块可以参考[Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything)

**输出两个:**

* **sparse_embeddings** 点和框的稀疏嵌入,形状为BxNx(embed_dim),其中**N由输入点和框的数量确定**,如果两者同时有则N的计算方式为(点的个数+2x框的个数)
* point box 全都没有,输出大小:Bx0x256
* 如果只有point,输出大小:Bx(N+1)x256,会补充一个[0,0]空点在最后,label为-1,表示只有点提示;
* 如果只有box,输出大小: (B\*N)x2x256
* piont、box都有,输出大小:BxNx256

* **dense_embeddings** 掩码的密集嵌入,形状为Bx(embed_dim)x(embed_H)x(embed_W),**默认大小为Bx256x64x64**,没有提示时会返回一个网络学习到的no mask默认嵌入

#### 流程

网络已自动学会了针对不通过类型提示的编码信息,输入的point、box、mask等提示加上位置编码后,再加上网络学会的综合编码信息,最终对point、box这种稀疏的提示会返回sparse embedding, 对mask会返回dense embeddings(没有mask提示时是网络学习到的embeddings)。这部分就相当于把各种提示转换为decoder能理解的格式。

 

### Mask decoder

**输入:**

* image encoder得到的image_embeddings和图像的positional encoding
* prompt encoder得到的prompt embeddings(sparse和dense两种)

**输出:**

* masks,如果指定了"multimask_output"参数则会输出3个层级的mask(whole, part, and subpart),否则只输出1个mask
* IoU scores,可以理解为每个mask的置信度,由网络中的iou token得到

#### 流程

* 首先会image_embeddings会混入dense embeddings的信息(两者直接相加),sparse embeddings则会与mask token和IoU token拼在一起成为一个新的token,mask token后续会用于生成mask,IoU token用于衡量每个mask的好坏

* 然后这个新的token和image_embeddings经过一个TwoWayTransformer模块(下图黄色框部分),先做token的self attention,然后做token(作为key)到图像的cross attention,经过MLP更新token,最后再图像(作为key)到token的attention,目的是不断更新图像和token中的信息,会重复两次

* 更新后token再做一次token(作为key)到图像的cross attention后,又拆出来之前的两个部分mask token和IoU token,后者就代表每个mask的置信度;

而图像信息经过转置卷积还原到原图大小后,会和mask token做矩阵乘法生成最终的masks,类似 [YOLACT](https://mp.weixin.qq.com/s/-kQ4uwvW9OQ0MmU03F8jDw)中的"prototype masks"和"mask coefficients"矩阵乘法


 

 


### 整图分割推理(segment everything)

 


#### 流程

在图片上生成32x32的网格,得到1024个采样点,每个采样点都当做1个前景的prompt进入prompt encoder然后和image encoder结果一起生成mask,每次会处理一个batch(默认64)的采样点;每个batch得到的mask都会进行以下几个过滤:

* predicted IoU过滤,mask decoder除了返回masks还会预测对应mask iou值,过滤低置信度(默认阈值0.88)的mask
* stability score过滤,stability score是mask在两个阈值下二值化后的IoU值,可以理解为改变过滤阈值后还能得到同样mask的能力,过滤低于0.95的mask
* mask threshold过滤,直接过滤mask logits值低于mask_threshold(默认0.0)的mask
* boundary过滤,每个mask生成外界矩形,过滤超过图像边界的mask

所有batch过滤后的的masks结果再进行nms过滤(mask对应外接矩形的nms,阈值0.7)就得到最终的分割结果

 

#### 最终结果

git上也有官方demo可以参考:[全图分割的官方demo](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb)

 

 

## 数据引擎(data engine)

SAM除了模型外,还公开了一份有10亿个masks的1100万张图的分割数据集**[SA-1B](https://ai.meta.com/datasets/segment-anything/)**,基于他们提出的data engine方案得到,这块的贡献也是非常显著,也体现了Data-centric AI的惊人能力,[这块知乎上"一堆废纸"博主介绍的比较好](如何评价Meta/FAIR 最新工作Segment Anything? - 一堆废纸的回答 - 知乎
https://www.zhihu.com/question/593888697/answer/2972047807)。从论文里总结就是辅助人工标注、半自动标注、全自动标注三步,具体如下:

* 第一步以人工标注为主。初始模型在公开数据集训练后辅助生成masks,再人工精修调整,再用标好的新数据迭代模型。如此重复6次,从12万张图得到430万masks
* 第二步是模型半自动标注高置信度masks,然后人工标注补充剩余未标出的masks。mask的置信度判断是用一个模型对mask进行目标检测,如果能检测出物体则是置信度较高mask无需再人工标注,这个目标检测模型是基于第一步得到的数据训练的。如此迭代5次,从18万张图新增了590万masks
* 第三部是模型全自动标注。基于此前两步的数据得到模型,已有较好的分割能力且能适配模糊提示分割(局部mask或者整体mask),对一张图撒32x32的网格点进行segment everything,后处理会挑选搞IoU和搞稳定性的masks并做NMS得到全图最终的masks。针对所有图片自动分割,最终得到了SA-1B数据集

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/773280.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个贝塞尔曲线编辑工具(2d)

曲线在unity下如何绘制? 类似绘制圆,是用一段一段的线段拼接来模拟的,这边也是类似,可以用一段一段的线段来模拟曲线。既然要模拟,那我们也得知道贝塞尔曲线的公式才行。 一般用的比较多的就是3次贝塞尔曲线,该曲线由起点p1,p1的控制点c1,终点p2,p2的控制点c2组成。公…

烧录算法制作

前言 在使用Keil的时候,我们一般会通过一个下载器与目标芯片连接,这样就可以实现的代码下载或调试。那么下载器是如何将我们的应用程序烧写在我们芯片内部Flash当中的呢,是否可以同样的方式烧录在外部Flash上呢?这是此片文章所要说明的。 MDK下载算法原理 通过MDK创建一批与…

Mocreak Office Installer(Office安装部署工具) v2.3.0.703 中文绿色版

概述 Mocreak 是一款一键自动化下载、安装、部署正版 Office 的办公增强工具。该工具完全免费、无广告、绿色、无毒、简约、高效、安全。软件特点 一键快速下载、安装、部署最新版 Microsoft Office 软件。提供简约、高效,且可自定义的图形界面,提升部署效率。支持将 Office …

LLM大模型:deepspeed实战和原理解析

多年前搞大数据,因为单节点无力存储和计算PB级别的数据,所以hadoop这种分布式存储和计算框架是标配!如今搞大模型,仍然需要对大量样本数据做计算,因为涉及矩阵运算,单机单卡运算效率太低,也涉及到分布式计算了,大模型时代的分布式pre-train和Inference框架就有现成的—…

Iterator与Iterable(迭代器与可迭代)

一 前言 环境: python 3.10 win10 二 Iterator(迭代器) Iterator 是python的内置类型之一,看下其定义该类型的实例对象称之为iterator(迭代器对象) 要得到一个iterator(迭代器对象),可用内置函数iter()将 list tuple等转成迭代器对象 也可以自定义一个迭代器类型的class,…

java基础 手写回忆篇

java 特性:分布行,跨平台性,安全行,多线程,面向对象编程,简单性 高级语言分为编译型和解释型: 编译型:整个程序写完一起编译速度快效率高 解释性:需要一句解释一句编译速度慢效率低 java是两者综合:编译器(javac)先把你写好的代码编译成class文件(字节码文件)再用j…

洛谷 Markdown - 从入门到精通

洛谷 Markdown - 从入门到精通 编写——Jerrycyx(CSDN,洛谷) 洛谷博客查看因为洛谷博客的渲染机制和其它地方不一样,可能导致渲染错误,所以你可以到这里食用:https://www.luogu.com.cn/paste/wu019n2x绪论希望更丰富的展现?使用 Markdown。这是洛谷文字编辑时会出现的一行…

Dev-C++ 的功能与外观优化

预备 安装 安装 Dev-C++ 5.11:官方下载:https://sourceforge.net/projects/orwelldevcpp/(若下载缓慢可选择 Problem Downloading->Auto-select) 蓝奏云下载:https://wwu.lanzouq.com/iTwwW07r28ni运行安装包即可。 更改语言 如果界面语言为英文,选择 Tools -> Env…

OI 中各种输入方式的速度比较(C++,大量实测数据,附图表)

测试信息 本次共测试了以下几种输入方式的速度:scanf cin 快读 位运算快读 fread() + 位运算快读 关闭同步流的 cin 开启 tie 绑定并关闭同步流的 cin每组测试各输入方式均使用相同数据,为随机生成的 \(1000000\)(1E6) 个整数,范围在 \([-(2^{31}-1),2^{31}-1]\)(即 int …

乒乓球比赛计分系统需求流程——最小可用产品

计应222_杜晓瑾_2210502012 乒乓球比赛计分系统需求流程——最小可用产品 作为一名裁判,我希望可以在系统上进行网络计时、记成绩,以便大家可以监督和观看任务 Sprint 1 2 3 4 5 6 7 8 9 10 开通网络计时(1h) 1 0 网络计时(4h) 4 2 1 0 进行网络记成绩(3h…

Easysearch、Elasticsearch、Amazon OpenSearch 快照兼容对比

在当今的数据驱动时代,搜索引擎的快照功能在数据保护和灾难恢复中至关重要。本文将对 Easysearch、Elasticsearch 和 Amazon OpenSearch 的快照兼容性进行比较,分析它们在快照创建、恢复、存储格式和跨平台兼容性等方面的特点,帮助大家更好地理解这些搜索引擎的差异,从而选…