简易机器学习笔记(八)关于经典的图像分类问题-常见经典神经网络LeNet

前言

图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉的核心,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层次视觉任务的基础。图像分类在许多领域都有着广泛的应用,如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。

这里简单讲讲LeNet

我的推荐是可以看看这个视频,可视化的查看卷积神经网络是如何一层一层地抽稀获得特征,最后将所有的图像展开成一个一维的轴,再通过全连接神经网络预测得到一个最后的预测值。

手写数字识别 1.4 LeNet-5-哔哩哔哩

在这里插入图片描述

计算过程

前置知识:

  1. 步长 Stride & 加边 Padding

卷积后尺寸=(输入尺寸-卷积核大小+加边像素数)/步长 + 1

默认Padding = ‘valid’ (丢弃),strides = 1
在这里插入图片描述

正式计算

  1. 卷积层1:

第一层我们给定的图像时32 * 32,使用六个5 x 5的卷积核,步长为1

第一层中没有加边,那么卷积后的尺寸就是(32 - 5 + 0 )/1 + 1 =28,那么输出的图像就是 28*28的边长

在第一层中,由于我们使用了六个卷积核,我们得到的输出为:62828,可以理解为一个六层厚的图像

  1. 池化层1:

我们在池化层内在2x2的图像内选取了一个最大值或者平均值,也就是图片整体缩水到原先的二分之一,所以我们得到池化层的输出为 6 x 14 x 14

  1. 卷积层2:

还是按照公式,卷积后尺寸=(输入-卷积核+加边像素数)/步长 + 1,这个时候输入为6 x 14 x 14,这一次我们给定了16个卷积核,得到输出后的尺寸为(14 - 5 + 0)/1 + 1 = 10,得到输出为161010

关于这个16个卷积核是怎么来的,可以见图:

问了下组里的大佬,大佬说这个卷积核数目和层数很多是经验值,即你寻求更多或者更少的卷积核数目或者层数,实际效果不一定有经验值更好,反正都是离散值,就随便试试就行了。

其中:卷积输出尺寸nout:nin为输入原图尺寸大小;s是步长(一次移动几个像素);p补零圈数,

我们这里输入的值

  1. 池化层2

得到 输出后尺寸为16 * 5 * 5

  1. 全连接层1:

输入为16 * 5 * 5 ,有120个5*5卷积核,步长为1,输出尺寸为(5 - 5 + 0)/1 + 1 =1,这时候输出的就是一条直线的一维输出了

  1. 全连接层2:

输入为120,使用了84个神经元,

  1. 输出层

输入84,输出为10

比如我们如图所示,在代码中是这样的:

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST
#定义LeNet网络结构# 定义 LeNet 网络结构
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet,self).__init__()#创建卷积层和池化层#创建第一个卷积层self.conv1 = Conv2D(in_channels=1,out_channels=6,kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2,stride=2)#尺寸的逻辑:池化层未改变通道数,当前通道为6#创建第二个卷积层self.conv2 = Conv2D(in_channels=6,out_channels=16,kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2,stride=2)#创建第三个卷积层self.conv3 = Conv2D(in_channels=16,out_channels=120,kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x
# 飞桨会根据实际图像数据的尺寸和卷积核参数自动推断中间层数据的W和H等,只需要用户表达通道数即可。
# 下面的程序使用随机数作为输入,查看经过LeNet-5的每一层作用之后,输出数据的形状。# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
x = paddle.to_tensor(x)for item in model.sublayers():#item是LeNet类中的一个子层#查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)# 设置迭代轮数
EPOCH_NUM = 5
#定义训练过程 
def train(model,opt,train_loader,valid_loader):print("start training ... ")model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] #计算模型输出# 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))#反向传播avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1]# 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')    # 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/319426.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

对技术行业的深度思考

技术行业是当今世界最为热门和发展迅猛的领域之一。无论是互联网、人工智能还是区块链,技术的快速发展正在改变着我们的生活和社会。然而,我们是否真正思考过技术在我们生活中的影响和意义?本文将对技术行业展开深度思考,探讨其带…

LAMP平台搭建

目录 LAMP平台概述 1、Apache 网站服务基础 1.1、Apache 简介 (1)Apache 的起源 (2)Apache 的主要特点 1.2、安装 httpd 服务器 (1)准备工作 (2)源码编译及安装 &#…

【MySQL】MySQL如何查询和筛选存储的JSON数据?

MySQL如何查询和筛选存储的JSON数据? 一、背景介绍二、支持的JSON数据类型三、基础数据3.1 创建表3.2 插入 JSON 数据3.3 查询 JSON 数据 四、操作函数4.1 JSON_OBJECT4.2 JSON_ARRAY4.3 JSON_EXTRACT 一、背景介绍 JSON(JavaScript Object Notation)是一种轻量级的…

文心一言 VS 讯飞星火 VS chatgpt (171)-- 算法导论13.2 4题

四、用go语言,证明:任何一棵含n个结点的二叉搜索树可以通过 O(n)次旋转,转变为其他任何一棵含n个结点的二叉搜索树。(提示:先证明至多n-1次右旋足以将树转变为一条右侧伸展的链。) 文心一言: 这是一个有趣的问题&…

【JUC】Volatile关键字+CPU/JVM底层原理

Volatile关键字 volatile内存语义 1.当写一个volatile变量时,JMM会把该线程对应的本地内存中的共享变量值立即刷新回主内存中。 2.当读一个volatile变量时,JMM会把该线程对应的本地内存设置为无效,直接从主内存中读取共享变量 所以volatile…

【番外】【Airsim in Windows ROS in WSL2-Ubuntu20.04】环境配置大全

【番外】【Airsim in Windows &ROS in WSL2-Ubuntu20.04】环境配置大全 【前言(可省略不看)】1.在windows上面部署好UE4AirSim联合仿真环境2.在windows上面部署wsl2系统以及在wsl2上面部署ubuntu系统3.安装好ubuntu系统之后,目前只能在命…

使用开源 Upscayl 工具放大图片

Upscayl 是一个基于人工智能的图像放大工具,可以用来将低分辨率的图片放大到高分辨率。Upscayl 使用了一种称为超分辨率重建的技术,可以生成逼真的高分辨率图像。 在本教程中,我们将介绍如何使用 Upscaly 工具放大图片。 准备工作 下载&a…

基于Java在线考试管理系统设计与实现(源码+部署文档)

博主介绍: ✌至今服务客户已经1000、专注于Java技术领域、项目定制、技术答疑、开发工具、毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精彩专栏 推荐订阅 👇🏻 不然下次找不到 Java项目精品实…

geemap学习笔记040:GEE中样本点选择操作流程

前言 geemap中目前有一个bug,就是在选择样本点的时候不合理,选完一类样本之后,没法继续选择下一类,并且没法在线进行编辑和修改。因此目前就只能结合在线版的GEE进行样本选择,本节就详细的介绍一下GEE中样本点的选择过…

【EDA软件】国产嘉立创EDA使用

嘉立创(JLCPCB) 嘉立创(JLCPCB)提供的EDA专业版是一款在线的原理图设计与PCB设计工具,可以协助用户完成电子设计项目。以下是使用嘉立创EDA专业版的一般步骤: 注册与登录: 首先,你需…

【MySQL·8.0·源码】MySQL 表的扫描方式

前言 在进一步介绍 MySQL 优化器时,先来了解一下 MySQL 单表都有哪些扫描方式。 单表扫描方法是基表的读取基础,也是完成表连接的基础,熟悉了基表的基本扫描方式, 即可以倒推理解 MySQL 优化器层的诸多考量。 基表,即…

前端 js 基础(2)

js For In for in 循环遍历 person 对象每次迭代返回一个键 (x)键用于访问键的值键的值为 person[x] 如果索引顺序很重要,请不要在数组上使用 for in。 索引顺序依赖于实现,可能不会按照您期望的顺序访问数组值。 当顺序很重要时,最好使用 f…