【Transformer】detr之encoder逐行梳理(二)

every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog

0. 前言

detr之encoder逐行梳理

1. 整体

encoder由encoder layer构成

输入进encoder的特征shape:(hw,b,c),后文将给出说明

class Transformer(nn.Module):def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False,return_intermediate_dec=False):super().__init__()# encoder layerencoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)encoder_norm = nn.LayerNorm(d_model) if normalize_before else None# encoder 部分self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)... # 略def forward(self, src, mask, query_embed, pos_embed):# flatten bxCxHxW to HWxbxCbs, c, h, w = src.shape# (b,c,h,w) ->(b,c,hw) -> (hw,b,c) src = src.flatten(2).permute(2, 0, 1)# (b,c,h,w) ->(b,c,hw) -> (hw,b,c) pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# (b,h,w) -> (b,hw)mask = mask.flatten(1)memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)... # 略

2. 部分

2.1 get_clone

用于对指定的层进行复制

def _get_clones(module, N):return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

2.2 Encoder

串联多个layer,输出作为输入

20240422143301

class TransformerEncoder(nn.Module):def __init__(self, encoder_layer, num_layers, norm=None):super().__init__()# 对指定的层进行复制self.layers = _get_clones(encoder_layer, num_layers)self.num_layers = num_layersself.norm = normdef forward(self, src,mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):output = srcfor layer in self.layers:# 输出作为输入output = layer(output, src_mask=mask,src_key_padding_mask=src_key_padding_mask, pos=pos)if self.norm is not None:output = self.norm(output)return output

2.3 EncoderLayer

结构:

最开始的输入是backone的输出,即,src,后续的输入是上一层的输出
20240422150225

其中forward包含forward_post和forward_pre两个函数,主要区别是最开始进行标准化还是最后进行标准化。

由于self.normalize_before默认是False,所以默认是forward_post ,如下方的局部代码所示


q 和 k = backbone输出特征图 + 位置编码

这里对query和key增加位置编码 是因为需要在图像特征中各个位置之间计算相似度/相关性, 而value作为原图像的特征 和 相关性矩阵加权,
从而得到各个位置结合了全局相关性(增强后)的特征表示,所以q 和 k这种计算需要+位置编码 而v代表原图像不需要加位置编码


其中注意力计算主要涉及到两个参数:

  • key_padding_mask: 这部分就是我们在backbone中获取的mask,
    记录backbone生成的特征图中哪些是原始图像pad的部分 这部分是没有意义的
    计算注意力会被填充为-inf,这样最终生成注意力经过softmax时输出就趋向于0,相当于忽略不计。

  • attn_mask: 是在Transformer中用来“防作弊”的,即遮住当前预测位置之后的位置,忽略这些位置,不计算与其相关的注意力权重
    在encoder中通常为None,不使用,因为要计算全局的相关性。 decoder中才使用

forward_post局部代码:

def with_pos_embed(self, tensor, pos: Optional[Tensor]):return tensor if pos is None else tensor + posdef forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos) # q,k都添加位置编码# 计算# key_padding_mask: 记录backbone生成的特征图中哪些是原始图像pad的部分 这部分是没有意义的#                   计算注意力会被填充为-inf,这样最终生成注意力经过softmax时输出就趋向于0,相当于忽略不计# attn_mask: 是在Transformer中用来“防作弊”的,即遮住当前预测位置之后的位置,忽略这些位置,不计算与其相关的注意力权重#            而在encoder中通常为None 不适用  decoder中才使用src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]# 残差连接src = src + self.dropout1(src2)# 标准化src = self.norm1(src)# FFNsrc2 = self.linear2(self.dropout(self.activation(self.linear1(src))))# 残差连接src = src + self.dropout2(src2)# 最后进行标准化src = self.norm2(src) return src

默认batch_first = False
20240422152811

所以输入的形式是(l,batch,d),即我们最开始看到的(hw,b,c)
20240422152750

输出两个值,第一个是计算结果,第二个是权重。只需要第一个所以上面用了[0]

20240422153424

EncoderLayer完整代码:

class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False):super().__init__()self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)# Implementation of Feedforward modelself.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.activation = _get_activation_fn(activation)self.normalize_before = normalize_beforedef with_pos_embed(self, tensor, pos: Optional[Tensor]):return tensor if pos is None else tensor + posdef forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos)src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)src = self.norm1(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src) # 最后进行标准化return srcdef forward_pre(self, src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):src2 = self.norm1(src) # 最开始进行标准化q = k = self.with_pos_embed(src2, pos)src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)src2 = self.norm2(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))src = src + self.dropout2(src2)return srcdef forward(self, src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):# 默认是Falseif self.normalize_before:return self.forward_pre(src, src_mask, src_key_padding_mask, pos)return self.forward_post(src, src_mask, src_key_padding_mask, pos)

参考

  1. https://blog.csdn.net/weixin_39190382/article/details/137905915?spm=1001.2014.3001.5502
  2. https://hukai.blog.csdn.net/article/details/127616634

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/641531.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java-springmvc 01

MVC就是和Tomcat有关。 01.MVC启动的第一步,启动Tomcat 02.Tomcat会解析web-inf的web.xml文件

《架构风清扬-Java面试系列第23讲》如何理解Java的泛型檫除?

晚上好,给大家加个餐 来,思考片刻,说出你的答案 1,什么是泛型檫除? 泛型擦除是指编译器在处理泛型代码时,会在编译阶段移除(擦除)所有与泛型相关的类型参数信息,将其替换…

Java 网络编程之TCP(三):基于NIO实现服务端,BIO实现客户端

前面的文章,我们讲述了BIO的概念,以及编程模型,由于BIO中服务器端的一些阻塞的点,导致服务端对于每一个客户端连接,都要开辟一个线程来处理,导致资源浪费,效率低。 为此,Linux 内核…

VulBG: 构建行为图加强基于深度学习的漏洞检测模型

近年来,人们提出了基于深度学习(DL)的漏洞检测系统,用于从源代码中自动提取特征。这些方法在合成数据集上可以实现理想的性能,但在检测真实世界的漏洞数据集时,准确率却大幅下降。此外,这些方法…

数字时代的社交王者:探索Facebook的社交帝国

引言:社交媒体的霸主 在数字化浪潮席卷全球的当下,社交媒体已然成为人们日常生活中不可或缺的一部分,而Facebook则是这个领域的不二之选。作为全球最大的社交网络,Facebook不仅拥有庞大的用户群体,更在技术创新、社会…

EJB和Spring

1. EJB 1.1. 背景 功能日趋复杂的软件,如果把所有的功能实现都放在客户端,不仅代码没有安全性,部署及发布运维都会变的很复杂,所以将软件的功能实现分为客户端和服务端,服务端和客户端之间通过网络调用进行功能实现。…

要养生也要时尚,益百分满足你的所有需求

要养生也要时尚,益百分满足你的所有需求 艾灸是个好东西,尤其是在近几年的时候,艾灸就像一阵浪潮席卷进了人们的日常生活之中,我们可以在街边看到大大小小的艾灸馆,有些评价比较高的艾灸馆门前甚至还排起了长长的队伍…

Leetcode 118 杨辉三角

目录 一、问题描述二、示例及约束三、代码方法一:数学 四、总结 一、问题描述 给定一个非负整数 numRows,生成「杨辉三角」的前 numRows 行。   在「杨辉三角」中,每个数是它左上方和右上方的数的和。 二、示例及约束 示例 1&#xff1a…

布局香港之零售小店篇 | 香港一人小企与连锁超市的竞争

近年来,内地品牌入驻香港市场开拓业务已成大势所趋。香港特区政府早前公布的「2023年有香港境外母公司的驻港公司按年统计调查」显示,2023年母公司在海外及内地的驻港公司数量高达9039家。内地品牌在香港的成功落地,不仅为香港市民带来了丰富…

Windows 平台上面管理服务器程式的高级 QoS 策略

在 Windows 平台上面,目前有两个办法来调整应用程式的 QoS 策略设置,一种是通过程式设置,一种是通过 “Windows 组策略控制”。 在阅读本文之前,您需要先查阅本人以下的几篇文献,作为前情提示: VC Windows…

跨境电商测评攻略:如何安全有效地提升业绩?

跨境电商做久了,卖家都会陷入一个困境,到底是该坚持慢慢做好,还是要测评? 有卖家表示,美客多基本的操作如果熟练了之后,就不用在运营上费太多功夫 这时候要好好规划一下测评的事情,做美客多到最后你会发…

双链向表专题

1.链表的分类 链表的种类非常多组合起来就有 2 2 8种 链表说明: 虽然有这么多的链表的结构,但是我们实际中最常⽤还是两种结构: 单链表 和 双向带头循环链表 1. 无头单向⾮循环链表:结构简单,⼀般不会单独⽤来存数…