【Pytorch】学习记录分享6——PyTorch经典网络 ResNet与手写体识别

【Pytorch】学习记录分享5——PyTorch经典网络 ResNet

      • 1. ResNet (残差网络)基础知识
      • 2. 感受野
      • 3. 手写体数字识别
        • 3. 0 数据集(训练与测试集)
        • 3. 1 数据加载
        • 3. 2 函数实现:
        • 3. 3 训练及其测试:

1. ResNet (残差网络)基础知识

图1 56层error比20层error高,提出ResNet (残差网络)的方案
在这里插入图片描述

网络效果:

在这里插入图片描述
网络结构:
在这里插入图片描述
在这里插入图片描述

2. 感受野

在这里插入图片描述
在这里插入图片描述

3. 手写体数字识别

3. 0 数据集(训练与测试集)

mnist 用于手写体训练与测试,这里包含完整的链接

3. 1 数据加载
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms 
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
### 首先读取数据
# - 分别构建训练集和测试集(验证集)
# - DataLoader来迭代取数据# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

在这里插入图片描述

3. 2 函数实现:
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.out = nn.Linear(32 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)  output = self.out(x)return output# 准确率作为评估标准
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 
3. 3 训练及其测试:
# 训练网络模型
# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()  # 将模型设置为训练模式output = net(data)  # 使用模型进行前向传播loss = criterion(output, target)  # 计算损失optimizer.zero_grad()  # 梯度清零loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新参数right = accuracy(output, target)  # 计算当前批次的准确率train_rights.append(right)  # 将准确率保存起来if batch_idx % 500 == 0:  # 每500个批次进行一次验证net.eval()  # 将模型设置为评估模式val_rights = []  # 存储验证集的准确率for (data, target) in test_loader:  # 在测试集上进行验证output = net(data)  # 使用模型进行前向传播right = accuracy(output, target)  # 计算验证集上的准确率val_rights.append(right)  # 将准确率保存起来#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))  # 计算训练集准确率的分子和分母val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))  # 计算验证集准确率的分子和分母print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))  # 打印当前进度和准确率信息

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/292307.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Gbase8c认证考试课后题

Gbase8c认证考试课后题 第一次练习 第一题 第二题 第三题 第四题 第五题 第六题 第七题 第八题 第九题 第十题 第十一题 第十二题 第十三题 第二次练习 第一题 第二题 第三题 第四题 第五题 第三次练习 第一题 第二题 第三题 第四题 第五题 第四次练习 第一题 第二题 第三…

关于外贸包裹的那些事

大早晨收到一个客户留言,询问能不能看一下他的货物包裹被送到了哪里,然后客户可以安排他的代理人联系去取包裹,我心里的第一感觉是难道包裹丢失了? 于是赶紧起来查看物流单号,单号显示早在半个多月前已经被他的国内代…

万界星空电机行业MES/电机mes

万界星空科技电机行业生产管理MES系统的主要功能: 1、基础数据管理 包含车间的产品材料清单管理,产品的工艺信息(工艺、定额、工厂行程)管理和资源信息(关键设备信息)等等。提供产品配套信息维护及查询&a…

conda环境下执行conda命令提示无法识别解决方案

1 问题描述 win10环境命令行执行conda命令,报命令无法识别,错误信息如下: PS D:\code\cv> conda activate pt conda : 无法将“conda”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写,如果包括路径&a…

vue3封装年份组件

ant框架年份组件 看了ant框架针对于年份不能自定义插槽内容所以放弃用ant框架年份组件,自定义插槽内容是想实现年份下方可以加小圆点的需求,因加小圆点需求必须实现,决定自己封装组件来实现需求,自己实现的效果呢类似于ant年份控件…

【数论】约数

试除法求约数 时间复杂度 O(sqrt(n))。 核心思路是求到较小的约数时,将其对应的较大约数也可以直接求出来, 例如:a/bc,b是a的余数,c也是a的余数 ps:注意bc的情况,要注意去重 void solve() …

数据库客户案例:每个物种都需要一个数据库!

1、GERDH——花卉多组学数据库 项目名称:GERDH:花卉多组学数据库 链接地址:https://dphdatabase.com 项目描述:GERDH包含了来自150多种园艺花卉植物种质的 12961个观赏植物。将不同花卉植物转录组学、表观组学等数据进行比较&am…

水利水库大坝安全监测参数详解

变形监测 变形监测是指对工程结构或地质环境中的变形进行实时或定期的测量与监测的过程。变形监测的目的是为了及时了解结构或环境的变形情况,评估其稳定性和安全性,并采取相应的措施来预防灾害和保护人民生命财产安全。 变形监测主要包括的内容有&#…

Jenkins 构建触发器指南

目录 触发远程构建 (例如,使用脚本) 描述 配置步骤 安全令牌 在其他项目构建完成后触发构建 描述 配置步骤 定时触发构建 描述 配置步骤 GitHub钩子触发GITScm轮询 描述 配置步骤 Poll SCM - 轮询版本控制系统 描述 触发远程构建 (例如,使…

Nature Commun.:物理所揭示原子分辨下的铁电涡旋畴的原位力学转变过程

通过复杂的晶格-电荷相互作用形成的铁电涡旋畴在纳米电子器件研发中具有巨大的应用潜力。实际应用中,如何在外界激励下操纵这类结构的拓扑状态是至关重要的。中国科学院物理研究所/北京凝聚态物理国家研究中心表面物理国家重点实验室与北京大学、湘潭大学和美国宾夕…

surface pro 如何调用和显示软键盘/触摸键盘

长按任务栏-勾选☑显示触摸键盘 右下角就会出现软键盘按钮 如果点选输入栏自动弹出或调用软键盘或触控键盘,按如下设定 设置-输入-不处于平板电脑模式且未连接键盘时显示触摸键盘 (开关打开即可)

【easy-ES使用】1.基础操作:增删改查、批量操作、分词查询、聚合处理。

easy-es、elasticsearch、分词器 与springboot 结合的代码我这里就不放了,我这里直接是使用代码。 基础准备: 创建实体类: Data // 索引名 IndexName("test_jc") public class TestJcES {// id注解IndexId(type IdType.CUSTOMI…