TansUNet代码理解

首先通过论文中所给的图片了解网络的整体架构:
在这里插入图片描述

vit_seg_modeling部分

模块引入和定义相关量:

# coding=utf-8
# __future__ 在老版本的Python代码中兼顾新特性的一种方法
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport copy
import logging
import mathfrom os.path import join as pjoinimport torch
import torch.nn as nn
import numpy as npfrom torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from . import vit_seg_configs as configs
from .vit_seg_modeling_resnet_skip import ResNetV2logger = logging.getLogger(__name__)ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"# 获取超参
CONFIGS = {'ViT-B_16': configs.get_b16_config(),'ViT-B_32': configs.get_b32_config(),'ViT-L_16': configs.get_l16_config(),'ViT-L_32': configs.get_l32_config(),'ViT-H_14': configs.get_h14_config(),'R50-ViT-B_16': configs.get_r50_b16_config(),'R50-ViT-L_16': configs.get_r50_l16_config(),'testing': configs.get_testing(),
}

工具函数的定义:
np2th用于将numpy格式的数据改为tensor。

def np2th(weights, conv=False):"""Possibly convert HWIO to OIHW."""if conv:weights = weights.transpose([3, 2, 0, 1])return torch.from_numpy(weights)

swish时由谷歌团队提出来的激活函数,他们实验表明,在一些具有挑战性的数据集上,它的效果比relu更好。

def swish(x):return x * torch.sigmoid(x)ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

采用自顶向下的结构来理解代码
VisionTransformer就是模型的整个结构,其中调用了Transformer,DecoderCup,SegmentationHead,load_from用于加载训练好的参数。

class VisionTransformer(nn.Module):def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):super(VisionTransformer, self).__init__()self.num_classes = num_classesself.zero_head = zero_headself.classifier = config.classifierself.transformer = Transformer(config, img_size, vis)self.decoder = DecoderCup(config)self.segmentation_head = SegmentationHead(in_channels=config['decoder_channels'][-1],out_channels=config['n_classes'],kernel_size=3,)self.config = configdef forward(self, x):if x.size()[1] == 1:x = x.repeat(1, 3, 1, 1)x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)x = self.decoder(x, features)logits = self.segmentation_head(x)return logitsdef load_from(self, weights):# with torch.no_grad()将所有require_grad临时设置为False,这样可以只更新变量的值with torch.no_grad():res_weight = weightsself.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])posemb_new = self.transformer.embeddings.position_embeddingsif posemb.size() == posemb_new.size():self.transformer.embeddings.position_embeddings.copy_(posemb)elif posemb.size()[1] - 1 == posemb_new.size()[1]:posemb = posemb[:, 1:]self.transformer.embeddings.position_embeddings.copy_(posemb)else:logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))ntok_new = posemb_new.size(1)if self.classifier == "seg":_, posemb_grid = posemb[:, :1], posemb[0, 1:]gs_old = int(np.sqrt(len(posemb_grid)))gs_new = int(np.sqrt(ntok_new))print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)zoom = (gs_new / gs_old, gs_new / gs_old, 1)posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2npposemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)posemb = posemb_gridself.transformer.embeddings.position_embeddings.copy_(np2th(posemb))# Encoder wholefor bname, block in self.transformer.encoder.named_children():for uname, unit in block.named_children():unit.load_from(weights, n_block=uname)if self.transformer.embeddings.hybrid:self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))# .view(-1)将tensor展开为一维张量,但不改变该对象本身的形状gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():for uname, unit in block.named_children():unit.load_from(res_weight, n_block=bname, n_unit=uname)

接下来是Transformer的代码:
Transformer包括了Embeddings和Encoder:

class Transformer(nn.Module):def __init__(self, config, img_size, vis):super(Transformer, self).__init__()self.embeddings = Embeddings(config, img_size=img_size)self.encoder = Encoder(config, vis)def forward(self, input_ids):embedding_output, features = self.embeddings(input_ids)encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)return encoded, attn_weights, features

Embeddings的功能对应于图片中的:
在这里插入图片描述
ResNetV2(这部分的代码放在最后一个部分)对图片通过卷积操作提取特征,然后将提取到的各层特征返回到Embeddings。
拿到ResNetV2返回的特征后,将最后一层的特征分割为多个切片,并将各个切片映射成长度为patch_size*patch_size*channels的向量,并且加上位置序列信息,对应于图片的这个部分:
在这里插入图片描述

class Embeddings(nn.Module):"""Construct the embeddings from patch, position embeddings."""def __init__(self, config, img_size, in_channels=3):super(Embeddings, self).__init__()self.hybrid = Noneself.config = config# 应该是把参数中的img_size,转换为元组形式即:img_size = (value,value)这里的value即为参数的img_size。img_size = _pair(img_size)if config.patches.get("grid") is not None:  # ResNetgrid_size = config.patches["grid"]  # grid 是一个元组,值为:输入图片大小//切片大小patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])self.hybrid = Trueelse:patch_size = _pair(config.patches["size"])n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])self.hybrid = Falseif self.hybrid:self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)in_channels = self.hybrid_model.width * 16# patch_embeddings通过卷积操作将输入转变为(B, hidden_size, n_patches^(1/2), n_patches^(1/2))# hidden_size是一个token(相当于输入的一个词)的长度self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=patch_size,stride=patch_size)# 各个向量的位置序列self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))self.dropout = Dropout(config.transformer["dropout_rate"])def forward(self, x):if self.hybrid:x, features = self.hybrid_model(x)else:features = Nonex = self.patch_embeddings(x)  # (B, hidden, n_patches^(1/2), n_patches^(1/2))x = x.flatten(2)  # 表示从2维开始压缩,得到(B, hidden, n_patches)x = x.transpose(-1, -2)  # 对最后两个维度进行转置(B, n_patches, hidden)embeddings = x + self.position_embeddings  # 加上位置序列embeddings = self.dropout(embeddings)return embeddings, features

Encoder是图像的编码部分,根据num_layers生成多个Block模块

class Encoder(nn.Module):def __init__(self, config, vis):super(Encoder, self).__init__()self.vis = vis# nn.ModuleList()一个module列表,与普通的list相比,它继承了nn.Module的网络模型class,因此可以识别其中的parameters,# 即该列表中记录的module可以被主module识别,但它只是一个list,不会自动实现forward方法。self.layer = nn.ModuleList()self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)for _ in range(config.transformer["num_layers"]):layer = Block(config, vis)self.layer.append(copy.deepcopy(layer))def forward(self, hidden_states):attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights

Block包括了MSA(Multihead Self-Attention)和MSA(Multi-Layer Perceptron)两个结构,对应于图像中的:
在这里插入图片描述

class Block(nn.Module):def __init__(self, config, vis):super(Block, self).__init__()self.hidden_size = config.hidden_sizeself.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn = Mlp(config)self.attn = Attention(config, vis)def forward(self, x):h = xx = self.attention_norm(x)x, weights = self.attn(x)x = x + hh = xx = self.ffn_norm(x)x = self.ffn(x)x = x + hreturn x, weightsdef load_from(self, weights, n_block):ROOT = f"Transformer/encoderblock_{n_block}"with torch.no_grad():query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size,self.hidden_size).t()key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size,self.hidden_size).t()out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size,self.hidden_size).t()query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)self.attn.query.weight.copy_(query_weight)self.attn.key.weight.copy_(key_weight)self.attn.value.weight.copy_(value_weight)self.attn.out.weight.copy_(out_weight)self.attn.query.bias.copy_(query_bias)self.attn.key.bias.copy_(key_bias)self.attn.value.bias.copy_(value_bias)self.attn.out.bias.copy_(out_bias)mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()self.ffn.fc1.weight.copy_(mlp_weight_0)self.ffn.fc2.weight.copy_(mlp_weight_1)self.ffn.fc1.bias.copy_(mlp_bias_0)self.ffn.fc2.bias.copy_(mlp_bias_1)self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

Attention对应图中的MSA部分,num_heads即为多头注意力机制的数量,attention_head_size为每个注意力机制的输出大小。Multihead self-attention 就是采用多个注意力机制来预测,但实现时并不是采用循环来实现多次,由于每个注意力机制采用相同的策略,他们只存在学习到的参数的差异,所以可以直接学习一个大的参数矩阵,我的理解如下图所示:
在这里插入图片描述

class Attention(nn.Module):def __init__(self, config, vis):super(Attention, self).__init__()self.vis = visself.num_attention_heads = config.transformer["num_heads"]self.attention_head_size = int(config.hidden_size / self.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = Linear(config.hidden_size, self.all_head_size)self.key = Linear(config.hidden_size, self.all_head_size)self.value = Linear(config.hidden_size, self.all_head_size)self.out = Linear(config.hidden_size, config.hidden_size)self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])self.softmax = Softmax(dim=-1)def transpose_for_scores(self, x):# new_x_shape (B, n_patch, num_attention_heads, attention_head_size)new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)# view()方法主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensorx = x.view(*new_x_shape)# permute可以对任意高维矩阵进行转置,transpose只能操作2D矩阵的转置return x.permute(0, 2, 1, 3)  # return (B, num_attention_heads, n_patch, attention_head_size)def forward(self, hidden_states):# hidden_states (B, n_patch, hidden)# mixed_*  (B, n_patch, all_head_size)mixed_query_layer = self.query(hidden_states)mixed_key_layer = self.key(hidden_states)mixed_value_layer = self.value(hidden_states)query_layer = self.transpose_for_scores(mixed_query_layer)key_layer = self.transpose_for_scores(mixed_key_layer)value_layer = self.transpose_for_scores(mixed_value_layer)# torch.matmul矩阵相乘# key_layer.transpose(-1, -2): (B, num_attention_heads, attention_head_size, n_patch)# attention_scores: (B, num_attention_heads, n_patch, n_patch)attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_scores = attention_scores / math.sqrt(self.attention_head_size)attention_probs = self.softmax(attention_scores)weights = attention_probs if self.vis else Noneattention_probs = self.attn_dropout(attention_probs)# context_layer (B, num_attention_heads, n_patch, attention_head_size)context_layer = torch.matmul(attention_probs, value_layer)# context_layer (B, n_patch, num_attention_heads, attention_head_size)# contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形context_layer = context_layer.permute(0, 2, 1, 3).contiguous()# new_context_layer_shape (B, n_patch,all_head_size)new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)attention_output = self.out(context_layer)# attention_output (B, n_patch,hidden_size)# 小细节 attention_head_size = int(hidden_size / num_attention_heads),all_head_size = num_attention_heads * attention_head_size# 所以应该满足hidden_size能被num_attention_heads整除attention_output = self.proj_dropout(attention_output)return attention_output, weights

Mlp也就是一个前馈神经网络

class Mlp(nn.Module):"""Multi-Layer Perceptron: 多层感知器"""def __init__(self, config):super(Mlp, self).__init__()self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)self.act_fn = ACT2FN["gelu"]self.dropout = Dropout(config.transformer["dropout_rate"])self._init_weights()def _init_weights(self):# nn.init.xavier_uniform_初始化权重,避免深度神经网络训练过程中的梯度消失和梯度爆炸问题nn.init.xavier_uniform_(self.fc1.weight)nn.init.xavier_uniform_(self.fc2.weight)# nn.init.normal_是正态初始化函数nn.init.normal_(self.fc1.bias, std=1e-6)nn.init.normal_(self.fc2.bias, std=1e-6)def forward(self, x):x = self.fc1(x)x = self.act_fn(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return x

至此,Transformer所调用的模块结束了。


DecoderCup 对对应图片向上解码的部分:
在这里插入图片描述

在forward函数中的

B, n_patch, hidden = hidden_states.size()  # hidden_states: (B, n_patch, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)  # x: (B, hidden, n_patch)
x = x.contiguous().view(B, hidden, h, w)  # x: (B, hidden, h, w)
x = self.conv_more(x)  # (B, hidden, h, w) ===> (B, 512, h', w')

将Transformer的输出(B, n_patch, hidden),先转化为(B, hidden, h, w),其中 h , w = n _ p a t c h = H 16 = W 16 h,w = \sqrt{n\_patch} = \frac{H}{16}= \frac{W}{16} h,w=n_patch =16H=16W ,即:
在这里插入图片描述
然后通过卷积操作conv_more得到(512, hidden, h, w):
在这里插入图片描述

class DecoderCup(nn.Module):def __init__(self, config):super().__init__()self.config = confighead_channels = 512self.conv_more = Conv2dReLU(config.hidden_size,head_channels,kernel_size=3,padding=1,use_batchnorm=True,)decoder_channels = config.decoder_channels  # decoder_channels (256, 128, 64, 16)in_channels = [head_channels] + list(decoder_channels[:-1])  # in_channels = [512, 256, 128, 64]out_channels = decoder_channels# config.n_skip = 3if self.config.n_skip != 0:skip_channels = self.config.skip_channels  # config.skip_channels = [512, 256, 64, 16]for i in range(4 - self.config.n_skip):  # re-select the skip channels according to n_skipskip_channels[3 - i] = 0  # ===》skip_channels = [512, 256, 64, 0]else:skip_channels = [0, 0, 0, 0]# in_channels = [512, 256, 128, 64] out_channels = (256, 128, 64, 16)blocks = [DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)]self.blocks = nn.ModuleList(blocks)def forward(self, hidden_states, features=None):B, n_patch, hidden = hidden_states.size()  # hidden_states: (B, n_patch, hidden)h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))x = hidden_states.permute(0, 2, 1)  # x: (B, hidden, n_patch)x = x.contiguous().view(B, hidden, h, w)  # x: (B, hidden, h, w)x = self.conv_more(x)  # (B, hidden, h, w) ===> (B, 512, h, w)for i, decoder_block in enumerate(self.blocks):if features is not None:skip = features[i] if (i < self.config.n_skip) else Noneelse:skip = Nonex = decoder_block(x, skip=skip)return x

DecoderBlock就是逐层向上解码的过程,首先通过插值上采样UpsamplingBilinear2d扩大H和W,随后与对应的feature进行拼接后进行卷积,即:
在这里插入图片描述

class DecoderBlock(nn.Module):def __init__(self,in_channels,out_channels,skip_channels=0,use_batchnorm=True,):super().__init__()self.conv1 = Conv2dReLU(in_channels + skip_channels,out_channels,kernel_size=3,padding=1,use_batchnorm=use_batchnorm,)self.conv2 = Conv2dReLU(out_channels,out_channels,kernel_size=3,padding=1,use_batchnorm=use_batchnorm,)self.up = nn.UpsamplingBilinear2d(scale_factor=2)def forward(self, x, skip=None):x = self.up(x)if skip is not None:x = torch.cat([x, skip], dim=1)x = self.conv1(x)x = self.conv2(x)return x

SegmentationHead对应于图像分割部分:
在这里插入图片描述
nn.Identity()不对输入进行任何操作,常在分类任务中替换最后一层,得到分类前得到的特征,常用于迁移学习,用法举例:

model = models.resnet18()
# replace last linar layer with nn.Identity
model.fc = nn.Identity()# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)
> torch.Size([1, 512])

SegmentationHead模块:

class SegmentationHead(nn.Sequential):def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()super().__init__(conv2d, upsampling)

最后是ResNetV2模块,该模块在vit_seg_modeling_resnet_skip文件中,对应图片中的:
在这里插入图片描述
该模块的相关包及其工具函数:

import mathfrom os.path import join as pjoin
from collections import OrderedDictimport torch
import torch.nn as nn
import torch.nn.functional as Fdef np2th(weights, conv=False):"""Possibly convert HWIO to OIHW."""if conv:weights = weights.transpose([3, 2, 0, 1])return torch.from_numpy(weights)class StdConv2d(nn.Conv2d):def forward(self, x):w = self.weightv, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)w = (w - m) / torch.sqrt(v + 1e-5)return F.conv2d(x, w, self.bias, self.stride, self.padding,self.dilation, self.groups)def conv3x3(cin, cout, stride=1, groups=1, bias=False):return StdConv2d(cin, cout, kernel_size=3, stride=stride,padding=1, bias=bias, groups=groups)def conv1x1(cin, cout, stride=1, bias=False):return StdConv2d(cin, cout, kernel_size=1, stride=stride,padding=0, bias=bias)
class ResNetV2(nn.Module):"""Implementation of Pre-activation (v2) ResNet mode."""def __init__(self, block_units, width_factor):super().__init__()width = int(64 * width_factor)self.width = widthself.root = nn.Sequential(OrderedDict([('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),('gn', nn.GroupNorm(32, width, eps=1e-6)),('relu', nn.ReLU(inplace=True)),# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))]))self.body = nn.Sequential(OrderedDict([('block1', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],))),('block2', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],))),('block3', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],))),]))def forward(self, x):features = []b, c, in_size, _ = x.size()x = self.root(x)features.append(x)x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)for i in range(len(self.body)-1):x = self.body[i](x)right_size = int(in_size / 4 / (i+1))if x.size()[2] != right_size:pad = right_size - x.size()[2]assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]else:feat = xfeatures.append(feat)x = self.body[-1](x)return x, features[::-1]
class PreActBottleneck(nn.Module):"""Pre-activation (v2) bottleneck block."""def __init__(self, cin, cout=None, cmid=None, stride=1):super().__init__()cout = cout or cincmid = cmid or cout//4self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)self.conv1 = conv1x1(cin, cmid, bias=False)self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)self.conv3 = conv1x1(cmid, cout, bias=False)self.relu = nn.ReLU(inplace=True)if (stride != 1 or cin != cout):# Projection also with pre-activation according to paper.self.downsample = conv1x1(cin, cout, stride, bias=False)self.gn_proj = nn.GroupNorm(cout, cout)def forward(self, x):# Residual branchresidual = xif hasattr(self, 'downsample'):residual = self.downsample(x)residual = self.gn_proj(residual)# Unit's branchy = self.relu(self.gn1(self.conv1(x)))y = self.relu(self.gn2(self.conv2(y)))y = self.gn3(self.conv3(y))y = self.relu(residual + y)return ydef load_from(self, weights, n_block, n_unit):conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])self.conv1.weight.copy_(conv1_weight)self.conv2.weight.copy_(conv2_weight)self.conv3.weight.copy_(conv3_weight)self.gn1.weight.copy_(gn1_weight.view(-1))self.gn1.bias.copy_(gn1_bias.view(-1))self.gn2.weight.copy_(gn2_weight.view(-1))self.gn2.bias.copy_(gn2_bias.view(-1))self.gn3.weight.copy_(gn3_weight.view(-1))self.gn3.bias.copy_(gn3_bias.view(-1))if hasattr(self, 'downsample'):proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])self.downsample.weight.copy_(proj_conv_weight)self.gn_proj.weight.copy_(proj_gn_weight.view(-1))self.gn_proj.bias.copy_(proj_gn_bias.view(-1))

由于只有在hybrid模式下才用到这部分的代码,所以目前并没有去了解为什么采用StdConv2d和GroupNorm,后面再去ViT里面找答案吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/63653.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++ 派生类的析构函数

派生类的析构函数的功能是在该类对象消亡之前进行的一些必要的清理工作。析构函数没有类型&#xff0c;也没有参数。 在派生过程中&#xff0c;基类的析构函数也不可以继承下来&#xff0c;如果需要析构的话&#xff0c;就要在派生类中声明新的析构函数。派生类析构函数的声明…

户外组网摆脱布线困扰,工业5G网关实现无人值守、远程实时监控

在物联网通信技术发达的2023&#xff0c;网络覆盖对所及之处的全面覆盖&#xff0c;科技发展的促使下很多高危户外场景也在思考如何利用无线技术提高人员安全及现场无人化管理。 煤矿是我们国家不可缺少的重要能源&#xff0c;其开采过程的危险系数也是众所皆知的&#xff0c;…

HbuildX生成安卓签名证书

HbuildX生成安卓签名证书 安装和配置JRE环境 根据此链接安装和配置JRE环境 生成签名证书 keytool -genkey -alias testalias -keyalg RSA -keysize 2048 -validity 36500 -keystore test.keystoretestalias是证书别名&#xff0c;可修改为自己想设置的字符&#xff0c;建议…

策略模式【Strategy Pattern】

刘备要到江东娶老婆了&#xff0c;走之前诸葛亮给赵云&#xff08;伴郎&#xff09;三个锦囊妙计&#xff0c;说是按天机拆开解决棘手问题&#xff0c; 嘿&#xff0c;还别说&#xff0c;真是解决了大问题&#xff0c;搞到最后是周瑜陪了夫人又折兵呀&#xff0c;那咱们先看看…

深度学习(37)—— 图神经网络GNN(2)

深度学习&#xff08;37&#xff09;—— 图神经网络GNN&#xff08;2&#xff09; 这一期主要是一些简单示例&#xff0c;针对不同的情况&#xff0c;使用的数据都是torch_geometric的内置数据集 文章目录 深度学习&#xff08;37&#xff09;—— 图神经网络GNN&#xff08…

volte端到端问题分析(一)

1、MME专载保持功能验证 **描述&#xff1a;**当无线环境较差时&#xff0c;有可能由于“Radio_Connection_with_UE_Lost” 原因造成的VoLTE通话掉话&#xff0c;如果UE发生RRC重建成功&#xff0c;手机将不会掉话。 对MME1202进行功能验证&#xff1a;开启后&#xff0c;MME专…

在时间和频率域中准确地测量太阳黑子活动及使用信号处理工具箱(TM)生成广泛的波形,如正弦波、方波等研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

喜报!诚恒科技与赛时达科技达成BI金蝶云星空项目合作

随着全球数字化浪潮轰轰烈烈袭来&#xff0c;仅仅凭借手工处理的方式难以在庞大的数据海洋中精准获取信息、把握市场需求、了解目标用户&#xff0c;为企业创新提供强有力的支持。深圳赛时达科技有限公司&#xff08;简称赛时达科技&#xff09;希望通过数字化转型实现从手工处…

Blazor 调试控制台

文章目录 设置 设置 Blazor项目启动之后&#xff0c;有好几种项目设置&#xff0c;我其实想要这一种控制台 直接Console.log就行了 public void LoginBtn(){Console.WriteLine("登录");//navigationManager.NavigateTo("/index");}

Unity-Linux部署WebGL项目MIME类型添加

在以往的文章中有提到过使用IIS部署WebGL添加MIME类型使WebGL项目在浏览器中能够正常加载&#xff0c;那么如果咱们做的是商业项目&#xff0c;往往是需要部署在学校或者云服务器上面的&#xff0c;大部分情况下如果项目有接口或者后台管理系统&#xff0c;后台基本都会使用Lin…

数据结构(一):顺序表详解

在正式介绍顺序表之前&#xff0c;我们有必要先了解一个名词&#xff1a;线性表。 线性表&#xff1a; 线性表是&#xff0c;具有n个相同特性的数据元素的有限序列。常见的线性表&#xff1a;顺序表、链表、栈、队列、数组、字符串... 线性表在逻辑上是线性结构&#xff0c;但…

2021年09月 C/C++(一级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题:数字判断 输入一个字符,如何输入的字符是数字,输出yes,否则输出no 输入 一个字符 输出 如何输入的字符是数字,输出yes,否则输出no 样例1输入 样例1输入 5 样例1输出 yes 样例2输入 A 样例2输出 no 下面是一个使用C语言编写的数字判断程序的示例代码,根据输入的字符…