PyTorch从零开始实现ResNet

文章目录

    • 代码实现
    • 参考

代码实现

本文实现 ResNet原论文 Deep Residual Learning for Image Recognition 中的50层,101层和152层残差连接。
在这里插入图片描述
代码中使用基础残差块这个概念,这里的基础残差块指的是上图中红色矩形圈出的内容:从上到下分别使用3, 4, 6, 3个基础残差块,每个基础残差块由三个卷积层组成,核大小分别为1x1, 3x3, 1x1 。

残差连接的结构

在这里插入图片描述

复现代码如下:

import torch
import torch.nn as nn# 基础残差块,后面ResNet要多次重复使用该块
class block(nn.Module):def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):super(block, self).__init__()self.expansion = 4  self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)self.relu = nn.ReLU()self.identity_downsample = identity_downsampledef forward(self, x):identity = xx = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)x = self.conv3(x)x = self.bn3(x)# x 和 identity形状一致,才能相加if self.identity_downsample is not None:identity = self.identity_downsample(identity)x += identityx = self.relu(x)return xclass ResNet(nn.Module):def __init__(self, block, layers, image_channels, num_classes):super(ResNet, self).__init__()# 初始化的层self.in_channels = 64self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# ResNet layersself.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512*4, num_classes)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.reshape(x.shape[0], -1)x = self.fc(x)return x# 核心函数:调用block基础残差块,构造ResNet的每一层def _make_layer(self, block, num_residual_blocks, out_channels, stride):identity_downsample = Nonelayers = []# 修改形状,使得残差连接可以相加:x + identityif stride != 1 or self.in_channels != out_channels * 4:identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels*4, kernel_size=1,stride=stride),                                               nn.BatchNorm2d(out_channels*4))layers.append(block(self.in_channels, out_channels, identity_downsample, stride))self.in_channels = out_channels * 4for i in range(num_residual_blocks - 1):layers.append(block(self.in_channels, out_channels)) # 256 -> 64, 64*4(256) againreturn nn.Sequential(*layers)# 构造ResNet50层:默认图像通道3,分类类别为1000
def resnet50(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)# 构造ResNet101层  
def resnet101(img_channels=3, num_classes=1000):return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)# 构造ResNet152层  
def resnet152(img_channels=3, num_classes=1000):return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)# 测试输出y的形状是否满足1000类
def test():net = resnet152()x = torch.randn(2, 3, 224, 224)y = net(x)print(y.shape) # [2, 1000]test()

参考

[1] Deep Residual Learning for Image Recognition
[2] https://www.youtube.com/watch?v=DkNIBBBvcPs&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/69287.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

插入、希尔、归并、快速排序(java实现)

目录 插入排序 希尔排序 归并排序 快速排序 插入排序 排序原理: 1.把所有元素分为两组,第一组是有序已经排好的,第二组是乱序未排序。 2.将未排序一组的第一个元素作为插入元素,倒序与有序组比较。 3.在有序组中找到比插入…

docker实现Nginx

文章目录 1.docker 安装2.docker环境实现Nginx 1.docker 安装 1.使用环境为红帽8.1,添加源 yum-config-manager --add-repo https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo2.安装 yum install docker-ce docker-ce-cli containerd.io显示出错 Docker C…

机器学习深度学习——seq2seq实现机器翻译(数据集处理)

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——从编码器-解码器架构到seq2seq(机器翻译) 📚订阅专栏:机…

一波七折之寻找遗失的容器ip

最近业务有个需求,需要在宿主机上获取容器ip。获取ip这就不是个事,平凡而普通。 常用手段获取ip 常用的命令ifconfig, ip等等首先被pass,因为宿主机环境未知,这些命令可能没有。 那么只能从系统api着手。getifaddrs可以获取网卡…

WSL2 Ubuntu子系统安装cuda+cudnn+torch

文章目录 前言一、安装cudncudnn安装pytorch 前言 确保Windows系统版本高于windows10 21H2或Windows11,然后在Windows中将显卡驱动升级到最新即可,WSL2已支持对显卡的直接调用。 一、安装cudncudnn 配置cuda环境,WSL下的Ubuntu子系统的cu…

Redis心跳检测

在命令传播阶段&#xff0c;从服务器默认会以每秒一次的频率&#xff0c;向主服务器发送命令&#xff1a; REPLCON FACK <rep1 ication_ offset>其中replication_offset是从服务器当前的复制偏移量。 发送REPLCONF ACK命令对于主从服务器有三个作用&#xff1a; 检测主…

Linux交叉编译opencv并移植ARM端

Linux交叉编译opencv并移植ARM端 - 知乎 一、安装交叉编译器 目标平台为arm7l&#xff0c;此为32位ARM架构&#xff0c;要安装合适的编译器 sudo apt install arm-linux-gnueabihf-gcc sudo apt install arm-linux-gnueabihf-g注意&#xff1a;64位ARM架构的编译器与32位ARM架…

macOS CLion 使用 bits/stdc++.h

macOS 下 CLion 使用 bits/stdc.h 头文件 terminal运行 brew install gccCLion里配置 -D CMAKE_CXX_COMPILER/usr/local/bin/g-11

MySQL卸载并重装指定版本

MySQL卸载并重装制定版本 学习新的项目&#xff0c;发现之前的Navicat已经失去了与现有MySQL的链接&#xff0c;而且版本也不适合&#xff0c;为了少走弯路&#xff0c;准备直接重装相应版本的MySQL 卸载现有MySQL 停止windows的MySQL服务&#xff0c;【windowsR】打开运行框…

CI/CD流水线实战

不知道为什么&#xff0c;现在什么技术都想学&#xff0c;因为我觉得我遇到了技术的壁垒&#xff0c;大的项目接触不到&#xff0c;做的项目一个字辣*。所以&#xff0c;整个人心浮气躁&#xff0c;我已经得通过每天的骑行和长跑缓解这种浮躁了。一个周末&#xff0c;我再次宅在…

使用phpstorm开发调试thinkphp

1.环境准备 1.开发工具下载&#xff1a;PhpStorm: PHP IDE and Code Editor from JetBrains 2.PHP下载&#xff1a;PHP: Downloads 3. PHP扩展&#xff1a;PECL :: Package search 4.用与调试的xdebug模块&#xff1a; Xdebug: Downloads xdebug模块&#xff0c;如果是php8以…

TypeScript 语法

环境搭建 以javascript为基础构建的语言&#xff0c;一个js的超集&#xff0c;可以在任何支持js的平台中执行&#xff0c;ts扩展了js并且添加了类型&#xff0c;但是ts不能被js解析器直接执行&#xff0c;需要编译器编译为js文件&#xff0c;然后引入到 html 页面使用。 ts增…