All Things ViTs系列讲座从ViT视觉模型注意力机制出发,本文给出DINO attention map可视化部分阅读学习体会.
课程视频与课件: https://all-things-vits.github.io/atv/
代码:https://colab.research.google.com/github/all-things-vits/code-samples/blob/main/probing/dino_attention_maps.ipynb
文献: Emerging Properties in Self-Supervised Vision Transformers
1. 总述
这部分代码实际与CVPR 2023 Hybrid Tutorial: All Things ViTs之mean attention distance (MAD)比较类似.Emerging Properties in Self-Supervised Vision Transformers这篇论文主要想论述采用DINO(self-distillation with no labels)的训练方式可以达到类似监督学习的训练效果,因此通过可视化注意力机制图的方式来说明DINO学习方法是学到了需要的知识.这部分代码的核心是可视化最后一个block中CLS token与其他图像patch token之间的注意力图.即展示每个patch与CLS之间的关联性,如Fig 1:
Fig 1 DINO attention map
2.关键代码
def get_attention_scores(image: Image, model: torch.nn.Module, processor):"""Extracts attention scores given an image, a model,and its processor."""inputs = processor(image, return_tensors="pt")with torch.no_grad():outputs = model(**inputs, output_attentions=True)return outputs.attentions # [[1, 12, 197, 197]*12]
这部分代码将图像输入ViT网络,并得到输出的每个head的注意力分数(outputs.attentions).outputs.attentions是一个tuple,其中包括12个维度为[1, 12, 197, 197]的tensor.这个tensor可理解如下,其中12为head的数量,197是token的数量.197*197表示每个token之间的注意力分数.
def process_attention_map(image: torch.Tensor, attention_scores, block_id=11, patch_size=16
):"""Processes the attention scores such that they can be overlaid on the input image.Args:image (torch.Tensor): The input image tensor.attention_scores (Tuple[torch.Tensor]): Tuple of attention scores.block_id (int, optional): The block ID. Default is 11. 11 is the lasttransformer block for a DINO base model.patch_size (int, optional): The size of the patches. Default is 16.Returns:numpy.ndarray: The processed attention map as a NumPy array."""height, width = image.shape[2:]w_featmap = width // patch_sizeh_featmap = height // patch_sizenum_heads = attention_scores[block_id].shape[1] # Number of attention heads. 12# Taking the representations from CLS token.attentions = attention_scores[block_id][0, :, 0, 1:].reshape(num_heads, -1)# only visual cls tokens of each headprint(attentions.shape) # [12, 196], 196 is the number of patches(tokens), 12 is the number of heads, attention of cls token to each patch in one head# Reshape the attention scores to resemble mini patches.attentions = attentions.reshape(num_heads, w_featmap, h_featmap) # [12, 14, 14]print(attentions.shape)# Resize the attention patches to 224x224 (224: 14x16)attentions = F.resize(attentions, size=(h_featmap * patch_size, w_featmap * patch_size))return attentions.numpy()
这部分代码则是抽取CLS token与其他patch token之间的attention矩阵,并上采样到与原始图像一样大并绘制.
3. One more thing
这部分代码可视化比较好理解,接下来可以更全面看看DINO可视化效果:
可以看到,经过DINO训练,ViT能够有效关注图像中有意义的部分.并且不同的head在关注不同的区域.后续经过简单微调,即可泛化至下游任务.