pytorch学习——如何构建一个神经网络——以手写数字识别为例

目录

一.概念介绍

1.1神经网络核心组件

1.2神经网络结构示意图

1.3使用pytorch构建神经网络的主要工具

二、实现手写数字识别

2.1环境

2.2主要步骤

2.3神经网络结构

2.4准备数据

2.4.1导入模块

2.4.2定义一些超参数

2.4.3下载数据并对数据进行预处理

2.4.4可视化数据集中部分元素

 2.4.5构建模型和实例化神经网络

2.4.6训练模型

2.4.7可视化损失函数

2.4.7.1 train  loss 

 2.4.7.2 test loss

一.概念介绍

        神经网络是一种计算模型,它模拟了人类神经系统的工作方式,由大量的神经元和它们之间的连接组成。每个神经元接收一些输入信息,并对这些信息进行处理,然后将结果传递给其他神经元。这些神经元之间的连接具有不同的权重,这些权重可以根据神经网络的训练数据进行调整。通过调整权重,神经网络可以对输入数据进行分类、回归、聚类等任务。

        通俗来讲,神经网络就是设置一堆参数,初始化这堆参数,然后通过求导,知道这些参数对结果的影响,然后调整这些参数的大小。直到参数大小可以接近完美地拟合实际结果。神经网络有两个部分:正向传播和反向传播。正向传播是求值,反向传播是求出参数对结果的影响,从而调整参数。所以,神经网络:正向传播->反向传播->正向传播->反向传播……     

        比如我们要预测一个图像是不是猫。如果是猫,它的结果就是1,如果不是猫,它的结果就是0.我们现在有一堆图片,有的是猫,有的不是猫,所以它对应的标签(这个是y)是:0 1 1 0 1。而我们的预测结果可能是对的,也可能是错的,假设我们的预测结果是:0 0 1 1 0.我们有3个预测对了,有2个预测错了。那么我们的损失值是2/5。当然这么搞的话太“粗糙”了,实际上我们会有一个函数来定义损失值是什么。而且我们的预测结果也不是一个确凿的数字,而是一个概率:比如我们预测第3张图片是猫的概率是0.8,那么我们的预测结果是0.8.总之,定义了损失值(这个损失值记为J)以后,我们要让这个损失值尽可能地小。

参考:什么是神经网络? - 绯红之刃的回答 - 知乎 

1.1神经网络核心组件

        神经网络看上去挺复杂,节点多,层多,参数多,但其结构都是类似的,核心部分和组件都是相通的,确定完这些核心组件,这个神经网络也就基本确定了。

核心组件包括:

(1)层:神经网络的基础数据结构是层,层是一个数据处理模块,它接受一个或多个张量作为输入,并输出一个或多个张量,由一组可调整参数描述。

(2)模型:模型是由多个层组成的网络,用于对输入数据进行分类、回归、聚类等任务。

 

(3)损失函数:参数学习的目标函数,通过最小化损失函数来学习各种参数。损失函数是衡量模型输出结果与真实标签之间的差异的函数,目标是最小化损失函数,提高模型性能。

(4)优化器:使损失函数的值最小化。根据损失函数的梯度更新神经网络中的权重和偏置,以使损失函数的值最小化,提高模型性能和稳定性。

1.2神经网络结构示意图

 描述:多个层链接在一起构成一个模型或网络,输入数据通过这个模型转换为预测值,然后损失函数把预测值与真实值进行比较,得到损失值(损失值可以是距离、概率值等),该损失值用于衡量预测值与目标结果的匹配或相似程度,优化器利用损失值更新权重参数,从而使损失值越来越小。这是一个循环过程,损失值达到一个阀值或循环次数到达指定次数,循环结束。

1.3使用pytorch构建神经网络的主要工具

 参考:第3章 Pytorch神经网络工具箱 | Python技术交流与分享

在PyTorch中,构建神经网络主要使用以下工具:

  1. torch.nn模块:提供了构建神经网络所需的各种层和模块,如全连接层、卷积层、池化层、循环神经网络等。

  2. torch.nn.functional模块:提供了一些常用的激活函数和损失函数,如ReLU、Sigmoid、CrossEntropyLoss等。

  3. torch.optim模块:提供了各种优化器,如SGD、Adam、RMSprop等,用于更新神经网络中的权重和偏置。

  4. torch.utils.data模块:提供了处理数据集的工具,如Dataset、DataLoader等,可以方便地处理数据集、进行批量训练等操作。

这些工具之间的相互关系如下:

  1. 使用torch.nn模块构建神经网络的各个层和模块。

  2. 使用torch.nn.functional模块中的激活函数和损失函数对神经网络进行非线性变换和优化。

  3. 使用torch.optim模块中的优化器对神经网络中的权重和偏置进行更新,以最小化损失函数。

  4. 使用torch.utils.data模块中的数据处理工具对数据集进行处理,方便地进行批量训练和数据预处理。

二、实现手写数字识别

2.1环境

        实例环境使用Pytorch1.0+,GPU或CPU,源数据集为MNIST。

2.2主要步骤

(1)利用Pytorch内置函数mnist下载数据
(2)利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器
(3)可视化源数据
(4)利用nn工具箱构建神经网络模型
(5)实例化模型,并定义损失函数及优化器
(6)训练模型
(7)可视化结果

2.3神经网络结构

实验中使用两个隐含层,每层激活函数为Relu,最后使用torch.max(out,1)找出张量out最大值对应索引作为预测值。

2.4准备数据

2.4.1导入模块

import numpy as np
import torch
# 导入 pytorch 内置的 mnist 数据
from torchvision.datasets import mnist 
#导入预处理模块
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
#导入nn及优化器
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

2.4.2定义一些超参数

# 定义训练和测试时的批处理大小
train_batch_size = 64
test_batch_size = 128# 定义学习率和迭代次数
learning_rate = 0.01
num_epoches = 20# 定义优化器的超参数
lr = 0.01
momentum = 0.5
#动量优化器通过引入动量参数(Momentum),在更新参数时考虑之前的梯度信息,可以使得参数更新方向更加稳定,同时加速梯度下降的收敛速度。动量参数通常设置在0.5到0.9之间,可以根据具体情况进行调整。

2.4.3下载数据并对数据进行预处理

#定义预处理函数,这些预处理依次放在Compose函数中。
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
#下载数据,并对数据进行预处理
train_dataset = mnist.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = mnist.MNIST('./data', train=False, transform=transform)
#dataloader是一个可迭代对象,可以使用迭代器一样使用。
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

注:

①transforms.Compose可以把一些转换函数组合在一起;
②Normalize([0.5], [0.5])对张量进行归一化,这里两个0.5分别表示对张量进行归一化的全局平均值和方差。因图像是灰色的只有一个通道,如果有多个通道,需要有多个数字,如三个通道,应该是Normalize([m1,m2,m3], [n1,n2,n3])
③download参数控制是否需要下载,如果./data目录下已有MNIST,可选择False。
④用DataLoader得到生成器,这可节省内存。

2.4.4可视化数据集中部分元素

# 导入matplotlib.pyplot库,并设置inline模式
import matplotlib.pyplot as plt
%matplotlib inline# 枚举数据加载器中的一批数据
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)# 创建一个图像对象
fig = plt.figure()# 显示前6个图像和对应的标签
for i in range(6):plt.subplot(2,3,i+1)           # 将图像分成2行3列,当前位置为第i+1个plt.tight_layout()             # 自动调整子图之间的间距plt.imshow(example_data[i][0], cmap='gray', interpolation='none')  # 显示图像plt.title("Ground Truth: {}".format(example_targets[i]))          # 显示标签plt.xticks([])                 # 隐藏x轴刻度plt.yticks([])                 # 隐藏y轴刻度

注:

  1. 导入matplotlib.pyplot库,并设置inline模式,以在Jupyter Notebook中显示图像。

  2. 枚举数据加载器中的一批数据,其中test_loader是一个测试数据集加载器。

  3. 创建一个图像对象,用于显示图像和标签。

  4. 显示前6个图像和对应的标签,其中plt.subplot()用于将图像分成2行3列,plt.tight_layout()用于自动调整子图之间的间距,plt.imshow()用于显示图像,plt.title()用于显示标签,plt.xticks()和plt.yticks()用于隐藏x轴和y轴的刻度。

 2.4.5构建模型和实例化神经网络

class Net(nn.Module):"""使用sequential构建网络,Sequential()函数的功能是将网络的层组合到一起"""def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):super(Net, self).__init__()self.layer1 = nn.Sequential(nn.Linear(in_dim, n_hidden_1),nn.BatchNorm1d(n_hidden_1))self.layer2 = nn.Sequential(nn.Linear(n_hidden_1, n_hidden_2),nn.BatchNorm1d(n_hidden_2))self.layer3 = nn.Sequential(nn.Linear(n_hidden_2, out_dim))def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))x = self.layer3(x)return x#检测是否有可用的GPU,有则使用,否则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#实例化网络
model = Net(28 * 28, 300, 100, 10)
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

2.4.6训练模型

# 开始训练
losses = []
acces = []
eval_losses = []
eval_acces = []for epoch in range(num_epoches):train_loss = 0train_acc = 0model.train()#动态修改参数学习率if epoch%5==0:optimizer.param_groups[0]['lr']*=0.1for img, label in train_loader:img=img.to(device)label = label.to(device)img = img.view(img.size(0), -1)# 前向传播out = model(img)loss = criterion(out, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.item()# 计算分类的准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]train_acc += acclosses.append(train_loss / len(train_loader))acces.append(train_acc / len(train_loader))# 在测试集上检验效果eval_loss = 0eval_acc = 0# 将模型改为预测模式model.eval()for img, label in test_loader:img=img.to(device)label = label.to(device)img = img.view(img.size(0), -1)out = model(img)loss = criterion(out, label)# 记录误差eval_loss += loss.item()# 记录准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / img.shape[0]eval_acc += acceval_losses.append(eval_loss / len(test_loader))eval_acces.append(eval_acc / len(test_loader))print('epoch: {}, Train Loss: {:.4f}, Train Acc: {:.4f}, Test Loss: {:.4f}, Test Acc: {:.4f}'.format(epoch, train_loss / len(train_loader), train_acc / len(train_loader), eval_loss / len(test_loader), eval_acc / len(test_loader)))

2.4.7可视化损失函数

2.4.7.1 train  loss 

plt.title('train loss')
plt.plot(np.arange(len(losses)), losses)
plt.legend(['Train Loss'], loc='upper right')

 2.4.7.2 test loss

# 绘制测试集损失函数
plt.plot(eval_losses, label='Test Loss')
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/55116.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为QinQ技术的基本qinq和灵活qinq 2种配置案例

基本qinq配置: 运营商pe设备在收到同一个公司的ce发来的的包,统一打上同样的vlan ,如上图,同一个家公司两边统一打上vlan 2,等于在原内网vlan 10或20过来的包再统一打上vlan 2的标签,这样传输就不会和其它…

mysql进阶-修改linux服务器中MySQL的字符集

1.背景 linux中mysql8默认的字符集是latin1,在插入中文时会报错,所以一般在配置好mysql时需要修改字符集为utf8【又叫utfmb3,一般开发够用,一个字符用3个字节表示】或者utfmb4【一个字符用4个字节表示,如果存储emoji表情&#xf…

创建者模式-单例模式

文章目录 一、创建者模式1. 单例设计模式1.1 单例模式的结构1.2 单例模式的实现(1)饿汉式-方式1(静态变量方式)(2)饿汉式-方式2(静态代码块方式)(3)懒汉式-方…

CS 144 Lab Four 收尾 -- 网络交互全流程解析

CS 144 Lab Four 收尾 -- 网络交互全流程解析 引言Tun/Tap简介tcp_ipv4.cc文件配置信息初始化cs144实现的fd家族体系基于自定义fd体系进行数据读写的adapter适配器体系自定义socket体系自定义事件循环EventLoop模板类TCPSpongeSocket详解listen_and_accept方法_tcp_main方法_in…

SSL原理详解

SSL协议结构: SSL协议分为两层,下层为SSL记录协议,上层为SSL握手协议、SSL密码变化协议和SSL警告协议。 1.下层为SSL记录协议,主要作用是为高层协议提供基本的安全服务 建立在可靠的传输之上,负责对上层的数据进行分块…

【数组数组】应用

一.题目 样例输入&#xff1a; 5 7 1 5 1 5 5 3 3 1 1 样例输出&#xff1a; 1 2 1 1 0 二.分析 有两个数据&#xff1a;x&#xff0c;y。 我们不妨先将x排序&#xff0c;再判断y&#xff0c;若小于&#xff0c;就即可。 这是暴力的思路 for(int i1;i<n;i){int ans0;for(…

前端页面--视觉差效果

代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><link rel"stylesheet" href"https://un…

51单片机IIC方式驱动oled屏代码示例

以IIC方式驱动oled屏特点就是四根线&#xff0c;分别是GND,VCC,SCL,SDA。前面两根很好理解&#xff0c;GND接地&#xff0c;VCC电源正极。SCL接时钟信号&#xff0c;SDA接双向数据信号。在51单片机电路中&#xff0c;没有明确表示SCL和SDA的数据接口&#xff0c;需要自定义。 而…

【深度学习】Vision Transformer论文,ViT的一些见解《 一幅图像抵得上16x16个词:用于大规模图像识别的Transformer模型》

必看文章&#xff1a;https://blog.csdn.net/qq_37541097/article/details/118242600 论文名称&#xff1a; An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale 论文下载&#xff1a;https://arxiv.org/abs/2010.11929 官方代码&#xff1a;https:…

Java项目-苍穹外卖-Day02

完善用户登陆功能 就对密码进行md5加密处理 1.改数据库内部的密码&#xff0c;改成md5加密后的 2.改Service的逻辑&#xff0c;将传过来的进行md5加密后再比较(controller是发令牌&#xff0c;和返回VO对象那逻辑) 先更新数据 如果不改java代码进行登陆&#xff0c;肯定会失…

Maven-生命周期及命令

关于本文 ✍写作原因 之前在学校学习的时候&#xff0c;编写代码使用的项目都是单体架构&#xff0c;导入开源框架依赖时只需要在pom.xml里面添加依赖&#xff0c;点一下reload按钮即可解决大部分需求&#xff1b;但是在公司使用了dubbo微服务架构之后发现只知道使用reload不足…

【Groups】50 Matplotlib Visualizations, Python实现,源码可复现

详情请参考博客: Top 50 matplotlib Visualizations 因编译更新问题&#xff0c;本文将稍作更改&#xff0c;以便能够顺利运行。 1 Dendrogram 树状图根据给定的距离度量将相似的点组合在一起&#xff0c;并根据点的相似性将它们组织成树状的链接。 新建文件Dendrogram.py: …