【神经网络手写数字识别-最全源码(pytorch)】

Torch安装的方法

在这里插入图片描述

学习方法

  • 1.边用边学,torch只是一个工具,真正用,查的过程才是学习的过程
  • 2.直接就上案例就行,先来跑,遇到什么来解决什么

Mnist分类任务:

  • 网络基本构建与训练方法,常用函数解析

  • torch.nn.functional模块

  • nn.Module模块

读取Mnist数据集

  • 会自动进行下载
# 查看自己的torch的版本
import torch
print(torch.__version__)
%matplotlib inline
# 前两步,不用管是在网上下载数据,后续的我们都是在本地的数据进行操作
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)
import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

784是mnist数据集每个样本的像素点个数

from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

在这里插入图片描述
全连接神经网络的结构
在这里插入图片描述在这里插入图片描述注意数据需转换成tensor才能参与后续建模训练

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

torch.nn.functional 很多层和函数在这里都会见到

torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))

创建一个model来更简化代码

  • 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
  • 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
  • Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nnclass Mnist_NN(nn.Module):# 构造函数def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)self.dropout = nn.Dropout(0.5)#前向传播自己定义,反向传播是自动进行的def forward(self, x):x = F.relu(self.hidden1(x))x = self.dropout(x)x = F.relu(self.hidden2(x))x = self.dropout(x)#x = F.relu(self.hidden3(x))x = self.out(x)return x

在这里插入图片描述

net = Mnist_NN()
print(net)

在这里插入图片描述
可以打印我们定义好名字里的权重和偏置项

for name,parameter in net.named_parameters():print(name, parameter,parameter.size())

在这里插入图片描述

使用TensorDataset和DataLoader来简化

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as npdef fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()  # 训练的时候需要更新权重参数for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval() # 验证的时候不需要更新权重参数with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))

zip的用法

a = [1,2,3]
b = [4,5,6]
zipped = zip(a,b)
print(list(zipped))
a2,b2 = zip(*zip(a,b))
print(a2)
print(b2)
from torch import optim
def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)

三行搞定!

train_dl,valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(100, model, loss_func, opt, train_dl, valid_dl)

在这里插入图片描述

correct = 0
total = 0
for xb,yb in valid_dl:outputs = model(xb)_,predicted = torch.max(outputs.data,1)total += yb.size(0)correct += (predicted == yb).sum().item()
print(f"Accuracy of the network the 10000 test imgaes {100*correct/total}")

![在这里插入图片描述](https://img-blog.csdnimg.cn/89e5e749b680426c9700aac9f93bf76a.png

后期有兴趣的小伙伴们可以比较SGD和Adam两种优化器,哪个效果更好一点

-SGD 20epoch 85%
-Adam 20epoch 85%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/54769.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt 添加资源文件(添加图片)

第一步,先将要添加的资源文件(图片等)放入项目目录,如下: 第二步,如下图: 第三步,如下图: 第四步,如下图: 第五步,继续选择完成即…

ES6 - 对象新增的一些常用方法

文章目录 1,Object.is()2,Object.asign()3,Object.getOwnPropertyDescriptors()4,Object.setPrototypeOf()和getPrototypeOf()5,Object.keys()、values() 和 entries()6,Object.fromEntries()7,…

学习系统编程终章【多线程剩余知识】

引言: 北京时间:2023/8/3/19:21,刚刚将文章更新,是近期以来为数不多的一次早更,不然每次更文都要卡在12点左右,这是我们实现日更的一个好开端,再接再厉实现日更不是梦。最近失眠一直困扰着我&a…

参考RabbitMQ实现一个消息队列

文章目录 前言小小消息管家1.项目介绍2. 需求分析2.1 API2.2 消息应答2.3 网络通信协议设计 3. 开发环境4. 项目结构介绍4.1 配置信息 5. 项目演示 前言 消息队列的本质就是阻塞队列,它的最大用途就是用来实现生产者消费者模型,从而实现解耦合以及削峰填…

Java项目作业~ 创建基于Maven的Java项目,连接数据库,实现对站点信息的管理,即实现对站点的新增,修改,删除,查询操作

需求: 创建基于Maven的Java项目,连接数据库,实现对站点信息的管理,即实现对站点的新增,修改,删除,查询操作。 以下是站点表的建表语句: CREATE TABLE websites (id int(11) NOT N…

外卖多门店小程序开源版开发

外卖多门店小程序开源版开发 外卖多门店小程序开源版的开发可以按照以下步骤进行: 确定需求:明确外卖多门店小程序的功能和特点,包括用户注册登录、浏览菜单、下单支付、订单管理等。技术选型:选择适合开发小程序的技术框架&…

在服务器上搭建gitlab

最终效果展示: 官方文档: 安装部署GitLab服务 1.在服务器上下载gitlab wget https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/gitlab-ce-12.9.0-ce.0.el7.x86_64.rpm rpm -ivh gitlab-ce-12.9.0-ce.0.el7.x86_64.rpm 2.编辑站点位置 vim …

python人工智能可以干什么,python人工智能能干什么

大家好,给大家分享一下python做人工智能需要什么水平,很多人还不知道这一点。下面详细解释一下。现在让我们来看看! 人工智能包含常用机器学习和深度学习两个很重要的模块,而python拥有matplotlib、Numpy、sklearn、keras等大量的…

Spring接口ApplicationRunner的作用和使用介绍

在Spring框架中,ApplicationRunner接口是org.springframework.boot.ApplicationRunner接口的一部分。它是Spring Boot中用于在Spring应用程序启动完成后执行特定任务的接口。ApplicationRunner的作用是在Spring应用程序完全启动后,执行一些初始化任务或处…

基于 Flink Paimon 实现 Streaming Warehouse 数据一致性管理

摘要:本文整理自字节跳动基础架构工程师李明,在 Apache Paimon Meetup 的分享。本篇内容主要分为四个部分: 背景 方案设计 当前进展 未来规划 点击查看原文视频 & 演讲PPT 一、背景 ​ 早期的数仓生产体系主要以离线数仓为主&#xf…

javaWeb项目--二级评论完整思路

先来看前端需要什么吧: 通过博客id,首先需要显示所有一级评论,包括评论者的头像,昵称,评论时间,评论内容 然后要显示每个一级评论下面的二级评论,包括,评论者的头像,昵称…

Linux6.32 Kubernetes kubeadm部署

文章目录 计算机系统5G云计算第三章 LINUX Kubernetes kubeadm部署一、kubeadm搭建 Kubernetes v1.20(一主两从)1.环境准备2.所有节点安装docker3.所有节点安装kubeadm,kubelet和kubectl4.部署K8S集群 二、kubeadm搭建 Kubernetes v1.20&…