文章目录
- 1 Image Encoder的结构
- 1.1 图片分patch
- 1.2 attention block
- 1.2.1 window partition
- 1.2.2 window unpartition
- 1.2.3 relative partition embedding
- 1.3 neck
- 2 Prompt encoder
- 2.1 point embedding
- 2.2 box embedding
- 2.3 mask embedding
- 3 Mask decoder
- 3.1 预测mask的流程
- 3.2 TwoWayTransformer
- 3.3 后处理
- 4 全图分割
- 4.1 完整流程
- 4.2 生成masks
- 4.3 处理每一个crop_img
参考资料:
- demo: https://segment-anything.com/demo
- paper: https://arxiv.org/abs/2304.02643
- Github:https://github.com/facebookresearch/segment-anything
SAM模型大致上分成3个模块,一个标准的vit构成的image encoder、一个prompt encoder和一个mask decoder。其中:
- Image encoder: 用于输出image embedding;
- prompt encoder:用于接收point、box、txt的编码信息,并且与image embedding组合到一起送入mask decoder中;
- mask decoder:将上述两个encoder的编码信息转化为mask输出。
1 Image Encoder的结构
从结构上看,sam的encoder部分就是堆叠transformer的block结构,最后再跟一个neck,调整输出embedding的维度。Meta开源了三个模型,分别是vit_h, vit_l和vit_b,这三个模型的区别仅仅在于内部patch embedding维度、transformer的block的个数以及每个block中head的数量和全局attention的index:
模型 | patch embedding维度 | transformer head数量 | transformer block层数 | global attention 的block的index |
---|---|---|---|---|
vit_h | 1280 | 16 | 32 | [7, 15, 23, 31] |
vit_l | 1024 | 16 | 24 | [5, 11, 17, 23] |
vit_b | 768 | 12 | 12 | [2, 5, 8, 11] |
网络输入尺寸:1024x1024,
图片分path的尺寸:16,
image embedding的长度:256,
windows size:14。
1.1 图片分patch
原图进入网络之后,按照最大边长补充成方形,再resize到1024x1024。
1024x1024x3的图片输入进入网络后,首先使用一个16x16,stride=16,输出channel数为patch embedding维度的二维卷积。以vit_b为例,patch embedding的维度是768,因此经过卷积之后,图片变成了768x64x64的feature map,再调整维度就变成64x64x768。
在该feature map基础上,会再加一个绝对位置编码(absolute positional embedding),所谓绝对位置编码是指生成一组与feature map同样大小(64x64x768)的可学习参数,初始化时一般为0。
1.2 attention block
1.2.1 window partition
针对非global attention的block,会将上一小节输出的feature map进行补边,再拆分成14x14的网格。流程如下:
输入的特征图大小为:1x64x64x768
窗口的大小为:14x14
得到最小可整除特征图大小为1x70x70x768,因此采用0来padding,padding方式为右下角填充,再将特征图拆分为25x14x14x768。
1.2.2 window unpartition
针对非global attention的block,将attention层输出的特征图1x70x70x768转化为1x64x64x768的特征图,实际上是通过切片操作得到的,即取右上角特征图。
1.2.3 relative partition embedding
相对位置编码出现在attention模块中,用于在q*k计算完成之后,对于attention矩阵进行操作:
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)# 添加相对位置编码
if self.use_rel_pos: attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
- 生成一组可学习的位置编码向量矩阵
相对位置编码针对于非global attention的block,即那些需要将特征图拆分成14x14大小的block。在h,w两个方向上生成的一组可学习的参数,维度为(2 * 14 - 1, 64)。
针对于需要global attention的block,即不拆分特征图,在h,w两个方向上生成的一组可学习的参数,维度为(2 * 64- 1, 64)
# 其中input_size[0] = input_size[1] = 14
# multi-head attention中的head维度:head_dim = block输出的维度768 / head的数量12 = 64
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
为啥是2 * input_size[0] - 1呢?因为矩阵中最远的距离就是对角线元素之间的曼哈顿距离,所以只需生成2*H-1个向量即可。
- 根据特征图的大小,生成相对坐标的index
假设q,k的size为5,则生成的位置编码为:(torch.arange(5)[:, None] - torch.arange(5)[None, :]) + (5 -1)
效果如下:
tensor([[4, 3, 2, 1, 0],
[5, 4, 3, 2, 1],
[6, 5, 4, 3, 2],
[7, 6, 5, 4, 3],
[8, 7, 6, 5, 4]])
从图中可以看出,相对位置编码index是从特征图的某个角开始设置为0,距离该角越远,index越大。再使用该index,从上一步生成的位置编码向量矩阵中取出不同index下的编码向量。
- 基于query矩阵计算最终的相对位置编码
针对非global attention的block,输入的特征图大小为25x14x14x768,所以生成的index矩阵大小为14x14,再用index矩阵取其对应的位置编码向量,得到的就是一个14x14x64的位置编码矩阵,针对h,w两个方向的做同样的操作,得到2个14x14x64的相对位置编码矩阵Rh与Rw。其中Rh基于rel_pos_h生成,Rw基于rel_pos_w生成,两个矩阵不一样。
此时计算出来的query矩阵大小为300x196x64,将其还原到300x14x14x64,再分别与Rh和Rw做矩阵乘法,最终得到的就是位置编码,大小为300x14x14x14,对应代码:
def add_decomposed_rel_pos( attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int], ) -> torch.Tensor: """ Calculate decomposed Relative Positional Embeddings Args: attn (Tensor): attention map. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. q_size (Tuple): spatial sequence size of query q with (q_h, q_w). k_size (Tuple): spatial sequence size of key k with (k_h, k_w). Returns: attn (Tensor): attention map with added relative positional embeddings. """ # q: 300x196x64 # atten:300x196x196 q_h, q_w = q_size k_h, k_w = k_size# Rh: 14x14x64 Rh = get_rel_pos(q_h, k_h, rel_pos_h) Rw = get_rel_pos(q_w, k_w, rel_pos_w) B, _, dim = q.shape# r_q: 300x14x14x64 r_q = q.reshape(B, q_h, q_w, dim) # rel_h: 300x14x14x14 # 等价于: # rel_h = torch.matmul(r_q, Rh.transpose(1, 2))# rel_w = torch.matmul(r_q.transpose(1, 2), Rw.transpose(1, 2)).transpose(1, 2) rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)# 将相对位置编码加在atten里面,再resize回300x192x196 attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) return attn
- 直接将计算好的相对位置编码加到attention矩阵上
attention矩阵为300x196x196,reshape成300x14x14x14x14,再使用矩阵加法,将相对位置编码分别加到倒数2个维度上,再reshape回原来的大小。
对应的代码操作为:
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
1.3 neck
neck部分由两个卷积层组成,分别是256x768x1x1和256x256x3x3,最后输出的image imbedding的尺寸是1x256x64x64。
2 Prompt encoder
根据输入的point和boxs返回sparse embedding, 根据mask返回dense embeddings。
2.1 point embedding
- step1:首先生成一组可学习的向量point embedding,大小为:4x1x256:
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
4代表了表示pos/neg + 2 box corners,即demo里面的添加点和消除点、以及box框的左上角和右下角;
0:neg,对应demo中的消除点
1:pos,对应demo中的添加点
2:代表box左上角点
3:代表box右下角点
- step2:再生成一组可学习的向量not_a_point_embed,大小为1x256,用于表示该位置不是一个点
self.not_a_point_embed = nn.Embedding(1, embed_dim)
- step3:如果传入的prompt里面没有bbox,则补充一个【0,0】点到每个point后面,其对应的label为-1
if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1)
此时point大小为Nx2x2,label为Nx2
- step4:如果传入的还有bbox,此时的point大小为Nx1x2,label为Nx1
- step5:再根据point计算point embedding,其流程如下:
- 横纵坐标先归一化,即都除以输入的尺寸(1024, 1024);
- 再将point矩阵与一个随机高斯矩阵(2x128)矩阵相乘得到Nxax128的矩阵coord,其中(a=2表示只有point,a=1表示还有box作为prompt输入);
- 再分别对coord计算sin和cos,拼接矩阵得到最终的point embedding(Nxax256)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:"""Positionally encode points that are normalized to [0,1]."""# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1coords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shapereturn torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
- step:6再根据label,给point embedding加上之前生成的可学习的embeding向量
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight # 对应label为-1的,加上not_a_point_embed
point_embedding[labels == 0] += self.point_embeddings[0].weight # neg点加上point_embeddings[0]
point_embedding[labels == 1] += self.point_embeddings[1].weight # pos点加上point_embeddings[1]
完整point embedding的流程如下:
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool, ) -> torch.Tensor: """Embeds point prompts.""" points = points + 0.5 # Shift to center of pixel# 如果没有输入的box的话,会将points的长度用0补充形成Nx2x2,label用【-1】补充成Nx2if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1)# 将points与一个2x128的随机高斯矩阵相乘再通过进行sin、cos运算,两者的运算结果拼接得到# point_embedding: Nx1x256 或者 Nx2x256point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weightpoint_embedding[labels == 0] += self.point_embeddings[0].weightpoint_embedding[labels == 1] += self.point_embeddings[1].weightreturn point_embedding
返回之后需要与sparse embedding进行拼接:
# 如果只有point,那么sparse_embeddings的size是Nx2x256,如果还有box则是Nx1x256
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
如果只有point,当前sparse_embeddings 的大小为Nx2x256
如果还有box,当前sparse_embeddings 的大小为Nx1x256
2.2 box embedding
bbox一般有2个点,其编码步骤如下:
step1: 所以回先resize为Nx2x2;
step2: 再使用point embedding编码的方式,得到corner_embedding,
step3: 再加上之前生成的可学习的embeding向量;
最后输出的corner_embedding大小为Nx2x256。
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: """Embeds box prompts.""" boxes = boxes + 0.5 # Shift to center of pixel# 操作与points类似,讲4个点resize成Nx2x2coords = boxes.reshape(-1, 2, 2)# 返回Nx2x256的embeddingcorner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)# 再加上点的embedding corner_embedding[:, 0, :] += self.point_embeddings[2].weight corner_embedding[:, 1, :] += self.point_embeddings[3].weight return corner_embedding
最后输出的box的embedding的尺寸是Nx2x256。
合并(concat)point embedding和corner embedding,可以得到sparse embedding:
- 全都没有:sparse embedding(1x0x256)
- 如果只有point:sparse embedding(Nx2x256)
- 如果只有box:sparse embedding(Nx2x256)
- piont、box都有:sparse embedding(Nx3x256)
2.3 mask embedding
- 如果没有配置mask,有一个长度为256的可学习向量,表示没有mask embedding,再将其拓展为1x256x64x64
self.no_mask_embed = nn.Embedding(1, embed_dim) # embed_dim=256
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
- **如果有配置mask:**已知输入的mask是Nx1x256x256,经过3层卷积,最后得到与image embedding一样的size:
mask会先进入一个1x2x2x4的卷积,stride=2;LN;
然后再进入一个4x2x2x16的卷积,stride=2;LN;
最后再进入一个16x1x1x256的卷积;
得到最后的mask_embedding的size为Nx256x64x64
最终mask embeding作为dense embedding输出,大小为Nx256x64x64。
3 Mask decoder
初始化几个可学习的参数:
可学习的mask tokens:4x256
# num_mask_tokens = 3 + 1 = 4, transformer_dim = 256
# 输出一个4x256的矩阵
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
可学习的iou tokens:1x256
self.iou_token = nn.Embedding(1, transformer_dim)
image_pe: 跟image embedding一样大的位置编码256x64x64 ,见prompt_encoder.py:PositionalEmbeddingRandom.get_dense_pe()
就是将64x64个坐标点归一化之后,与随机高斯矩阵相乘(2x128),再将结果分别进行sin和cos,最后再拼到一起,输出的大小为256x64x64,与image_embedding大小基本一致了。
class PositionEmbeddingRandom(nn.Module):"""Positional encoding using random spatial frequencies."""def init(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:super().init()if scale is None or scale <= 0.0:scale = 1.0# 构建一个2x128的随机矩阵作为位置编码高斯矩阵self.register_buffer("positional_encoding_gaussian_matrix",scale * torch.randn((2, num_pos_feats)),)def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:"""Positionally encode points that are normalized to [0,1]."""# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shapecoords = 2 * coords - 1# 矩阵乘法:64x64xx2 @ 2x128 ---> 64x64x128coords = coords @ self.positional_encoding_gaussian_matrixcoords = 2 * np.pi * coords# outputs d_1 x ... x d_n x C shape# cat, 最后一个维度上拼接:64x64x256return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)def forward(self, size: Tuple[int, int]) -> torch.Tensor:"""Generate positional encoding for a grid of the specified size."""h, w = sizedevice: Any = self.positional_encoding_gaussian_matrix.device# 构造一个64x64的全1矩阵grid = torch.ones((h, w), device=device, dtype=torch.float32)# 行、列累加y_embed = grid.cumsum(dim=0) - 0.5x_embed = grid.cumsum(dim=1) - 0.5# 行列累加结果归一化y_embed = y_embed / hx_embed = x_embed / w# 行列拼接:64x64x2,编码后的结果是64x64x256pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))# 最后输出256x64x64return pe.permute(2, 0, 1) # C x H x W
3.1 预测mask的流程
- sparse embedding、iou token、mask token合并成一个tokens,作为point_embeddings
需要注意的是:
sparse embedding: point、bbox prompt合并后的产物,一般为NxXx256
iou token: 可学习参数,大小为1x256
mask token: 可学习参数,大小为4x256
首先将iou token和mask token 拼接得到一个5x256的矩阵,再将其拓展到与sparse embedding一个维度Nx5x256;
再将拓展后的矩阵与sparse embedding拼接得到tokens,其大小Nx(5+X)x256;
# 代码见:mask_decoder.py -> predict_masks
# 拼接iou_token和mask token得到i: 5x256 的tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 拓展成稀疏prompt编码的个数:Nx5x256
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)# 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
- image embedding与dense prompt直接相加得到Nx256x64x64的矩阵,命名为src,作为image_embedding
需要注意的是:
image embedding: 是image encoder的输出,大小为为1x256x64x64
dense prompt: 是mask embedding的产物,大小为Nx256x64x64
image embedding拓展维度之后直接与dense prompt相加,得到image_embedding,大小为Nx256x64x64:
# 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)# 将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64
src = src + dense_prompt_embeddings
- image_pe位置编码也拓展成Nx256x64x64的矩阵,命名为pos_src
需要注意的是:image_pe相当于特征图中每个位置进行了与point类似的编码操作
# 将256x64x64的位置编码,拓展成Nx256x64x64
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
- 将这三个送入TwoWayTransformer中,返回的结果后处理后就能得到最终的mask信息。
def predict_masks(self,image_embeddings: torch.Tensor,image_pe: torch.Tensor,sparse_prompt_embeddings: torch.Tensor,dense_prompt_embeddings: torch.Tensor,) -> Tuple[torch.Tensor, torch.Tensor]:"""Predicts masks. See 'forward' for more details."""# Concatenate output tokens# 拼接iou_token和mask token得到i: 5x256 的tokensoutput_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)# 拓展成稀疏prompt编码的个数:Nx5x256output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)# 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)# Expand per-image data in batch direction to be per-mask# 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)# 将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64src = src + dense_prompt_embeddings# 将256x64x64的位置编码,拓展成Nx256x64x64pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)b, c, h, w = src.shape# Run the transformer:这里使用的TwoWayTransformer,有必要对输入再说明一下# src:image_bedding + dense_prompt(mask),Nx256x64x64# pos_src: 位置编码,Nx256x64x64# tokens: iou_tokens + mask_tokens + sparse_prompt(point/bbox),Nx(5+x)x256hs, src = self.transformer(src, pos_src, tokens)# 后处理iou_token_out = hs[:, 0, :]mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokenssrc = src.transpose(1, 2).view(b, c, h, w)upscaled_embedding = self.output_upscaling(src)hyper_in_list: List[torch.Tensor] = []for i in range(self.num_mask_tokens):hyper_in_list.append(self.output_hypernetworks_mlpsi)hyper_in = torch.stack(hyper_in_list, dim=1)b, c, h, w = upscaled_embedding.shapemasks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# Generate mask quality predictionsiou_pred = self.iou_prediction_head(iou_token_out)return masks, iou_pred
3.2 TwoWayTransformer
【关于transformer的一个心得】:q、k、v中,k、v一定具有相同的size,最后输出的attention的size是由q来决定的。
参数:
- depth:2,表示attention block只有2个
- embedding_dim: 256
- mlp_dim: 2048
- num_heads: 8
所谓的TwoWay:两轮次循环,第一次point_embedding自注意,第二次则加上上一轮输出的query再进行attention。
两层TwoWayAttentionBlock:
- 第一层:q = q_pe = point_embedding,k = image_embedding, k_pe = image_pe
- 第二层:q = 第一层输出q, q_pe = point_embedding,k = 第一层输出k, k_pe = image_pe
整个流程如下:
- 先将image_embedding转换shape:Nx256x64x64 —> Nx4096x256 (注意此时的image embedding是encoder的输出+dense prompt)
- 再将位置编码也调整shape:Nx256x4096 —> Nx4096x256
- 将image embedding( Nx4096x256)作为key,point embedding(Nx(5+x)x256)作为querise,送入2层TwoWayAttentionBlock,需要注意的是point embedding是point、bbox和iou、mask embedding拼接结果
- TwoWayAttentionBlock返回 key和querise
- q = querise + point_embedding, k = key + image_pe,v = key将其输入到final_attn_token_to_image中
- 最后输出的结果是queries = queries + attn_out,再经过norm
def forward(self,image_embedding: Tensor,image_pe: Tensor,point_embedding: Tensor,) -> Tuple[Tensor, Tensor]:qur"""Args:image_embedding (torch.Tensor): image to attend to. Should be shapeB x embedding_dim x h x w for any h and w.image_pe (torch.Tensor): the positional encoding to add to the image. Musthave the same shape as image_embedding.point_embedding (torch.Tensor): the embedding to add to the query points.Must have shape B x N_points x embedding_dim for any N_points.Returns:torch.Tensor: the processed point_embeddingtorch.Tensor: the processed image_embedding"""# BxCxHxW -> BxHWxC == B x N_image_tokens x C# image_embedding: Nx256x64x64bs, c, h, w = image_embedding.shape# Nx256x4096 ---> Nx4096x256image_embedding = image_embedding.flatten(2).permute(0, 2, 1)# Nx256x4096 ---> Nx4096x256image_pe = image_pe.flatten(2).permute(0, 2, 1)# Prepare queriesqueries = point_embedding # Nx(5+x)x256keys = image_embedding # Nx4096x256# Apply transformer blocks and final layernorm# 将稀疏prompt和iou、mask的tokens组合tokens作为querise,image embedding作为key# 进入2层的TwoWayAttentionBlockfor layer in self.layers:queries, keys = layer(queries=queries,keys=keys,query_pe=point_embedding,key_pe=image_pe,)# Apply the final attention layer from the points to the image# 出TwoWayAttentionBlock的querise继续加上组合tokens作为queriseq = queries + point_embedding# 出TwoWayAttentionBlock的key再加上位置编码作为keyk = keys + image_pe# 进入最后的attention层,但是v还是没有加上位置编码的keyattn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)# querise + attention的输出并norm后输出queries = queries + attn_outqueries = self.norm_final_attn(queries)return queries, keys
3.3 后处理
TwoWayTransformer返回的结果为:
hs: Nx(5+x)x256
src: Nx4096x256
-
取tokens
- 取第一个位置为iou_token : hs[:, 0, :] —> Nx1x256
- 取1~5为位子为mask_token: hs[:, 1 : (1 + 4), :] —> Nx4x256
-
reshape src: Nx4096x256 —> Nx256x64x64
-
通过2层转置卷积,将src变成Nx32x256x256
-
将4个mask token分别送入4个独立的3层全连接网络中(每个channel进入不同的FC),最终得到hyper_in:Nx4x32
-
将hyper_in 矩阵乘 reshap src得到Nx4x256x256的矩阵,这就是输出的mask!
-
将tokens送入3层全连接网络,最后得到iou_pred,大小为Nx1x4,这就输出的IoU_pred!
流程示意图如下,最后输出的N与传入的prompt的数量有关
hs, src = self.transformer(src, pos_src, tokens)
# 后处理
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):hyper_in_list.append(self.output_hypernetworks_mlpsi)hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
输出的mask有4个通道:
通道0:不用
通道1:whole
通道2:part
通道3:subpart
所以一般有一个multimask_output字段来控制是只输出whole,还是全部都输出。
4 全图分割
代码详见:automatic_mask_generator.py
class SamAutomaticMaskGenerator:def init(self,model: Sam, # sam模型points_per_side: Optional[int] = 32, # 每个边需要采样的点数默认为32,最后的总点数为32x32points_per_batch: int = 64, # 模型可以同时处理的点数,默认64,数字越大越快GPU越高pred_iou_thresh: float = 0.88, # iou阈值,默认0.88stability_score_thresh: float = 0.95, stability_score_offset: float = 1.0,box_nms_thresh: float = 0.7, crop_n_layers: int = 0, # 裁剪的层数 crop_nms_thresh: float = 0.7,crop_overlap_ratio: float = 512 / 1500, # 裁剪图片间的重叠情况crop_n_points_downscale_factor: int = 1,point_grids: Optional[List[np.ndarray]] = None, # 点的网格列表,与points_per_side直接相关min_mask_region_area: int = 0,output_mode: str = "binary_mask",) -> None:
-
每张图片默认撒点方式: 首先根据每个边需要采样的点数,默认时32,生成32x32的网格point_grids,所以最后输出的就是1024个坐标,且都归一化到0-1之间。
-
crop_n_layers: 裁剪的层数,每层裁剪的crop_img个数为2^(n+1)个,即第一层裁剪4个,第二层16个,依次类推.
- point_grids也可以根据crop_n_layers每层进行递减,通过crop_n_points_downscale_factor控制,默认设置为1表示所有的裁剪图crop_img也均匀采样1024个点
Step1: 原图裁剪,一般crop_n_layers设置为0,即送全图区域:
Step2: 将图片补边,再resize到1024x1024,送入Image Encoder中生成image embedding;
Step3: 图片宽高方向各均匀生成32个位置,组成1024个坐标点;
Step4: 每次送入64个坐标点,迭代1次,生成mask及iou_pred;
Step5: 结果后处理
-
根据iou阈值(默认0.88)过滤mask
-
对过滤后计算calculate_stability_score稳定性分值=(mask > 1的数量) / (mask > -1的数量)
-
根据calculate_stability_score过滤mask,阈值默认为0.95
-
对过滤后的mask取阈值0,得到掩膜,根据掩模计算外界矩形框
-
过滤外界矩形框达到crop边界的对应的mask
-
将截取图片crop_img的mask,映射到原图尺寸上
-
再将mask转化为rle编码,用于节省内存,mask拉平,(3,3)表示第3个元素开始,后面3个都是1
4.1 完整流程
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:# Generate masks# 生成maskmask_data = self._generate_masks(image)# Filter small disconnected regions and holes in masks# 过滤小的区域或者空洞if self.min_mask_region_area > 0:mask_data = self.postprocess_small_regions(mask_data,self.min_mask_region_area,max(self.box_nms_thresh, self.crop_nms_thresh),)# Encode masks# 输出的mask格式,默认输出binary_mask二值掩码if self.output_mode == "coco_rle":mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]elif self.output_mode == "binary_mask":mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]else:mask_data["segmentations"] = mask_data["rles"]# Write mask records# 将结果整理输出curr_anns = []for idx in range(len(mask_data["segmentations"])):ann = {"segmentation": mask_data"segmentations","area": area_from_rle(mask_data"rles"),"bbox": box_xyxy_to_xywh(mask_data"boxes").tolist(),"predicted_iou": mask_data"iou_preds".item(),"point_coords": [mask_data"points".tolist()],"stability_score": mask_data"stability_score".item(),"crop_box": box_xyxy_to_xywh(mask_data"crop_boxes").tolist(),}curr_anns.append(ann)return curr_anns
4.2 生成masks
def _generate_masks(self, image: np.ndarray) -> MaskData:orig_size = image.shape[:2]#由于默认的crop_n_layers为0,所以返回的crop_box为全图,layer_idxs只有0crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio)# Iterate over image crops# 将每一个抠图区域送入网络中data = MaskData()for crop_box, layer_idx in zip(crop_boxes, layer_idxs):crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)data.cat(crop_data)# Remove duplicate masks between crops# 将所有crop图片的结果汇总到一起进行NMS过滤if len(crop_boxes) > 1:# Prefer masks from smaller cropsscores = 1 / box_area(data["crop_boxes"])scores = scores.to(data["boxes"].device)keep_by_nms = batched_nms(data["boxes"].float(),scores,torch.zeros_like(data"boxes"), # categoriesiou_threshold=self.crop_nms_thresh,)data.filter(keep_by_nms)data.to_numpy()return data
4.3 处理每一个crop_img
def _process_crop(self,image: np.ndarray,crop_box: List[int],crop_layer_idx: int,orig_size: Tuple[int, ...],) -> MaskData:# Crop the image and calculate embeddings# 裁剪图片x0, y0, x1, y1 = crop_boxcropped_im = image[y0:y1, x0:x1, :]cropped_im_size = cropped_im.shape[:2]# 将裁剪后的图片送入网络,变成了encoder所需的1024*1024# 1、先按照比例将图片最大的边resize到1024# 2、调整位置HxWxC ---> 1xCxHxW# 3、再将图片归一化# 4、用0padding,使其成为1024x1024的size# 5、送入encoder计算image_embeddingself.predictor.set_image(cropped_im)# Get points for this crop# 将1024x2的矩阵乘以裁剪图片的大小,就得到了在裁剪图片上的gridpoints_scale = np.array(cropped_im_size)[None, ::-1]points_for_image = self.point_grids[crop_layer_idx] * points_scale# Generate masks for this crop in batchesdata = MaskData()# 每张图送入网络1024个点,每次同时计算points_per_batch(64个),因此需要迭代1024 / 64 = 16次for (points,) in batch_iterator(self.points_per_batch, points_for_image):# 1、将坐标点映射到1024x1024的图片上# 2、每个点的label设置为1,label的size就是64x1# 3、送入decoder计算mask# 3.1 先送入prompt encoder,由于只有point,得到sparse embedding和dense embedding(no_mask_embed)# 3.2 将其送入mask decoder得到最后输出的# 3.3 切片输出,如果需要输出多个mask,取index 1 ~ 3, 如果只输出一个mask index取0# 3.4 对mask做后处理:先resize回1024x1024,取出非padding部分再resize回原图# 4、一系列的后处理# 4.1 根据iou阈值(默认0.88)过滤mask# 4.2 对过滤后计算calculate_stability_score稳定性分值=(mask > 1的数量) / (mask > -1的数量)# 4.3 根据calculate_stability_score过滤mask,阈值默认为0.95# 4.4 对过滤后的mask取阈值0,得到掩膜,根据掩模计算外界矩形框# 4.5 过滤外界矩形框达到crop边界的对应的mask# 4.6 将截取图片crop_img的mask,映射到原图尺寸上# 4.7 再将mask转化为rle编码,用于节省内存,mask拉平,(3,3)表示第3个元素开始,后面3个都是1batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)data.cat(batch_data)del batch_dataself.predictor.reset_image()# Remove duplicates within this crop.# 根据bbox,使用NMS过滤重复的结果keep_by_nms = batched_nms(data["boxes"].float(),data["iou_preds"],torch.zeros_like(data"boxes"), # categoriesiou_threshold=self.box_nms_thresh,)data.filter(keep_by_nms)# Return to the original image frame# bbox和point映射回原图坐标data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)data["points"] = uncrop_points(data["points"], crop_box)data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])return data