分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):super().__init__()self.ps = patch_sizeself.c1 = Conv2D(in_c, out_c)encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)self.r1 = residual_block(out_c, out_c)def forward(self, inputs):x = self.c1(inputs)b, c, h, w = x.shapenum_patches = (h*w)//(self.ps**2)x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))x = self.te(x)x = torch.reshape(x, (b, c, h, w))x = self.c2(x)s = self.c3(inputs)x = self.relu(x + s)x = self.r1(x)return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):def __init__(self, in_c, out_c):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c),nn.LeakyReLU(negative_slope=0.1, inplace=True),nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c))self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),nn.BatchNorm2d(out_c))self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)def forward(self, inputs):x = self.conv(inputs)s = self.shortcut(inputs)return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):def __init__(self):super().__init__()""" Encoder """backbone = resnet50()self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)self.layer2 = backbone.layer2self.layer3 = backbone.layer3self.layer4 = backbone.layer4self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)""" Decoder """self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)self.r1 = residual_transformer_block(64+64, 64, dim=64)self.r2 = residual_transformer_block(64+64, 64, dim=256)self.r3 = residual_block(64+64, 64)""" Classifier """self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)def forward(self, inputs):""" Encoder """x0 = inputsx1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]e1 = self.e1(x1)e2 = self.e2(x2)e3 = self.e3(x3)e4 = self.e4(x4)""" Decoder """x = self.up(e4)x = torch.cat([x, e3], axis=1)x = self.r1(x)x = self.up(x)x = torch.cat([x, e2], axis=1)x = self.r2(x)x = self.up(x)x = torch.cat([x, e1], axis=1)x = self.r3(x)x = self.up(x)""" Classifier """outputs = self.outputs(x)return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/526725.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Masked Generative Distillation(MGD)2022年ECCV

Masked Generative Distillation(MGD)2022年ECCV 摘要 **目前的蒸馏算法通常通过模仿老师的输出来提高学生的表现。本文表明,教师还可以通过引导学生特征恢复来提高学生的代表性。从这个角度来看,我们提出的掩模生成蒸馏&#x…

先进电机技术 —— 高速电机与低速电机

一、背景 高速电机是指转速远高于一般电机的电动机,通常其转速在每分钟几千转至上万转甚至几十万转以上。这类电机具有功率密度高、响应速度快、输出扭矩大等特点,在航空航天、精密仪器、机器人、电动汽车、高端装备制造等领域有着广泛的应用。 高速电…

无人机生态环境监测、图像处理与GIS数据分析

构建“天空地”一体化监测体系是新形势下生态、环境、水文、农业、林业、气象等资源环境领域的重大需求,无人机生态环境监测在一体化监测体系中扮演着极其重要的角色。通过无人机航空遥感技术可以实现对地表空间要素的立体观测,获取丰富多样的地理空间数…

跨平台大小端判断与主机节序转网络字节序使用

1.macOS : 默认使用小端 ,高位使用高地址,转换为网络字节序成大端 #include <iostream> #include <arpa/inet.h> int main() {//大小端判断union{short s;char c[sizeof(short)];}un;un.s = 0x0102;printf("低地址:%d,高地址:%d\n",un.c[0],un.c[1]);if …

灯塔批量添加指纹信息

攻击地址https://github.com/loecho-sec/ARL-Finger-ADD 指纹文件https://github.com/lemonlove7/EHole_magic/blob/main/finger.json 成功导入可以看到

基于Springboot的在线租房和招聘平台(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的在线租房和招聘平台&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结…

采用 Amazon DocumentDB 和 Amazon Bedrock 上的 Claude 3 构建游戏行业产品推荐

前言 大语言模型&#xff08;LLM&#xff09;自面世以来即展示了其创新能力&#xff0c;但 LLM 面临着幻觉等挑战。如何通过整合外部数据库的知识&#xff0c;检索增强生成&#xff08;RAG&#xff09;已成为通用和可行的解决方案。这提高了模型的准确性和可信度&#xff0c;特…

【个人开发】llama2部署实践(三)——python部署llama服务(基于GPU加速)

1.python环境准备 注&#xff1a;llama-cpp-python安装一定要带上前面的参数安装&#xff0c;如果仅用pip install装&#xff0c;启动服务时并没将模型加载到GPU里面。 # CMAKE_ARGS"-DLLAMA_METALon" FORCE_CMAKE1 pip install llama-cpp-python CMAKE_ARGS"…

PyTorch搭建LeNet训练集详细实现

一、下载训练集 导包 import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as npToTensor()函数&#xff1a; 把图像…

【脚本玩漆黑的魅影】全自动刷努力值

文章目录 原理全部代码 原理 全自动练级&#xff0c;只不过把回城治疗改成吃红苹果。 吃一个可以打十下&#xff0c;背包留10个基本就练满了。 吃完会自动停止。 if img.getpixel(data_attack[0]) data_attack[1] or img.getpixel(data_attack_2[0]) data_attack_2[1]: # …

RESTful API关键部分组成和构建web应用程序步骤

RESTful API是一种基于HTTP协议的、符合REST原则的应用程序接口。REST&#xff08;Representational State Transfer&#xff09;是一种软件架构风格&#xff0c;用于设计网络应用程序的通信模式。 一个RESTful API由以下几个关键部分组成&#xff1a; 资源&#xff08;Resour…

关于天线综合4(伍德沃德——罗森取样法)

伍德沃德——罗森取样法 就是在各个点指定方向图的值&#xff0c;对其方向图取样 主要就是将线源电流分布分解成一组等幅度、线性相位的源的和 求出对应电流分量方向图 中心位于wwn 最大值为an&#xff0c; 其中wn控制该分量方向图最大值的位置&#xff0c;an控制分量方向图的幅…