【视觉大模型SAM系列】PerSAM:Personalize Segment Anything Model with One Shot

【版权声明】
本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。

更多算法总结请关注我的博客:https://blog.csdn.net/suiyingy,或”乐乐感知学堂“公众号。
本文章来自于专栏《大模型》的系列文章,专栏地址为:https://blog.csdn.net/suiyingy/category_12473256.html。

        PerSAM是基于SAM(Segment Anything Model)进行改进的模型,并且实现OneShot功能。SAM一般需要人工选择先验的提示词,包括点、矩形框box和分割mask,并根据提示词分割出目标。PerSAM采用一张参考图片的目标(OneShot)来自动选择前景点和背景点,并作为SAM的提示词,而不需要手动选择。针对目标可能存在歧义的部分,PerSAM-F通过训练两个权重参数来对SAM输出的3组mask进行加权求和,进而使得分割更加准确。另一方面,PerSAM将第一次预测结果再次经过两次SAM来进行后处理微调。以上3点即为PerSAM模型的核心思想。其论文地址为“https://arxiv.org/abs/2305.03048”。

图1 PerSAM模型总体结构

        下面将结合PerSAM的程序来分析具体的模型结构。

1 环境安装

        PerSAM模型程序的下载地址为“https://github.com/ZrrSkywalker/Personalize-SAM”。Python环境搭建过程如下所示。由于大模型大多比较新,因而考虑到环境的兼容性,我们最好安装比较高版本的Python、CUDA、Pytorch。当前Pytorch安装的是1.13.1版本,很多大模型在2.0版本以上运行速度更快,但是至少需要安装CUDA 11.7。一般情况下CUDA套件的版本不能高于显卡驱动版本,否则可能会带来兼容性问题。

conda create -n persam python=3.10 -y
conda activate persam
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 #安装pytorch CUDA 11.6
git clone https://ghproxy.com/https://github.com/ZrrSkywalker/Personalize-SAM.git #github代理:https://ghproxy.com/,如果GitHub无法直接连接下载,可采用该方式。
cd Personalize-SAM
pip install -r requirements.txt -i  https://pypi.tuna.tsinghua.edu.cn/simple

2 数据下载

2.1 PerSeg.zip

        测试数据PerSeg.zip下载地址为“下载地址1:https://drive.google.com/file/d/18TbrwhZtAPY5dlaoEqkPa5h08G9Rjcio/view?usp=sharing”或“下载地址2:https://pan.baidu.com/s/1X-czD-FYW0ELlk2x90eTLg(提取码:222k)”。

        解压后包括Annotations和Images两个文件夹,共41组目标语义分割类别。Images存储输入图片,同一类别图片在一个文件夹下,文件夹名称表示类别。Annotations存储图片分割标注,采用图片mask,目标区域像素值为[128, 0, 0],背景为全黑[0, 0, 0]。解压后如下图所示,将data3重命名为data,并放到模型工程目录下即可。

图2 PerSeg示例图片

2.2 训练评估数据DAVIS

        数据集下载地址为“https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip”。目录结构如下图所示。

图3 DAVIS目录结构

2.3 SAM预训练模型

        SAM预训练模型名称为sam_vit_h_4b8939.pth,下载地址为“https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth”。

3 SAM输入

        SAM论文地址为“https://arxiv.org/abs/2304.02643”,模型的整体结构如下图所示。

图4 SAM模型结构

        SAM的输入包括图像、分割mask、提示词(点points、矩形框box、文本text)。PerSAM不涉及文本提示词。图像输入尺寸为1024x1024。因此,图像将按照长边缩放到1024,短边则填充至1024像素,另外图像还将根据均值和标准差进行归一化。

        图形prompt包括点points和矩形框box两种类型。点由坐标及其标签组成。矩形框由左上和右下两个点组成。点和矩形框相对于语义分割任务来说属于一种稀疏表示方法,这是因为分割需要对每个点进行区分,若干个点相对于全部点来说是稀疏的。换句话说,点和矩形框是分割结果的近似表示形式。程序中点和矩形框进行特征变换,成为256维特征embedding。程序中定义为稀疏特征sparse_embedding。该特征维度为bxkx256,b表示batch size,以下将设置为1,k与点数和矩形框数量相关,256为特征维度)。由于矩形框由左上和右下两个点组成,每个点有一个embedding,因此一个矩形框有两个embedding,即1x2x256。算法会设置一个空白的无效sparse_embedding(1x256)用于表示没有输入矩形框prompt的情况。

        相比而言,mask是对区域进行分割标记,维度与图像特征相同,因而是稠密的。程序中定义为稠密特征dense_embedding(1x256x64x64,64为特征图尺寸)。在后续transformer变换过程中,该特征会与图像特征进行叠加。在没有mask输入时,程序也会设置一个无mask输入时的dense_embedding(1x256x64x64,no_mask_embedding)。

        图像经过特征提取模块image encoder之后的图像特征image_embedding维度为1x256x64x64,即1x4096x256,并且与mask的特征dense_embedding进行叠加。与此同时,模型会也会对各个像素位置进行特征编码,image_pe,1x256x64x64。位置编码属于transformer的常规操作。

4 transformer

        SAM的核心结构为transformer,而transformer的关键参数为Q(query)、K(key)、V(value)。Q可认为是我们要查询的结果或问题表述或提示信息,例如语义分割mask、mask的质量分        数、提示词embedding等。我们需要通过注意力计算出其对应的特征属性,一般是V的叠加。特征属性本质上来源于图像。因此,SAM第一个transformer的Q(tokens)由mask质量分数(1x1x256)、mask(1x4x256)、sparse_encoding(1x2x256,包含1个点和1个padding)组成,维度为1x7x256。

        Transformer根据K对Q的贡献程度来,得到Q各个元素的注意力特征V。通常Q和K会将位置进行关联,包含位置特征,而V则不包含位置特征。可以理解为,位置特征是不变的,不会随着特征变换发生改变,并不需要进行注意力叠加。例如,目标的矩形框位置坐标在特征图的位置实际上是固定的。

        QKV注意力的公式如下:

        由于每个Q元素会跟所有K计算softmax注意力权重,并将加权求和的结果作为注意力特征,因而其长度与A保持一致。假设Q的维度为BxLQxC,K的维度为BxLKxC,那么输出注意力的特征的维度也为BxLQxC。在多层注意力叠加时,模型会将Q与注意力结果进行一次叠加融合(二者维度完全一致,可直接进行叠加求和),并进行归一化操作。需要特别注意,叠加一般是对特征进行操作,不需要引入位置特征。

5 PerSAM-F模型结构

        程序中PerSAM-F的模型结构保持与SAM基本一致,只是对输出的mask做了加权参数优化。结果相比PerSAM更加简洁,因而这里先对其进行介绍。PerSAM-F模型总体过程主要包括相似度计算、SAM注意力特征提取层、解码层和mask权重微调训练三个部分,下面将分别进行详细介绍。

5.1 相似度计算

        模型首先根据参考图像mask区域特征采用均值与最大值相结合的方法计算出One-Shot的目标特征,维度为1x256,然后计算该特征与整图每个点特征的余弦相似度,并将相似度最高的作为prompt的输入点,即目标前景点。

        因此,SAM的prompt输入为参考图像及其相似度最高的输入点。具体选择过程如下所示。

# Image features encodingref_mask = predictor.set_image(ref_image, ref_mask) # 1x3x1024x1024ref_feat = predictor.features.squeeze().permute(1, 2, 0) # 64x64x256ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear") # 1x3x64x64ref_mask = ref_mask.squeeze()[0] # 64x64# Target feature extractiontarget_feat = ref_feat[ref_mask > 0] # K x 256target_feat_mean = target_feat.mean(0) # 256
target_feat_max = torch.max(target_feat, dim=0)[0] # 256
target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0) # 均值与最大值合并,1x256
# Cosine similarity
h, w, C = ref_feat.shape # 64x64x256
target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True) # 归一化,1x256
ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)# 归一化,64x64x256
ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w) # 256x4096
sim = target_feat @ ref_feat # 1x4096
sim = sim.reshape(1, 1, h, w) # 1x1x64x64
sim = F.interpolate(sim, scale_factor=4, mode="bilinear") # 1x1x256x256
sim = predictor.model.postprocess_masks(sim,input_size=predictor.input_size,original_size=predictor.original_size).squeeze() # 插值回原图尺寸,hxw
# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1) # 选择相似度最高的像素点坐标,1x2,[1]

5.2 SAM注意力特征提取层

        SAM注意力特征提取层主要由transformer结构组成,其结构如下图所示。

图5 SAM注意力特征提取层

        其结构由两个transformer 结构堆叠组成,可通过深度depth参数来设置堆叠的数量。每一个transformer 结构包含一个自注意力模块和两个交叉注意力模块。最终输出融合后的tokens预测结果hs和图像特征src,二者维度分别为1x7x256和1x4096x256。

5.2.1 第1个transformer

        第1组transformer的输入如下:

        Q:tokens,1x7x256,包括预测目标和prompt特征。

             K:image_embedding,下文用IE表示,1x4096x256,包括图像和mask(dense-embedding)特征。

        Q_PE:Prompt位置编码,tokens,1x7x256。

        K_PE:图像特征位置编码,image_pe,下文用IE_P表示,1x4096x256。

        (1)自注意力

        q=k=v=Q,学习tokens特征之间的关联性,输出注意力结果作为新的Q(1x7x256)。注意力的head数量为8,特征为32,QKV全连接为Linear(256, 256),输出全连接为Linear(256, 256),norm层为LayerNorm(256)。

        (2)交叉注意力 tokens attending to image embedding

        tokens特征中融入图像特征。q= Q + Q_PE,k = K + K_PE,V=K。根据tokens对各个图像像素特征的关注程度,得到新的注意力特征attn_out(1x7x256),该特征进一步与Q进行叠加作为新的Q(1x7x256),从而使得原始tokens中融入图像特征。注意力的head数量为8,特征为16,QKV全连接为Linear(256, 128),输出全连接为Linear(128, 256),norm层为LayerNorm(256)。

        (3)MLP

        对Q的特征使用全连接层进行融合,全连接层为Linear(256, 2048)和Linear(2048, 256)。

mlp_out = self.mlp(queries) # 1x7x256
queries = queries + mlp_out # 1x7x256
queries = self.norm3(queries) # 1x7x256

        (4)交叉注意力 image embedding attending to tokens

        图像特征中融入tokens特征。q = K + K_PE,k= Q + Q_PE,V=Q。根据各个图像像素特征对tokens的关注程度,得到新的注意力特征attn_out(1x4096x256),该特征进一步与K进行叠加作为新的K(1x4096x256),从而使得图像特征中融入tokens特征。注意力的head数量为8,特征为16,QKV全连接为Linear(256, 128),输出全连接为Linear(128, 256),norm层为LayerNorm(256)。

        经过第1个ransformer后,tokens特征融入了图像特征,得到新的Q(1x7x256),图像特征也融入了tokens特征,得到了新的K(1x4096x256)。但是,位置编码是不变的,仍然分别保持为Q_PE和K_PE。

5.2.2 第二个transformer

        自注意力过程中由于Q已经融合了新的特征,而不再是由token的位置特征组成,因此q=Q + Q_pe,K=Q + Q_pe,V=Q。

        交叉注意力机制和MLP与第一个transformer一致,相当于特征的再次融合。融合后的Q和K的维度仍然分别为1x7x256和1x4096x256。

5.2.3 final attention layer

        上面我们分别得到了融合后的Q和K,但是最后两者还未进行关联融合人。因此,这里需要再次计算Q对K注意力结果,并与Q叠加得到新的Q。最终,经过注意力融合后的Q和K的维度仍然分别为1x7x256和1x4096x256,分别对应下面程序的hs和src。

# Run the transformer
hs, src = self.transformer(src, pos_src, tokens, attn_sim, target_embedding) # 1x7x256,1x4096x256

5.3 解码层

        PerSAM-F的解码层结构如下图所示。

图6 PerSAM-F解码层

        hs[:, 0, :]经过3层全连接Linear(256, 256)、Linear(256, 256)和Linear(256, 4)得到mask的质量得分iou_pred(1x4)。

        hs[:, 1:5, :]分别经过3层全连接Linear(256, 256)、Linear(256, 256)和Linear(256, 32)得到4组图像特征权重hyper_in(4x32)。

        src(1x256x64x64)经过逆卷积上采样操作后得到最终图像特征32x256x256,权重hyper_in(4x32)分别对图像特征进行加权求和得到4组分割masks(1x4x256x256)。

        在多mask输出时选择后3个mask为预测结果,那么mask和iou_pred的维度分别为1x3x256x256和1x3。

# Select the correct mask or masks for output
if multimask_output:mask_slice = slice(1, None)
else:mask_slice = slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
# Prepare output
return masks, iou_pred

        此时mask的分辨率是256x256,相当于原图的下采样,因而用low_res_masks表示。低分辨率mask经过上采样插值后的高分辨mask(high_res_masks),维度为1x3xhxw。最终有效目标masks是high_res_masks中大于0处的点。

high_res_masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)#将mask插值到原始图像分辨率,1x3xhxw
masks = high_res_masks > self.model.mask_threshold  # 0.0
return masks, iou_predictions, low_res_masks, high_res_masks

5.4 mask权重微调训练

mask权重微调训练过程如下图所示。

图7 mask权重微调训练过程

        PerSAM-F对high_res_masks中的3组mask进行加权求和得到最终mask,且权重组成为(w1, w2, 1-w1-w2)。可以看到,微调只需学习w1和w2这两个参数。加权得到的mask与真实mask进行损失计算,包括Dice Loss和Foca Loss,这也是两种比较常规的语义分割损失。

6 PerSAM-F推理过程

        训练完成之后得到经过优化的后的mask权重参数w1和w2。推理过程包括初步预测和两次后处理过程。

6.1 初步预测

        与训练过程一致,采用相似度最高的点与SAM Decoder获取masks、 scores、 logits、 logits_high。masks维度为3xhxw,是high_res_masks大于0的布尔运算结果;scores维度为3,是3组mask的质量得分,即iou_pred;logits维度为3x256x256,即low_res_masks;logits_high维度为3xhxw,即high_res_masks。

        logits_high经过加权求和得到预测的mask,并计算mask大于0区域的外接矩形,将该矩形作为初步预测的矩形框box。

6.2 使用SAM后处理微调

        模型仍然使用SAM来进行预测,此时prompt不仅包含相似度最大的点,还包括第一步预测的mask的外接矩形框,并且将第一步预测的logits作为mask prompt。因此输入tokens的维度为1x8x256,包括mask质量分数(1x1x256)、mask(1x4x256)、sparse_encoding(1x3x256,包含1个点和1个矩形框)。Decoder再次输出masks、scores、 logits、 logits_high。这次不再使用加权求和的方法来预测结果,而是根据最优的质量得分选择从masks中选择对应的维度的mask,属于常规的SAM预测操作。这里没有使用加权求和的原因在于后处理是独立的,即后处理方法为SAM模型。最优mask仍然对应一个外接矩形。

6.3 再次使用SAM后处理微调

        微调过程与上一次完全一致,并且最终输出最优的mask作为预测结果。

best_idx = np.argmax(scores)
final_mask = masks[best_idx] # hxw

7 PerSAM

        PerSAM与PerSAM-F的区别在于:

  1. 没有训练微调。
  2. 增加背景点输入。
  3. 引入相似度注意力机制。
  4. 推理阶段输出单mask。

        根据图像特征与参考目标的平均特征分别选择相似度最高的点作为前景点(label为1),选择相似度最低的点作为背景点(label为0)。另一方面,模型将图像与参考目标特征的相似度作为注意力的一部分,在Token-to-Image中进行注意力叠加,如下图所示。

图8 PerSAM模型结构

# Obtain the target guidance for cross-attention layers
sim = (sim - sim.mean()) / torch.std(sim) # 归一化,hxw
sim = F.interpolate(sim.unsqueeze(0).unsqueeze(0), size=(64, 64), mode="bilinear") # 1x1x64x64
attn_sim = sim.sigmoid_().unsqueeze(0).flatten(3) # 转换为0~1概率,相当于每个像素的受关注程度,1x1x1x4096
attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
if attn_sim is not None:attn = attn + attn_simattn = torch.softmax(attn, dim=-1)

        在Token-to-Image过程中,每个token会计算与4096个像素之间的注意力权重,而4096个像素特征之间根据相似度注意力也可以得到一个自身的注意权重,二者进行叠加实现了注意力融合。在计算tokens对每个像素点的注意力时,模型会通过Q、K和Softmax得到注意力矩阵,该注意力与相似度注意力attn_sim进行叠加再经过softmax得到新的注意力系数。

        PerSAM输入增加负样本点Prompt,因此tokens维度为1x8x256,包括mask质量分数(1x1x256)、mask(1x4x256)、sparse_encoding(1x3x256,包含2个点和1个padding)。

        每次进行transformer时Q中增加参考图像特征1x1x256,这相当于查找测试图像中与参考图像特征的相似关系。

        输出单个mask(1x1x256x256)预测结果masks[0]及其质量得分iou_pred[0],这说明假设第1个mask是最优的。最终返回masks、 scores、 logits、 logits_high。masks维度为1xhxw,是high_res_masks大于0的布尔运算结果;scores维度为1,是1组mask的质量得分,即iou_pred;logits维度为1x256x256,即low_res_masks;logits_high维度为1xhxw,即high_res_masks。

        两次后处理操作与PerSAM-F一样,仅使用原始SAM结构,不再使用相似度注意力。其中第一次微调不输入box矩形框。第一步预测的结果可以认为是mask的初步预测,并作为后处理的mask prompt特征。

8 多目标分割

        多目标分割的主要过程如下:

  1. 使用PerSAM或PerSAM-F预测mask;
  2. 将原图中mask区域置为黑色;
  3. 重复以上两步,直到mask数量达到最大数量的限制,或者mask与已有结果IOU重叠较大,达到重叠阈值。

        PerSAM和PerSAM-F的具体效果可前往论文查看,也可按照工程进行验证测试。

【版权声明】
本文为博主原创文章,未经博主允许严禁转载,我们会定期进行侵权检索。

更多算法总结请关注我的博客:https://blog.csdn.net/suiyingy,或”乐乐感知学堂“公众号。
本文章来自于专栏《大模型》的系列文章,专栏地址为:https://blog.csdn.net/suiyingy/category_12473256.html。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/139296.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件外包开发文档

编写软件开发文档是项目开发过程中的关键步骤,它有助于组织、记录和分享项目的信息和进展。以下是编写软件开发文档的一般步骤和建议,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流合作。 1.文档…

千兆光模块存在哪些局限性

千兆光模块是目前使用最广泛的光模块之一,可以实现1Gbps的传输速度。随着科技的进步和应用场景的多样性,千兆光模块也因其固有的局限性而面临越来越多的挑战。以下是千兆光模块的局限性和如何克服这些局限性的讨论: 千兆光模块可以实现最大…

7.继承与多态 对象村的优质生活

7.1 民法亲属篇:继承(inheritance) 了解继承 在设计继承时,你会把共同的程序代码放在某个类中,然后告诉其他的类说此类是它们的父类。当某个类继承另一个类的时候,也就是子类继承自父类。以Java的方式说&…

【JVM面试题】JVM分代年龄为何是15次?能设置为16吗?

系列文章目录 【JVM系列】第一章 运行时数据区 【JVM面试题】第二章 从JDK7 到 JDK8, JVM为啥用元空间替换永久代? 【JVM面试题】第三章 JVM分代年龄为何是15次?能设置为16吗? 大家好,我是青花。拥有多项发明专利(都是…

【TES720D】青翼科技基于复旦微的FMQL20S400全国产化ARM核心模块

板卡概述 TES720D是一款基于上海复旦微电子FMQL20S400的全国产化核心模块。该核心模块将复旦微的FMQL20S400(兼容FMQL10S400)的最小系统集成在了一个50*70mm的核心板上,可以作为一个核心模块,进行功能性扩展,特别是用…

CSS的美化(文字、背景) Day02

一、文字控制属性 分为:字体样式属性 、文本样式属性 1.1 CSS字体样式属性 1.color定义元素内文字颜色2.font-size 字号大小3 font-family 字体4 font-weight 字体粗细5.font-style 字体风格6.font 字体综合属性 1.1.1 > 文字颜色 color 属性名: color color …

海量小文件数据传输如何确保安全性

在当今的信息化社会,企业需要处理和传输的文件越来越多,越来越大。其中,海量小文件数据是一种特殊的数据类型,它由数亿级别的小文件(通常小于1MB)组成,它在图片网站、物联网设备、日志分析等场景…

【Eclipse】取消按空格自动补全,以及出现没有src的解决办法

【Eclipse】设置自动提示 教程 根据上方链接,我们已经知道如何设置Eclipse的自动补全功能了,但是有时候敲变量名的时候按空格,本意是操作习惯,不需要自动补全,但是它却给我们自动补全了,这就造成了困扰&…

Unity游戏开发中ngui和ugui区别与优缺点详解

Unity3D是一款跨平台的游戏开发引擎,它支持多种平台,包括PC、移动设备和主机。在Unity3D中,UI系统是游戏开发中非常重要的一部分,它负责游戏中的用户界面的显示和交互。 对惹,这里有一个游戏开发交流小组,…

全国400电话办理,简单步骤帮您申请成功

导语:随着企业业务的扩展和市场需求的增加,越来越多的企业开始关注全国400电话的申请。本文将为您介绍全国400电话的办理步骤,帮助您顺利完成申请过程。 正文: 一、了解全国400电话 全国400电话是一种虚拟电话号码,…

百度发布全新 AI 互动式搜索:百度简单搜索

本心、输入输出、结果 文章目录 百度发布全新 AI 互动式搜索:百度简单搜索前言主要能力 相关资料能力介绍 百度搜索升级发文告用户如何获取百度简单搜索百度简单搜索的定位百度简单搜索在 APP 上面的体验讨论和点评我们关注的几个问题 弘扬爱国精神 百度发布全新 AI…

PostgreSQL性能调优:优化查询和索引设计

随着数据量的增长和业务需求的变化,数据库性能成为了许多企业关注的焦点之一。在众多的数据库管理系统中,PostgreSQL因其稳定性和可靠性而备受青睐。然而,即使是最强大的系统也需要合适的调优,以确保其能够高效地处理大规模数据和…