卷积神经网络CNN实战:MINST手写数字识别——数据集下载与网络训练

news/2025/2/6 0:54:21/文章来源:https://www.cnblogs.com/SXWisON/p/18314373

数据集下载

这一部分比较简单,就不过多赘述了,把代码粘贴到自己的项目文件里,运行一下就可以下载了。

from torchvision import datasets, transforms# 定义数据转换,将数据转换为张量并进行标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化
])# 下载和加载训练集
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)# 下载和加载测试集
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

该代码运行效果如下图:

该代码运行效果如下图

import torch'''=============== 数据集部分 ==============='''
# 定义数据转换
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 打开已经下载的训练集和测试集
from torchvision.datasets import MNIST
train_dataset = MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=False, transform=transform)# 创建数据加载器
batch_size = 256
from torch.utils.data import random_split
from torch.utils.data import DataLoader# 将数据集分割为6000和剩余的数据
train_size = 6000
train_subset, _ = random_split(train_dataset, [train_size, len(train_dataset) - train_size])train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)'''=============== 网络定义 ==============='''
# 初始化网络
from net import CNN 
net = CNN()# 初始化优化器、学习率调整器、评价函数
import torch.nn as nn
from torch import optim
learning_rate = 0.001 # 0.05 ~ 1e-6
weight_decay = 1e-4 # 1e-2 ~ 1e-8
momentum = 0.8 # 0.3~0.9
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
criterion = nn.CrossEntropyLoss()# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device)'''=============== 模型信息管理 ==============='''
model_path = Noneif model_path is not None:net.load_state_dict(torch.load(model_path, map_location=device))'''=============== 网络训练 ==============='''
epochs = 50def train(net, device, optimizer, scheduler, criterion):net.train() for epoch in range(epochs):epoch_loss = 0      # 集损失置0for images, labels in train_loader:''' ========== 数据获取和转移 ========== '''images = images.to(device=device, dtype=torch.float32)labels = labels.to(device=device, dtype=torch.long)''' ========== 数据操作 ========== '''outputs = net(images)# net.forward()loss = criterion(outputs, labels)epoch_loss += loss.detach().item()''' ========== 反向传播 ========== '''optimizer.zero_grad()loss.requires_grad_(True)loss.backward() # 梯度裁剪for param in net.parameters():if param.grad is not None and param.grad.nelement() > 0:nn.utils.clip_grad_value_([param], clip_value=0.1)optimizer.step()epoch_loss /= len(train_loader)# 输出每个 epoch 的平均损失print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss}')train(net, device, optimizer, scheduler, criterion)'''=============== 网络保存 ==============='''
from datetime import datetime# 获取当前时间
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
model_path = f'./output/final_model_{current_time}.pth'# 保存模型
torch.save(net.state_dict(), model_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/758767.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

winform--穿梭框

新建一个用户控件: 直接上代码:/** 作者:pengyan zhang* 邮箱:3073507793@qq.com* 博客:https://www.cnblogs.com/zpy1993-09* 时间:2024-04-10 16:36*/public partial class ShuttleFrameControl : UserControl{private Color lb_BackColor { get; set; } = Color.Trans…

mpc

https://blog.csdn.net/apr15/article/details/133965768在“数据安全概述”里面, 我们提到了安全多方计算SMPC(Secure multi-party computation)的技术。在这个计算里面代表是密码分享SS (secret sharing)技术。 而开启整个算法世界的其实是华人科学家姚期智教授, 他提出…

一般网站制作流程

制作需要经过以下几个流程:设计页面效果图,一般为PSD或者PNG格式的原图; 将页面效果图输出为HTML格式,后缀名为“.htm”; 根据页面内容调用需求生成或者编写标签所需代码; 嵌套标签代码到输出页面对应位置; 测试调试模板文件,保证调用和设计效果一致; 将模板标签、文件…

帝国CMS的网站“Notice: Use of undefined constant”错误说明

“Notice: Use of undefined constant”错误说明解答:php.ini配置问题,按下面修改即可解决: 修改php.ini,把error_reporting = E_ALL改成 error_reporting = E_ALL & ~E_NOTICE扫码添加技术【解决问题,仅需10元起】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精…

DASCTF 2023六月挑战赛|二进制专项 PWN (上)

DASCTF 2023六月挑战赛|二进制专项 PWN (上) 1.easynote edit函数对长度没有检查free函数存在UAF漏洞思路:1.通过堆溢出,UAF,修改size位达到堆块重叠,使用fastbin attack,把__malloc_hook,写入one_gadget 2.通过unlink修改free got表为system exp: from pwn import * co…

易优cms后台数据类型的开关功能如何默认都显示“开”

新建字段默认就是true,就是扫码添加技术【解决问题,仅需10元起】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网站改版、BUG修复、问题处理、二次开发、PSD转HTML、网站被黑、网站漏洞修复等。专…

eyoucms获取当前栏目分类的下级栏目的文档列表

[基础用法] 标签:modelsartlist (channelartlist)备注:使用channelartlist也可以正常输出描述:获取当前栏目分类的下级栏目的文档列表 用法: {eyou:modelsartlist typeid=栏目ID type=son loop=20} <a href={eyou:field name=typeurl /}>{eyou:field name=typename…

帝国CMS忘记后台登陆认证码怎么处理

忘记后台登陆认证码怎么办?查看/e/class/config.php文件里的“$do_loginauth”变量内容。扫码添加技术【解决问题,仅需10元起】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网站改版、BUG修复、…

dedecms织梦更新生成栏目没反应问题总汇

织梦dedecms栏目无法更新是最头疼的事情,因为导致dedecms栏目不能更新的因素有很多,至 于大家都是什么原因导致的我也无法确定,因此整理了笔者所知道的一些原因,希望对大家有所帮助, 下面大家跟我一起来看下,你遇到的dede更新栏目无效是下面的哪一种情况:方法/步骤第一种…

易优cms登陆后台,总是提示验证码错误,账户密码都对!

问题: 易优cms登陆后台,总是提示验证码错误,账户密码都对!解决办法: 检查下目录权限,或者用排除法,弄回本地安装看看,如果可以,就是空间环境哪里设置有问题。扫码添加技术【解决问题,仅需10元起】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、H…

DedeTag Engine Create File False的彻底解决方法总汇

错误记录: DedeTag Engine Create File False的彻底解决方法总汇解决方案: DedeTag Engine Create File False这个问题真是折磨人,说小不小说大不大,这里分享一下DedeTag Engine Create File False的解决办法 方法1:确认文件夹a、data(以前的版本好像html,你也可能自定…

用SqlBulkCopy批量插入数据 遇到的错误

原文链接:https://www.cnblogs.com/wz327/archive/2011/07/05/2098356.html 错误一:来自数据源的 String 类型的给定值不能转换为指定目标列的类型 nvarchar。 还有其他的错误如:AddTime不能为DBNull (这个应该是目标表中AddTime要求不许为null) 可能的原因有两种 可能是…