霹雳吧啦Wz《pytorch图像分类》-p5ResNet网络

《pytorch图像分类》p5ResNet网络结构

  • 1 网络中的亮点
    • 1.1 超深的网络结构
    • 1.2 residual模块
    • 1.3 Batch Normalization
    • 1.4 迁移学习简介
  • 2 模块类代码
    • 2.1 BasicBlock(18 & 32 layers)
    • 2.2 Bottleneck(50 & 101 & 152layers)
    • 2.3 ResNet
  • 3 课程代码
    • 3.1 modle.py
    • 3.2 train.py
    • 3.3 predict.py

1 网络中的亮点

论文连接:Deep Residual Learning for Image Recognition

1.1 超深的网络结构

AlexNet,VggNet等之前学习的网络当中,深度只有十几层到二十几层,如果层数加深,会产生梯度消失或者梯度爆炸现象,错误率反而会增高(如左图),ResNet是添加了残差模块之后才能搭建足够深的网络模型(如右图)。
在这里插入图片描述
网络深度突破了1000层
在这里插入图片描述

1.2 residual模块

残差模块(Residual module)是一种常用于深度卷积神经网络(CNN)中的构建块,用于解决梯度消失或爆炸的问题,并帮助网络更有效地进行训练。
残差模块的核心思想是通过引入跳跃连接shortcut connections(论文中写的术语)来绕过一些卷积层,将输入直接添加到输出中

在这里插入图片描述
当输入特征图和输出特征图在维度上不匹配时,可以通过1x1卷积层调整维度,以确保它们可以相加。
在这里插入图片描述
Down sampling is performed by conv3_1, conv4_1, and conv5_1 with a stride of 2.
下采样由conv3_1、conv4_1和conv5_1执行,步长为2。
在这里插入图片描述

1.3 Batch Normalization

Batch Normalization(批归一化)是一种在深度神经网络中用于提高训练速度和稳定性的技术。它通过对每个小批次(batch)的输入数据进行归一化处理,以使输入的均值和方差保持在稳定范围内

举个例子:我们在烹饪中做蛋糕的过程。每次用相同的配方烤蛋糕时,我们都会按照配方中的比例加入相同的材料,所以无论做几个蛋糕,它们都会有相似的味道和口感。
类似地,神经网络的训练过程中,我们可以将输入数据看作是蛋糕的材料。批归一化通过对每个批次中的数据进行标准化,使得它们具有零均值和单位方差。这相当于将每个批次的数据转化为类似的“蛋糕材料”,这样神经网络在处理不同的批次时,每个批次的数据都可以在相似的范围内变化。这样有利于网络的学习和训练。
(我还得找资料更深一步去学习)

1.4 迁移学习简介

  1. 能够快速的训练出一个理想的结果
  2. 当数据集比较小的时候也能训练出理想的效果

常见的迁移学习方式:

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层

2 模块类代码

2.1 BasicBlock(18 & 32 layers)

针对18层和34层的残差结构
expansion是指每个残差块的输出通道数相对于输入通道数的倍数,主分支所采用的卷积核的个数有没有发生变化,expansion=1也就是卷积核个数没有发生变化。
在这里插入图片描述

downsample=None对应的是虚线的结构,比如conv3、conv4和conv5的第一层,见1.2residual结构的第二张图

def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):

在这里插入图片描述

2.2 Bottleneck(50 & 101 & 152layers)

针对50层、101层和152层的残差结构
expansion=4即输出通道数是输入通道数的4倍。
在这里插入图片描述

def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):

2.3 ResNet

在这里插入图片描述

class ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):

3 课程代码

3.1 modle.py

import torch.nn as nn
import torchclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):"""注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)

3.2 train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdmfrom model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 16nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = resnet34()# load pretrain weights# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthmodel_weight_path = "resnet34-pre.pth"assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))# for param in net.parameters():#     param.requires_grad = False# change fc layer structurein_channel = net.fc.in_featuresnet.fc = nn.Linear(in_channel, 5)net.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 3best_acc = 0.0save_path = './resNet34.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

3.3 predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import resnet34def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = "1.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = resnet34(num_classes=5).to(device)# load model weightsweights_path = "./resNet34.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))# predictionmodel.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

1.jpg预测的玫瑰花
在这里插入图片描述
2.jpg预测的雏菊
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/318391.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

十大性能测试工具

这篇关于“性能测试工具”的文章将按以下顺序让您了解不同的软件测试工具: 什么是性能测试?为什么我们需要性能测试?性能测试的优势性能测试的类型十大性能测试工具 什么是性能测试? 性能测试是一种软件测试,可确保…

向日葵远程工具安装Mysql的安装与配置

目录 一、向日葵远程工具安装 1.1 简介 1.2 下载地址 二、Mysql 5.7 安装与配置 2.1 简介 2.2 安装 2.3 初始化mysql服务端 2.4 启动mysql服务 2.5 登录mysql 2.6 修改密码 2.7 设置外部访问 三、思维导图 一、向日葵远程工具安装 1.1 简介 向日葵远程控制是一款用…

java每日一题——双色球系统(答案及编程思路)

前言: 打好基础,daydayup! 题目:要求如下(同时:红球每个号码不可以相同) 编程思路:1,创建一个可以录入数字的数组;2,生成一个可以随机生成数字的数组&#xf…

java代码规范(适合写程序之前先了解有助于开发协同)

目录 一、类定义 二、方法定义 三、接口定义 四、变量定义 1、命名规范: 2、类型规范: 3、常量规范: 五、static关键字 1、静态变量(类变量): 2、静态方法(类方法)&#x…

论文降重助手同义词替换功能的优化建议与实施方案

大家好,今天来聊聊论文降重助手同义词替换功能的优化建议与实施方案,希望能给大家提供一点参考。 以下是针对论文重复率高的情况,提供一些修改建议和技巧,可以借助此类工具: 标题:论文降重助手同义词替换功…

【JavaSE】string与StringBuilder和StringBuffer

区别: 不可变性: String: String 类是不可变的,一旦创建就不能被修改。对字符串的任何操作都会创建一个新的字符串对象。StringBuffer: StringBuffer 是可变的,允许对字符串进行修改,而不创建新…

jmeter线程组

特点:模拟用户,支持多用户操作;可以串行也可以并行 分类: setup线程组:初始化 类似于 unittest中的setupclass 普通线程组:字面意思 teardown线程组:环境恢复,后置处理

Linux学习第48天:Linux USB驱动试验:保持热情,保持节奏,持续学习是作为一个技术人员应有的基本素质和要求

Linux版本号4.1.15 芯片I.MX6ULL 大叔学Linux 品人间百味 思文短情长 最近更新的速度和频率大不如以前,主要原因还是自己有些懈怠了。学习是一个持续努力的过程,一旦中断,再想保持以往的状态可能要…

重定向的原理及代码演示

一、重定向的概念 客户浏览器发送http请求,当web服务器接受后发送302状态码响应及对应新的location给客 户浏览器客户浏览器发现是302响应,则自动再发送一个新的http请求,请求url是新的location地址,服务器根据此请求寻找资源并发…

Hive09_函数

HIVE函数 系统内置函数 1)查看系统自带的函数 hive> show functions;2)显示自带的函数的用法 hive> desc function upper;3)详细显示自带的函数的用法 hive> desc function extended upper;hive函数分类 1、UDF:用…

这次,数据泄露的目标受害者指向了---救护车服务公司

已停业的救护车服务遭到勒索软件攻击导致近百万人受到威胁! 此次数据泄露的目标受害者是法伦救护车服务公司,该公司是Transformative Healthcare的子公司。ALPHV勒索软件团伙声称对2023年4月下旬对Transformative Healthcare的攻击负责,并导…

深挖小白必会指针笔试题<一>

目录 引言 关键解决办法: 学会画图确定指向关系 例题一: 画图分析: 例题二: 画图分析: 例题三: 注:%x是按十六进制打印 画图分析: 例题四: 画图分析&…