学习笔记9:卷积神经网络实现MNIST分类(GPU加速)

news/2025/3/28 1:43:02/文章来源:https://www.cnblogs.com/gongzb/p/18230128

转自:https://www.cnblogs.com/miraclepbc/p/14345342.html

相关包导入

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torchvision
from torchvision import datasets, transforms
%matplotlib inline

设置device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

如果cuda是可用的,那么就使用"cuda:0",否则使用"cpu"

数据加载

transformation = transforms.Compose([transforms.ToTensor(),       ## 转化为一个tensor, 转换到0-1之间, 将channnel放在第一位
])train_ds = datasets.MNIST('E:/datasets2/1-18/dataset/daatset',train = True,transform  =transformation,download = True
)test_ds = datasets.MNIST('E:/datasets2/1-18/dataset/daatset',train = False,transform = transformation,download = True
)train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True)
test_dl = DataLoader(test_ds, batch_size = 258)

模型定义

class Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 6, 5)#参数分别为n_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=Trueself.pool = nn.MaxPool2d((2, 2))self.conv2 = nn.Conv2d(6, 16, 5)self.linear_1 = nn.Linear(16 * 4 * 4, 256)self.linear_2 = nn.Linear(256, 10)def forward(self, input):x = F.relu(self.conv1(input))x = self.pool(x)x = F.relu(self.conv2(x))x = self.pool(x)# print(x.size())x = x.view(-1, 16 * 4 * 4)x = F.relu(self.linear_1(x))x = self.linear_2(x)return xloss_func = torch.nn.CrossEntropyLoss()

这里需要注意一点是,卷积、池化之后是不知道数据的shape的,因此可以采用print的方法,测试一下
具体来说,就是先在全连接层的维度那里随便设置值,然后打印一下
在输出框里,会出现正确的值,这时再将之前随便设置的值修正过来即可

模型训练

def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0for x, y in trainloader:x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_func(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():y_pred = torch.argmax(y_pred, dim = 1)correct += (y_pred == y).sum().item()total += y.size(0)running_loss += loss.item()epoch_acc = correct / totalepoch_loss = running_loss / len(trainloader.dataset)test_correct = 0test_total = 0test_running_loss = 0with torch.no_grad():for x, y in testloader:x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_func(y_pred, y)y_pred = torch.argmax(y_pred, dim = 1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()epoch_test_acc = test_correct / test_totalepoch_test_loss = test_running_loss / len(testloader.dataset)print('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'accuracy: ', round(epoch_acc, 3),'test_loss: ', round(epoch_test_loss, 3),'test_accuracy: ', round(epoch_test_acc, 3))return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_accmodel = Model()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
epochs = 20train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)

这里需要注意的地方是,如果要调用gpu,那么需要将模型和数据都转移到gpu上
因此,需要调用.to(device)方法进行转移

训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/719700.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

笔记2:张量简介

张量生成方法 转自:https://www.cnblogs.com/miraclepbc/p/14329476.html张量的形状及类型张量的计算张量的梯度手写线性回归张量生成方法 张量的形状及类型 张量的计算 张量的梯度 手写线性回归

笔记3:逻辑回归(分批次训练)

转自:https://www.cnblogs.com/miraclepbc/p/14332084.html 相关库导入 import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt from torch import nn %matplotlib inline数据读入及预处理 data = pd.read_csv(E:/datasets/dataset/credit-a.…

【深度好文】到底什么是质量意识?如何衡量,如何提升?

大家好,我是狂师! 在软件测试中,质量意识是一个核心且至关重要的概念。相信大家,经常会听到:"这个家伙质量意识很强,某某某要提升质量意识“之类的话语。 在企业中,“质量意识”不仅关乎产品和服务的优劣,更是企业竞争力和可持续发展的关键因素。那么,到底什么是…

域名

顶级域名、二级域名与三级域名互联网名称与数字地址分配机构(ICANN)负责管理和协调国际互联网络域名系统。根据ICANN的定义,一个完整的域名至少有两个部分,各部分之间用“.”来分隔,最后一个“.”的右边部分称为顶级域名,也称为一级域名;最后一个“.”的左边部分称为二级…

异构数据源同步之数据同步 → DataX 使用细节

开心一刻 中午我妈微信给我消息 妈:儿子啊,妈电话欠费了,能帮妈充个话费吗 我:妈,我知道了,我帮你充 当我帮我妈把话费充好,正准备回微信的时候,我妈微信给我发消息了 妈:等会儿子,不用充了,刚刚有个二臂帮妈充上了 我输入框中的(妈,充好了)是发还是不发?简单使…

js日期格式化代码

js 日期格式化代码 分享一个前端实用的 js 日期格式化代码,相当给力。1 export function getFillDate(key) {2 if(key < 10) {3 return `0${key}`;4 }else{5 return `${key}`;6 }7 }8 /**9 * 时间戳转化为年月日 10 * @param times 时间戳 11 * @param ym…

一周万星的文本转语音开源项目「GitHub 热点速览」

上周的热门开源项目让我想起了「图灵测试」,测试者在不知道对面是机器还是人类的前提下随意提问,最后根据对方回复的内容,判断与他们交谈的是人还是计算机。如果无法分辨出回答者是机器还是人类,则说明机器已通过测试,具有人类的智力水平。 ​虽然现在大模型的回答还充满 …

day6 CSS //免费版创建不了CSS

div标签:的独占一行的块级标签独占一行 块级标签1.独占一行2.可设置长宽 // h1-h6 p div 内联标签 1.不独占一行,按内容占比//b strong i em,span CSS的功能:渲染和布局 CSS的语法://作用 选择标签,操作标签 选择器{ 属性:值 } 展示放到body里面,修饰放到head里面 一…

day7 js

(javaScript)唯一的客户端语言//触发事件js代码 服务器下载运行包,本地自动运行的(类似于自动流水翻页) js的引入方式: (1)//头和身体都可以放进去 (2)外部引入console.log(2 == "2")//true 按类型转换//三等于==完全一才true console.log(2 + "2&q…

Visual Studio编程效率提升技巧集(提高.NET编程效率)

前言 本文大姚将为你介绍一些Visual Studio的使用技巧和建议,旨在帮助.NET开发者更加高效地利用Visual Studio进行编程工作。无论你是.NET初学者还是经验丰富的.NET开发者,这些技巧都将有助于提升你的工作效率,让你能够更快地编写出高质量的代码。让我们一起探索这些技巧,让…

SAP: ALV GRID 追加复选框字段及编辑时立刻调用事件

SAP: ALV GRID 追加复选框字段及编辑时立刻调用事件10、在GRID界面中选择复选框时,提示以下信息: 运行时错误:MOVE_TO_LIT_NOTALLOWED_NODATA 短文本:Assignment error: Overwriting of a protected field. 错误分析: Field “<FS_VALUE>” was to assigned a new v…

开源协议

开源协议分析: