Pytorch:搭建卷积神经网络完成MNIST分类任务:

2023.7.18

MNIST百科:

MNIST数据集简介与使用_bwqiang的博客-CSDN博客

数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集获取并转换成图片格式:

数据集将按以图片和文件夹名为标签的形式保存:

 代码:下载mnist数据集并转还为图片


import os
from PIL import Image
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量transforms.Normalize((0.5,), (0.5,))  # 标准化
])# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root=os.getcwd(), train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=os.getcwd(), train=False, transform=transform, download=True)# 路径
train_path = './images/train'
test_path = './images/test'# 将训练集中的图像保存为图片
for i in range(10):file_name = train_path + os.sep + str(i)if not os.path.exists(file_name):os.mkdir(file_name)for i in range(10):file_name = test_path + os.sep + str(i)if not os.path.exists(file_name):os.mkdir(file_name)for i, (image, label) in enumerate(train_dataset):train_label = labelimage_path = f'images/train/{train_label}/{i}.png'image = image.squeeze().numpy()  # 去除通道维度,并转换为 numpy 数组image = (image * 0.5) + 0.5  # 反标准化,将范围调整为 [0, 1]image = (image * 255).astype('uint8')  # 将范围调整为 [0, 255],并转换为整数类型Image.fromarray(image).save(image_path)# 将测试集中的图像保存为图片
for i, (image, label) in enumerate(test_dataset):text_label = labelimage_path = f'images/test/{text_label}/{i}.png'image = image.squeeze().numpy()  # 去除通道维度,并转换为 numpy 数组image = (image * 0.5) + 0.5  # 反标准化,将范围调整为 [0, 1]image = (image * 255).astype('uint8')  # 将范围调整为 [0, 255],并转换为整数类型Image.fromarray(image).save(image_path)

 训练代码:


import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision.transforms as transforms
from PIL import Image# 调动显卡进行计算
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")class MyDataset(torch.utils.data.Dataset):def __init__(self, root_dir, transform=None):self.root_dir = root_dirself.transform = transformself.names_list = []for dirs in os.listdir(self.root_dir):dir_path = self.root_dir + '/' + dirsfor imgs in os.listdir(dir_path):img_path = dir_path + '/' + imgsself.names_list.append((img_path, dirs))def __len__(self):return len(self.names_list)def __getitem__(self, index):image_path, label = self.names_list[index]if not os.path.isfile(image_path):print(image_path + '不存在该路径')return Noneimage = Image.open(image_path)label = np.array(label).astype(int)label = torch.from_numpy(label)if self.transform:image = self.transform(image)return image, label# 定义卷积神经网络模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc = nn.Linear(16 * 14 * 14, 10)def forward(self, x):x = self.conv1(x)  # 卷积x = self.relu(x)  # 激活函数x = self.maxpool(x)  # 最大值池化x = x.view(x.size(0), -1)x = self.fc(x)  # 全连接层return x# 加载手写数字数据集
train_dataset = MyDataset('./dataset/images/train', transform=transforms.ToTensor())
val_dataset = MyDataset('./dataset/images/val', transform=transforms.ToTensor())# 定义超参数
batch_size = 8192  # 批处理大小
learning_rate = 0.001  # 学习率
num_epochs = 30  # 迭代次数# 创建数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)# 实例化模型、损失函数和优化器
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 优化器# 记录验证的次数
total_train_step = 0
total_val_step = 0# 模型训练和验证
print("-------------TRAINING-------------")
total_step = len(train_loader)
for epoch in range(num_epochs):print("Epoch=", epoch)for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)output = model(images)loss = criterion(output, labels.long())optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1print("train_times:{},Loss:{}".format(total_train_step, loss.item()))# 测试验证total_val_loss = 0total_accuracy = 0with torch.no_grad():for i, (images, labels) in enumerate(val_loader):images = images.to(device)labels = labels.to(device)outputs = model(images)loss = criterion(outputs, labels.long())total_val_loss = total_val_loss + loss.item()  # 计算损失值的和accuracy = 0for j in labels:  # 计算精确度的和if outputs.argmax(1)[j] == labels[j]:accuracy = accuracy + 1total_accuracy = total_accuracy + accuracyprint('Accuracy =', float(total_accuracy / len(val_dataset)))  # 输出正确率torch.save(model, "cnn_{}.pth".format(epoch))  # 模型保存# # 模型评估
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

测试代码:

import torch
from torchvision import transforms
import torch.nn as nn
import os
from PIL import Imagedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # 判断是否有GPUclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.fc = nn.Linear(16 * 14 * 14, 10)def forward(self, x):x = self.conv1(x)  # 卷积x = self.relu(x)  # 激活函数x = self.maxpool(x)  # 最大值池化x = x.view(x.size(0), -1)x = self.fc(x)  # 全连接层return xmodel = torch.load('cnn.pth')  # 加载模型path = "./dataset/images/test/"  # 测试集imgs = os.listdir(path)test_num = len(imgs)
print(f"test_dataset_quantity={test_num}")for img_name in imgs:img = Image.open(path + img_name)test_transform = transforms.Compose([transforms.ToTensor()])img = test_transform(img)img = img.to(device)img = img.unsqueeze(0)outputs = model(img)  # 将图片输入到模型中_, predicted = outputs.max(1)pred_type = predicted.item()print(img_name, 'pred_type:', pred_type)

分类正确率不错:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/27441.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot 整合 RabbitMQ demo

Rabbit Windows安装教程 本文只做Demo案例的分享,具体只是需自行百度 一、生产者 1.application.properties 配置Rabbit的基本信息 #rabbit 主机IP spring.rabbitmq.host127.0.0.1 #rabbit 端口 spring.rabbitmq.port5672 #rabbit 账号 可自行创建 这里是默认的 …

排序算法的补充

建议先去看看我之前写的基础排序算法 补充一&#xff1a;快排中partition函数的三种实现形式 1.hoare法---与第2种方法类似 int Partition1(int*a,int left,int right) {int keyi left;while (left < right) {while (left < right && a[right] > a[keyi])…

Hadoop 之 HDFS 伪集群模式配置与使用(二)

HDFS 配置与使用 一.HDFS配置二.HDFS Shell1.默认配置说明2.shell 命令 三.Java 读写 HDFS1.Java 工程配置2.测试 一.HDFS配置 ## 基于上一篇文章进入 HADOOP_HOME 目录 cd $HADOOP_HOME/etc/hadoop ## 修改文件权限 chown -R root:root /usr/local/hadoop/hadoop-3.3.6/* ## …

C++-string类的模拟实现

本博客基于C官方文档当中给出的string类当中的主要功能实现&#xff0c;来作为参照&#xff0c;简单模拟实现 My-string 。 对于C当中的string类的介绍&#xff0c;在之前的几篇博客当中有说明&#xff0c;如有问题&#xff0c;请参照一下两个博客文章进行参考&#xff1a; (2…

ERROR: Invalid requirement: ‘==‘ 解决python报错

ERROR: Invalid requirement: 错误:无效的要求: 今天安装 selenium包时突然触发这个报错&#xff0c;这个错误通常出现在使用pip安装Python包时&#xff0c;报错的原因是需要注意的是前后没有空格&#xff0c;若是加空格就会出现上述报错。 例如&#xff1a; 安装指定版本的…

python与深度学习(一):ANN和手写数字识别

目录 1. 神经网络2. 线性回归3. 激活函数3.1 Sigmoid函数3.2 Relu函数3.3 Softmax函数 4. ANN(全连接网络)模型结构5. 误差函数5.1 均方差误差函数5.2 交叉熵误差函数 6. 手写数字识别实战6.1 工具说明6.2 导入相关库6.3 加载数据6.4 数据预处理6.5 数据处理6.6 构建网络模型6.…

nginx+lua+redis环境搭建(文末赋上脚本)

目录 需求背景 环境搭建后nginx和redis版本 系统环境 搭建步骤 配置服务器DNS 安装ntpdate同步一下系统时间 安装网络工具、编译工具及依赖库 创建软件包下载目录、nginx和redis安装目录 下载配置安装lua解释器LuaJIT 下载nginx NDK&#xff08;ngx_devel_kit&#xff09…

Vue3警告提示(Alert)

可自定义设置以下属性&#xff1a; 警告提示内容&#xff08;message&#xff09;&#xff0c;类型&#xff1a;string | slot&#xff0c;默认&#xff1a;‘’警告提示的辅助性文字介绍&#xff08;description&#xff09;&#xff0c;类型&#xff1a;string | slot&#…

libvirt 热迁移流程及参数介绍

01 热迁移基本原理 1.1 热迁移概念 热迁移也叫在线迁移&#xff0c;是指虚拟机在开机状态下&#xff0c;且不影响虚拟机内部业务正常运行的情况下&#xff0c;从一台宿主机迁移到另外一台宿主机上的过程。 1.2 虚拟机数据传输预拷贝和后拷贝 预拷贝(pre-copy)&#xff1a; …

Git及Tortoisegit使用教程,设置中文

一、到git官网下载GIT 官网 二、下载安装Tortoisegit及中文语言包,Tortoisegit及语言包 语言包下载地址 三、在电脑某个盘的文件里右键 提示未设置git.exe 路径不能继续, 于是去下载git GIT下载 安装Git时, 一直点击 Next > 不要停, 直到结束 此时再跳到TortoiseGit…

Versal ACAP在线升级之Boot Image格式

1、简介 Xilinx FPGA、SOC器件和自适应计算加速平台&#xff08;ACAPs&#xff09;通常由多个硬件和软件二进制文件组成&#xff0c;用于启动这些设备后按照预期设计进行工作。这些二进制文件可以包括FPGA比特流、固件镜像、bootloader引导程序、操作系统和用户选择的应…

KaiwuDB CTO 魏可伟:多模架构 —“化繁为简”加速器

以下为浪潮 KaiwuDB CTO 魏可伟受邀于7月4日在京举行的可信数据库发展大会发表演讲的实录&#xff0c;欢迎大家点赞、收藏、关注&#xff01; 打造多模引擎&#xff0c;AIoT数据库探索之路 01 何为“繁”&#xff1f; 工业 4.0 时代&#xff0c; 物联网产业驱动数据要素市场不…