【动手学习深度学习--逐行代码解析合集】15卷积神经网络(LeNet)

【动手学习深度学习】逐行代码解析合集

15卷积神经网络(LeNet)


视频链接:动手学习深度学习–卷积神经网络(LeNet)
课程主页:https://courses.d2l.ai/zh-v2/
教材:https://zh-v2.d2l.ai/

1、LeNet

总体来看,LeNet(LeNet-5)由两个部分组成:

  • 卷积编码器:由两个卷积层组成;
  • 全连接层密集块:由三个全连接层组成。

在这里插入图片描述


在这里插入图片描述


2、LeNet代码实现

import torch
from torch import nn
from d2l import torch as d2l
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE""====================1、LeNet===================="
net = nn.Sequential(# 在数据集中输入图片是32×32(包含padding),在网络中输入为28×28,因此此处需加padding# 输入通道数1,输出通道数6,卷积核5×5,边缘填充为2,加入Sigmoid激活函数引入非线性性nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),# 采用平均池化,卷积核2×2,步长为2nn.AvgPool2d(kernel_size=2, stride=2),# 输入通道数6,输出通道数16,卷积核5×5,采用Sigmoid激活函数nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),# 将连续的维度范围展平为张量。用于对神经网络模型的输出进行处理,得到tensor类型的数据。nn.Flatten(),# 全连接层,输入神经元个数16 * 5 * 5,输出神经元个数120,采用Sigmoid激活函数nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),# 最后得到10个输出类别nn.Linear(84, 10))# 输入图片28×28
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)

在这里插入图片描述
3、LeNet在Fashion-MNIST数据集上的表现(GPU)

"====================2、模型训练===================="
batch_size = 256  # 批量大小
# 导入训练集和测试集
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 由于完整的数据集位于内存中,因此在模型使用GPU计算数据集之前,我们需要将其复制到显存中。if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())# 分类正确的个数/总的大小return metric[0] / metric[1]
"为了使用GPU,我们还需要一点小改动。"
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):  # 初始化权重# 如果是全连接层或者卷积层,使用xavier初始化方法if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)# net.apply会对每一个参数都执行一次init_weights这个函数net.apply(init_weights)print('training on', device)  # 打印一下在哪个设备上训练net.to(device)  # 把整个参数挪到GPU上# 使用随机梯度下降算法更新参数optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 使用交叉熵损失函数loss = nn.CrossEntropyLoss()# 画图animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)# 迭代每一轮for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()# 每次数据迭代,拿出一个批次的数据for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()  # 梯度清零X, y = X.to(device), y.to(device)   # 将输入输出挪到GPU上y_hat = net(X)  # 前向操作计算y_hatl = loss(y_hat, y)  # 计算损失l.backward()  # 反向传播计算梯度optimizer.step()  # 迭代with torch.no_grad():# 打印动画metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]  # 损失train_acc = metric[1] / metric[2]  # 准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))# 打印训练损失、训练准确率、测试准确率print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')# 一些额外信息print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')
"训练和评估LeNet-5模型。"
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

运行结果
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/20156.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IDEA中 application.yaml文件没有绿色的叶子

IDEA中 application.yaml文件没有绿色的叶子 问题背景 前段时间一直在刷算法题和备战考试,忽略了项目方面的锻炼,于是今天就想着来写一个练手的项目,重新熟悉一下技术栈。结果刚搭建一个SpringBoot项目,就发现application.yaml配…

第三方ipad电容笔哪个品牌好用?平板电容笔推荐

可能很多人都认为,苹果原装的电容笔,是不可取代,但我认为,这还要看个人的预算,以及实际的需求。苹果Pencil对于那些不太讲究画质的用户来说实在是太贵了,要是我们仅用于书写上,其实我们可以用平…

java动态导出excel头

java动态导出excel头 java根据动态头导出excel文件一、需求背景1、调用接口将表头传给给后端2、请求结果展示3、核心代码1、工具类,注意异常抛出类如报错,需自定义异常类2、标题设置类3、单元各简单设置类4、controller接收参数 java根据动态头导出excel…

LeetCode 203. 移除链表元素

给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 (1)直接使用原来的链表来进行移除节点操作: //不带头结点删除元素节点 class Solution { public:Lis…

ssh配置多账号(Mac)

一. 应用场景 当存在同时需要git在GitHub、gitee、gitlab等多个不同git托管平台进行ssh代码操作的时候。 二. 具体操作 默认 ssh-keygen -t rsa -C "你的邮箱"之后一直回车就可以,会默认在~/.ssh目录下生成id_rsa、id_rsa.pub 指定文件 ssh-keygen …

Git Commit的规范及高级使用方法

git commit是日常工作中使用率极高的一个命令,但是根据我从业5年的经验来看,大多数人在用git commit命令时都很粗糙,比如git commit -m 后跟的message是五花八门,有用中文的,有用英文的,甚至还有直接跟111的…

Codeforces Round 882 (Div. 2)(视频讲解A——D)

[TOC](Codeforces Round 882 (Div. 2)&#xff08;视频讲解A——D&#xff09;) 讲解在B站&#xff1a;Codeforces Round 882 (Div. 2)&#xff08;视频讲解A——D&#xff09; A The Man who became a God #include<bits/stdc.h> #define endl \n #define INF 0x3f3…

【分布式应用】zookeeper集群

目录 一、zookeeper概述1.1zookeeper工作机制1.2Zookeeper 数据结构1.3Zookeeper 应用场景1.4Zookeeper 选举机制第一次启动选举机制**非第一次启动选举机制 二、部署 Zookeeper 集群2.1环境配置2.2安装 Zookeeper 一、zookeeper概述 Zookeeper是一个开源的分布式的&#xff0c…

机械臂与RealSense相机手眼标定

环境&#xff1a; 本文主要使用kinova mico机械臂 RealSense D435i深度相机进行了eye to hand的手眼标定。 系统环境&#xff1a;Ubuntu18.04&#xff0c;ROS Melodic 硬件&#xff1a;Kinova mico&#xff0c;RealSense D435i 特别注意&#xff1a;经测试&#xff0c;本方法…

【Python】Python实现串口通信(Python+Stm32)

&#x1f389;欢迎来到Python专栏~Python实现串口通信 ☆* o(≧▽≦)o *☆嗨~我是小夏与酒&#x1f379; ✨博客主页&#xff1a;小夏与酒的博客 &#x1f388;该系列文章专栏&#xff1a;Python学习专栏 文章作者技术和水平有限&#xff0c;如果文中出现错误&#xff0c;希望…

Js语法学习实战 -数据类型

Js语法学习实战 -数据类型 1. undefined2. null3. Boolean4. Number5. String5.1 常用方法5.2 字符串迭代遍历方法5.3 字符串替换 6. Symbol类型7. Object7.1 基本使用7.2 对象遍历7.3 复制对象方法 8. 数组 - Array8.1 数组的常用方法8.2 数组遍历 9. Function JS语法学习实战…

18.JavaWeb-JWT(登录、鉴权)

1.CSRF跨站请求伪造 跨站请求伪造&#xff08;英语&#xff1a;Cross-site request forgery&#xff09;&#xff0c;也被称为 one-click attack 或者 session riding&#xff0c;通常缩写为 CSRF 或者 XSRF&#xff0c; 是一种挟制用户在当前已登录的Web应用程序上执行非本意的…