CNN——LeNet

1.LeNet概述       

         LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。

        当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet当时被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。

        下面是整个网络的结构图

        LeNet共有8层,其中包括输入层,3个卷积层,2个子采样层(也就是现在的池化层),1个全连接层和1个高斯连接层。

        上图中用C代表卷积层,用S代表采样层,用F代表全连接层。输入size固定在1*32*32,LeNet图片的输入是二值图像。网络的输出为0~9十个数字的RBF度量,可以理解为输入图像属于0~9数字的可能性大小。

2.详解LeNet

下面对图中每一层做详细的介绍:

  • LeNet使用的卷积核大小都为5*5,步长为1,无填充,只是卷积深度不一样(卷积核个数导致生成的特征图的通道数)
  • 激活函数为Sigmoid
  • 下采样层都是使用最大池化实现,池化的核都为2*2,步长为2,无填充

        input输入层,尺寸为1*32*32的二值图

        C1层是一个卷积层。该层使用6个卷积核,生成特征图尺寸为32-5+1=28,输出为6个大小为28*28的特征图。再经过一个Sigmoid激活函数非线性变换。

        S2层是一个下采样层。生成特征图尺寸为28/2=14,得到6个14*14的特征图。

        C3层是一个卷积层,该层使用16个卷积核,生成特征图尺寸为14-5+1=10,输出为16个10*10的特征图。再经过一个Sigmoid激活函数非线性变换。

        S4层是一个下采样层,生成特征图尺寸为10/2=5,得到16个5*5的特征图

        C5层是一个卷积层,卷积核数量增加至120。生成特征图尺寸为5-5+1=1。得到120个1*1的特征图。这里实际上相当于S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1*1,那就不是全连接了。再经过一个Sigmoid激活函数非线性变换。

        F6层是一个全连接层,该层与C5层全连接,输出84张特征图。再经过一个Sigmoid激活函数非线性变换。

        输出层:输出层由欧式径向基函数(高斯)单元组成,每个类别(0~9数字)对应一个径向基函数单元,每个单元有84个输入。也就是说,每个输出RBF单元计算输入向量和该类别标记向量之间的欧式距离,距离越远,PRF输出越大,同时我们也会将与标记向量欧式距离最近的类别作为数字识别的输出结果。当然现在通常使用的Softmax实现

3.使用LeNet实现Mnist数据集分类

1.导入所需库

import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm # 显示训练进度条

2.使用GPU

device = 'cuda' if torch.cuda.is_available() else 'cpu'

3.读取Mnist数据集

# 定义数据转换以进行数据标准化
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为 PyTorch 张量
])# 下载并加载 MNIST 训练和测试数据集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)# 创建数据加载器以批量加载数据
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

4.搭建LeNet

        需要注意的是torch.nn.CrossEntropyLoss自带了softmax函数,所以最后一层使用全连接即可。

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# Mnist尺寸为28*28,这里设置填充变成32*32self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) self.sigmoid = nn.Sigmoid()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.flatten = nn.Flatten()self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(self.sigmoid(self.conv1(x)))x = self.pool(self.sigmoid(self.conv2(x)))x = self.flatten(x)x = self.sigmoid(self.fc1(x))x = self.sigmoid(self.fc2(x))x = self.fc3(x)return x
# 实例化模型
model = LeNet().to(device)
summary(model, (1, 28, 28))

5.训练函数

def train(model, lr, epochs):# 将模型放入GPUmodel = model.to(device)# 使用交叉熵损失函数loss_fn = nn.CrossEntropyLoss().to(device)# SGDoptimizer = torch.optim.SGD(model.parameters(), lr=lr)# 记录训练与验证数据train_losses = []train_accuracies = []# 开始迭代   for epoch in range(epochs):   # 切换训练模式model.train()  # 记录变量train_loss = 0.0correct_train = 0total_train = 0# 读取训练数据并使用 tqdm 显示进度条for i, (inputs, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch'):# 训练数据移入GPUinputs = inputs.to(device)targets = targets.to(device)# 模型预测outputs = model(inputs)# 计算损失loss = loss_fn(outputs, targets)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 使用优化器优化参数optimizer.step()# 记录损失train_loss += loss.item()# 计算训练正确个数_, predicted = torch.max(outputs, 1)total_train += targets.size(0)correct_train += (predicted == targets).sum().item()# 计算训练正确率并记录train_loss /= len(train_dataloader)train_accuracy = correct_train / total_traintrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 输出训练信息print(f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")# 绘制损失和正确率曲线plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(epochs), train_losses, label='Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(range(epochs), train_accuracies, label='Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()

6.模型训练

model = LeNet()
lr = 0.9 # sigmoid两端容易饱和,gradient比较小,学得比较慢,所以学习率要大一些
epochs = 20
train(model,lr,epochs)

7.模型测试 

def test(model, test_dataloader, device, model_path):# 将模型设置为评估模式model.eval()# 将模型移动到指定设备上model.to(device)# 从给定路径加载模型的状态字典model.load_state_dict(torch.load(model_path))correct_test = 0total_test = 0# 不计算梯度with torch.no_grad():# 遍历测试数据加载器for inputs, targets in test_dataloader:  # 将输入数据和标签移动到指定设备上inputs = inputs.to(device)targets = targets.to(device)# 模型进行推理outputs = model(inputs)# 获取预测结果中的最大值_, predicted = torch.max(outputs, 1)total_test += targets.size(0)# 统计预测正确的数量correct_test += (predicted == targets).sum().item()# 计算并打印测试数据的准确率test_accuracy = correct_test / total_testprint(f"Accuracy on Test: {test_accuracy:.4f}")return test_accuracy
model_path = save_path
test(model, test_dataloader, device, save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/326345.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何安装 Python

1.打开浏览器 输入网址 :www.python.org ​ 2.根据电脑系统配置进行下载 ​ 3.确定电脑系统属性,此处我们以win10的64位操作系统为例 ​ 4.安装python 3.6.3 双击下载的安装包 python-3.6.3.exe 注意要勾选:Add Python 3.6 to PATH 点击 Customize…

jdbc源码研究

JDBC介绍 JDBC(Java Data Base Connectivity,java数据库连接)是一种用于执行SQL语句的Java API,可以为多种关系数据库提供统一访问,它由一组用Java语言编写的类和接口组成。 开发者不必为每家数据通信协议的不同而疲于奔命&#…

从0开始python学习-42.requsts统一请求封装

统一请求封装的目的: 1.去除重复的冗余的代码 2. 跨py文件实现通过一个sess来自动关联有cookie关联的接口。 3. 设置统一的公共参数,统一的文件处理,统一的异常处理,统一的日志监控,统一的用例校验等 封装前原本代…

乐理燥废笔记

乐理燥废笔记 文章目录 终止式小调音阶转调不协和和弦进行大小转调1251 1451转调我的霹雳猫阿诺三全音代理五声音阶又怎样和弦附录:压缩字符串、大小端格式转换压缩字符串浮点数压缩Packed-ASCII字符串 大小端转换什么是大端和小端数据传输中的大小端总结大小端转换…

静态网页设计——旅游网(HTML+CSS+JavaScript)

前言 声明:该文章只是做技术分享,若侵权请联系我删除。!! 感谢大佬的视频: https://www.bilibili.com/video/BV1KN4y1v7jx/?vd_source5f425e0074a7f92921f53ab87712357b 使用技术:HTMLCSSJS(…

在电商狂欢中,什么平台更加对商家有利?

我是电商珠珠 近年来,不管是直播电商也好,电商平台也好,都一直朝着向上走的趋势。 我做电商也已经有5年时间了,期间做过天猫,快手、抖店,团队从原来的几个人,扩大到了70。 在22年10月&#x…

RKE安装k8s及部署高可用rancher之证书私有证书但是内置的ssl不放到外置的LB中 4层负载均衡

先决条件# Kubernetes 集群 参考RKE安装k8s及部署高可用rancher之证书在外面的LB(nginx中)-CSDN博客CLI 工具Ingress Controller(仅适用于托管 Kubernetes) 创建集群k8s [rootnginx locale]# cat rancher-cluster.yml nodes:- …

JMS消息发送

目录 概述1.搭建 JMS 环境2.使用JmsTemplate 发送消息3.接收JMS 消息 概述 JMS是一个Java标准,定义了使用消息代理(message broker)的通用API,在2001年提出。长期以来,JMS一直是Java 中实现异步消息的首选方案。在JMS 出现之前每个消息代理都有其私有的…

微服务注册中的负载均衡

背景 随着互联网行业的发展,对服务的要求也越来越高,服务架构也从单体架构逐渐演变为现在流行的微服务架构。这些架构之间有怎样的差别呢? 单体架构:简单方便,高度耦合,扩展性差,适合小型项目。…

一文讲透Python数据分析可视化之直方图(柱状图)

直方图(Histogram)又称柱状图,是一种统计报告图,由一系列高度不等的纵向条纹或线段表示数据分布的情况。一般用横轴表示数据类型,纵轴表示分布情况。通过绘制直方图可以较为直观地传递有关数据的变化信息,使…

Zoho SalesIQ:构建客户服务知识库的实用工具与指南

客服人员每天都有很多事情要做,包括在线聊天、音频通话、屏幕共享和发送电子邮件。为什么要将搜索常用信息添加到他们列表中呢?因为客户在遇到问题的同时想快速解决问题。所以,我们要使用Zoho SalesIQ客服系统构建客户服务知识库。 一、什么…

密码学:一文看懂初等数据加密一对称加密算法

文章目录 对称加密算法简述对称加密算法的由来对称加密算法的家谱数据加密标准-DES简述DES算法的消息传递模型DES算法的消息传递过程和Base64算法的消息传递模型的区别 算法的实现三重DES-DESede三重DES-DESede实现 高级数据加密标准一AES实现 国际数据加密标准-IDEA实现 基于口…