《动手学深度学习(PyTorch版)》笔记6.3

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter6 Convolutional Neural Network(CNN)

6.3 LeNet

在这里插入图片描述

LeNet模型中每个卷积块中的基本单元是一个卷积层、一个sigmoid激活函数和平均汇聚层。虽然ReLU和最大汇聚层更有效,但它们在20世纪90年代还没有出现。每个卷积层使用 5 × 5 5\times 5 5×5卷积核和一个sigmoid激活函数,这些层将输入映射到多个二维特征输出,通常同时增加通道的数量。第一卷积层有6个输出通道,而第二个卷积层有16个输出通道,每个 2 × 2 2\times2 2×2池操作(步幅2)通过空间下采样将维数减少4倍。为了将卷积块的输出传递给稠密块,我们必须在小批量中展平每个样本。换言之,我们将这个四维输入转换成全连接层所期望的二维输入(第一个维度索引小批量中的样本,第二个维度给出每个样本的平面向量表示)。LeNet的稠密块有三个全连接层,分别有120、84和10个输出。因为我们在执行分类任务,所以输出层的10维对应于最后输出结果的数量。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as pltnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).devicemetric = d2l.Accumulator(2)#metric[0]:正确预测的数量,metric[1]:总预测的数量。with torch.no_grad():for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]# BERT微调所需,将输入转移到GPU上(之后将介绍)else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):#@save"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):metric = d2l.Accumulator(3)#metric[i]分别是训练损失之和,训练准确率之和,样本数net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/449068.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Entity实体设计

Entity实体设计 💡用来和数据库中的表对应,解决的是数据格式在Java和数据库间的转换。 (一)设计思想 数据库Java表类行对象字段(列)属性 (二)实体Entity编程 编码规范 &#x1f4a…

分别用JavaScript,Java,PHP,C++实现桶排序的算法(附带源码)

桶排序是计数排序的升级版。它利用了函数的映射关系,高效与否的关键就在于这个映射函数的确定。为了使桶排序更加高效,我们需要做到这两点: 在额外空间充足的情况下,尽量增大桶的数量使用的映射函数能够将输入的 N 个数据均匀的分…

python-自动化篇-运维-监控-Python如何与Prometheus集成?

要将Python与Prometheus集成,可以使⽤Prometheus提供的客⼾端库来公开指标(metrics)供Prometheus采集。 Prometheus是⼀个开源的监控和警报⼯具,⽀持多种数据采集⽅式,其中之⼀是通过HTTP端点公开指标。以下是⼀些步骤…

【Nginx】Ubuntu如何安装使用Nginx反向代理?

文章目录 使用Nginx反向代理2个web接口服务步骤 1:安装 Nginx步骤 2:启动 Nginx 服务步骤 3:配置 Nginx步骤 4:启用配置步骤 5:检查配置步骤 6:重启 Nginx步骤 7:访问网站 proxy_set_header 含义…

【前沿技术杂谈:开源软件】引领技术创新与商业模式的革命

【前沿技术杂谈:开源软件】引领技术创新与商业模式的革命 开源软件如何推动技术创新开源软件的开放性和协作精神促进知识共享和技术迭代推动关键技术的发展开源软件与新技术的融合 开源软件的商业模式开源软件的商业模式将开源软件与商业软件相结合 开源软件的安全风…

OpenCV 15 - 卷积边缘处理

卷积处理边缘 图像卷积的时候边界像素,不能被卷积操作,原因在于边界像素没有完全跟kernel重叠,所以当3x3滤波时候有1个像素的边缘没有被处理,5x5滤波的时候有2个像素的边缘没有被处理 处理方法 在卷积开始之前增加边缘像素,填充的像素值为0或者RGB黑色,比如3x3在四周各填…

HiSilicon352 android9.0 开机视频调试分析

一,开机视频概念 开机广告是在系统开机后实现播放视频功能。 海思Android解决方案在原生Android基础上,增加了开机视频模块,可在开机过程中播放视频文件,使用户更好的体验系统开机过程。 二,模块结构 1. 海思自研开机…

MySQL的ACID、死锁、MVCC问题

1 ACID ACID代表原子性(atomicity)、一致性(consistency)、隔离性(isolation)和持久性(durability)。一个确保数据安全的事务处理系统,必须满足这些密切相关的标准。 原…

spring boot学习第九篇:操作mongo的集合和集合中的数据

1、安装好了Mongodb 参考&#xff1a;ubuntu安装mongod、配置用户访问、添删改查-CSDN博客 2、pom.xml文件内容如下&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns…

MATLAB绘制电磁场

MATLAB绘制电磁场举例: clc;close all;clear all;warning off;%清除变量 rand(seed, 100); randn(seed, 100); format long g; m12 for k1:m for j1:m if k1 V(j,k)1; elseif((j1)|(jm)|(km)) V(j,k)0; else …

【MySQL】深入理解隔离性

深入理解隔离性 一、数据库并发的场景二、多版本并发控制&#xff08; MVCC &#xff09;三、三个前提知识1、3个记录隐藏字段2、undo日志 四、快照的概念五、Read View六、隔离级别RR与RC的本质区别 一、数据库并发的场景 数据库并发的场景总共有三种&#xff1a; 读-读&…

PPT录屏功能在哪?一键快速找到它!

在现代办公环境中&#xff0c;ppt的录屏功能日益受到关注&#xff0c;它不仅能帮助我们记录演示文稿的播放过程&#xff0c;还能将操作过程、游戏等内容完美录制下来。可是很多人不知道ppt录屏功能在哪&#xff0c;本文将为您介绍ppt录屏的打开方法&#xff0c;以帮助读者更好地…