机器学习实验4——CNN卷积神经网络分类Minst数据集

文章目录

    • 🧡🧡实验内容🧡🧡
    • 🧡🧡 原理🧡🧡
    • 🧡🧡CNN实现分类Minst🧡🧡
      • 代码
      • 数据预处理:
      • 设置基本参数:

🧡🧡实验内容🧡🧡

基于手写minst数据集,完成关于卷积网络CNN的模型训练、测试与评估。

🧡🧡 原理🧡🧡

卷积层
通过使用一组可学习的滤波器(也称为卷积核)对输入图像进行滑动窗口卷积操作,这样可以提取出不同位置的局部特征,从而捕捉到图像的空间结构信息。
激活函数
在卷积层之后,通常会应用一个非线性激活函数,如 ReLU激活函数的作用是引入非线性,使得 CNN 能够学习更复杂的特征表达。
池化层
池化层用于降低特征图的空间尺寸,同时保留最显著的特征信息(类似于人眼观物,是根据物体的主要轮廓来判断物体是什么,而对一些小细节第一眼并没有那么关注)。常见的池化方式包括最大池化和平均池化,它们可以减少计算量,并增加模型的平移不变性。
全连接层
一般在 CNN 的最后几层,全连接层被用来将先前的卷积和池化层的输出与目标类别进行关联,每个神经元在该层中与前一层的所有神经元相连,通过学习权重参数来进行分类决策。
Softmax 函数
在最后一个全连接层之后,通常会应用 Softmax 函数来将神经网络的输出转换为概率分布,用于多类别分类问题的预测。
例如p=[0.2,0.3,0.5],这表示分类为类别1、2、3的概率分别为0.2,0.3,0.5,因此预测分类结果为类别3.
最后通过反向传播算法,CNN 使用训练数据进行模型参数的优化,它通过最小化损失函数(如交叉熵)来调整网络权重,并使用梯度下降等优化算法进行迭代更新。

构建本实验的CNN网络:

  • 5 x 5的卷积核,输入通道为1,输出通道为16:此时图像矩经过卷积核后尺寸变成24 x 24。
  • 2 x 2 的最大池化层:此时图像大小缩短一半,变成 12 x 12,通道数不变;
  • 再次经过 5 x 5 的卷积核,输入通道为16,输出通道为32:此时图像尺寸经过卷积核后变成8 *8。
  • 再次经过 2 x 2 的最大池化层:此时图像大小缩短一半,变成4 x 4,通道数不变;
  • 最后将图像整型变换成向量,输入到全连接层中:输入一共有4 x 4 x 32 = 512 个元素,输出为10.
    -在这里插入图片描述

🧡🧡CNN实现分类Minst🧡🧡

代码

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn, optim
from time import time# ======================准备数据集======================
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist',train=False,download=True,transform=transform)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)# ======================CNN net======================
class CNN_net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 16, kernel_size=5)  # 卷积1self.pooling1 = nn.MaxPool2d(2)  # 最大池化self.relu1 = nn.ReLU()  # 激活self.conv2 = nn.Conv2d(16, 32, kernel_size=5)self.pooling2 = nn.MaxPool2d(2)self.relu2 = nn.ReLU()self.fc = nn.Linear(512, 10)  # 全连接def forward(self, x):batch_size = x.size(0)x = self.conv1(x)x = self.pooling1(x)x = self.relu1(x)x = self.conv2(x)x = self.pooling2(x)x = self.relu2(x)x = x.view(batch_size, -1)x = self.fc(x)return xmodel = CNN_net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)# ====================== train ======================
def train(epoch):time0 = time()  # 记录下当前时间loss_list = []for e in range(epoch):running_loss = 0.0for images, labels in train_loader:outputs = model(images)  # 前向传播获取预测值loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 进行反向传播optimizer.step()  # 更新权重optimizer.zero_grad()  # 清空梯度running_loss += loss.item()  # 累加损失# 一轮循环结束后打印本轮的损失函数print("Epoch {} - Training loss: {}".format(e, running_loss / len(train_loader)))loss_list.append(running_loss / len(train_loader))# 打印总的训练时间print("\nTraining Time (in minutes) =", (time() - time0) / 60)# 绘制损失函数随训练轮数的变化图plt.plot(range(1, epoch + 1), loss_list)plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss')plt.show()train(5)# ====================== test ======================
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as pltdef test():model.eval()  # 将模型设置为评估模式correct = 0total = 0all_predicted = []all_labels = []with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()all_predicted.extend(predicted.tolist())all_labels.extend(labels.tolist())print('Model Accuracy =:%.4f' % (correct / total))# 绘制混淆矩阵cm = confusion_matrix(all_labels, all_predicted)plt.figure(figsize=(8, 6))sns.heatmap(cm, annot=True, fmt=".0f", cmap="Blues")plt.xlabel("Predicted Labels")plt.ylabel("True Labels")plt.title("Confusion Matrix")plt.show()test()

数据预处理:

加载数据集:
加载torch库中自带的minst数据集
转换数据:
先转为tensor变量(相当于直接除255归一化到值域为(0,1))
然后根据std=0.5,mean=0.5,再将值域标准化到(-1,1)。
(做完实验后,上网了解发现minst最合适的的std和mean分别为0.1307, 0.3081,但是其实结果都差不多,准确率变化不大,因为数据集还是相对比较简单的)

设置基本参数:

在这里插入图片描述

构建CNN神经网络:
同上述(1)中,已经构建完毕,这里不再赘述。
模型训练:
在这里插入图片描述
可见,虽然只经过5个epoch,但是花的时间为3.3min。
模型分类:
在这里插入图片描述
准确率达98.26%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/438519.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

操作系统-调度器与闲逛进程(调度程序与进程和线程调度)和调度算法的指标(CPU利用率 系统吞吐量 周转时间 等待时间 响应时间)

文章目录 调度器和闲逛进程调度器/调度程序进程调度线程调度 闲逛进程 调度算法的指标总览CPU利用率系统吞吐量周转时间等待时间响应时间小结 调度器和闲逛进程 调度器/调度程序 进程调度 是否让当前进程下处理机,让哪个进程上处理机 创建完新进程,此…

小程序直播項目开发流程

点击登录功能,创建IM个人账户 以及 创建直播间群组 第一步:需要获取用户唯一的标识openid。 获取流程如下-点击登录按钮-通过wx.getUserProfile这个Api返回的res.userinfo信息获取用户头像昵称等-再通过wx.login的api获取用户的code-使用code再到服务器换…

有哪些好用的洗地机?家用洗地机品牌推荐

洗地机独特的一洗一吸设计带来了卓越的清洁效果。地面上的污渍、垃圾、粉尘都无法抵挡其强大的清洁力,仅需短短几秒钟,家里的地面就能焕然一新,让人感觉仿佛置身于清新宜人的环境中。这种实用性和清洁效果的结合,让洗地机成为智能…

python 基础知识点(蓝桥杯python科目个人复习计划27)

今日复习内容:基础算法中的递归 1.介绍 递归:通过自我调用来解决问题的函数递归通常把一个复杂的大问题层层转化为一个与原问题相似的规模较小的问题来解决 递归要注意:(1)递归出口;(2&#x…

机器学习算法实战案例:使用 Transformer 模型进行时间序列预测实战(升级版)

时间序列预测是一个经久不衰的主题,受自然语言处理领域的成功启发,transformer模型也在时间序列预测有了很大的发展。 本文可以作为学习使用Transformer 模型的时间序列预测的一个起点。 文章目录 机器学习算法实战案例系列答疑&技术交流数据集数据…

使用py-spy对python程序进行性能诊断学习

py-spy简介 py-spy是一个用Rust编写的轻量级Python分析工具,它能够监视正在运行的Python程序,而不需要修改代码或者重新启动程序。Py-spy可以在不影响程序运行的情况下,采集程序运行时的信息,生成火焰图(flame graph&…

springboot131企业oa管理系统

企业OA管理系统 摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了企业OA管理系统的开发全过程。通过分析企业OA管理系统管理的不足,创建了一个计算机管理企业OA管理系统的方案。文章介绍了企业OA管…

第二百九十五回

文章目录 1. 概念介绍2. 使用方法3. 示例代码4. 内容总结 我们在上一章回中分享了一个好用的Json工具,本章回中将介绍如何处理ListView中的事件冲突.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 在Flutter应用开发中,ListView组件是实…

Redis面试(三)

1.Redis报内存不足怎么处理 Redis内存不足的集中处理方式: 修改配置文件redis.cof的maxmemory参数,增加Redis的可用内存通过命令修改set maxmemory动态设置内存上限修改内存淘汰策略,及时释放内存使用Redis集群,及时进行扩容 2…

LeetCode 热题 100 | 矩阵

目录 1 73. 矩阵置零 2 54. 螺旋矩阵 3 48. 旋转图像 4 240. 搜索二维矩阵 II 菜鸟做题第二周,语言是 C 1 73. 矩阵置零 解题思路: 遍历矩阵,寻找等于 0 的元素,记录对应的行和列将被记录的行的元素全部置 0将被记录的…

VBA技术资料MF112:列出目录中的所有文件和文件夹

我给VBA的定义:VBA是个人小型自动化处理的有效工具。利用好了,可以大大提高自己的工作效率,而且可以提高数据的准确度。我的教程一共九套,分为初级、中级、高级三大部分。是对VBA的系统讲解,从简单的入门,到…

AI 绘画平台难开发,难变现?试试 Stable Diffusion API Serverless 版解决方案

作者:王佳、江昱、筱姜 Stable Diffusion 模型,已经成为 AI 行业从传统深度学习时代走向 AIGC 时代的标志性里程碑。越来越多的开发者借助 stable-diffusion-webui(以下简称 SDWebUI)能力进行 AI 绘画领域创业或者业务上新&#…