李沐之经典卷积神经网络

目录

1. LeNet

2. 代码实现


1. LeNet

输入是32*32图片,放到一个5*5的卷积层里面,卷积层的输出通道数是6,高宽都是28(32-5+1=28)。再经过2*2的池化层,把28*28变成14*14(28-2+2)/2=14,这里的步幅和窗口的大小一样。再经过5*5的卷积层,输出就变成10*10的(14-5+1=10)。通道数增加了从6变到16。再经过一个池化层,也是16个通道(用了16个卷积核),大小是5*5(10-2+2)/2=5。再把它拉成一个向量输入到全连接层,第一个全连接层的输出是120,第二个的输出是84,最后一个的输出是10。

2. 代码实现

import torch
from torch import nn
from d2l import torch as d2l#原来的数据集是32*32,这里为了一样每边各padding2行(同时也是为了框住字),4个边就是32行*32列,
#为了得到非线性加一个sigmoid函数
net=nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,keinel_size=5),nn.Sigmoid(),nn.AvgPool2d(kernel_size=2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10))X=torch.rand(size=(1,1,28,28),dtype=torch.float32)
for layer in net:X=layer(X)print(layer.__class__.__name__,'output shape:\t',X.shape)
"""结果输出:
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])
"""
#在整个卷积块中,与上一层相比,每一层特征的高度和宽度都减小了。 第一个卷积层使用2个像素的填充,
#来补偿卷积核导致的特征减少。 相反,第二个卷积层没有填充,因此高度和宽度都减少了4个像素。 随着
#层叠的上升,通道的数量从输入时的1个,增加到第一个卷积层之后的6个,再到第二个卷积层之后的16个。 
#同时,每个汇聚层的高度和宽度都减半。最后,每个全连接层减少维数,最终输出一个维数与结果分类数
#相匹配的输出。#第一个模块卷积激活池化把1通道,28*28变成6通道14*14,高宽减半,通道数增加了6倍,信息变多了。#看看LeNet在Fashion-MNIST数据集上的表现
batch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size=batch_size)#evaluate_accuracy函数进行轻微的修改, 由于完整的数据集位于内存中,因此在模型使用GPU计算数
#据集之前,我们需要将其复制到显存中。
def evaluate_accuracy_gpu(net,data_iter,device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net,nn.Module):#isinstance() 函数来判断一个对象是否是一个已知的类型net.eval()        #设置为评估模式if not device:#如果device没有给定device=next(iter(net.parameters())).device#把net的参数构建成一个迭代器,把第一个net的参数拿出来看他的device在哪里#正确预测的数量,总预测的数量metric=d2l.Accumulator(2)with torch.no_grad():for X,y in data_iter:if isinstance(X,list):#如果X的类型是一个list,就每一个都挪到那个device上面去X=[x.to(device) for x in X]else:#如果是个tensor就挪一次X=X.to(device)y=y.to(device)metric=add(d2l.accuracy(net(X),y),y.numel())return metric[0]/metric[1]def train_ch6(net,train_iter,test_iter,num_epochs,lr,device):"""用GPU训练模型"""def init_weights(m):if type(m)==mm.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)#对每一个parameter都run一下初始化权重函数print('training on',device)net.to(device)#把整个参数搬到GPU上optimizer=torch.optim.SGD(net.parameters(),lr=lr)loss=nn.CrossEntropyLoss()animator=d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],legend=['train loss','train acc','test_acc'])timer,num_batches=d2l.Timer(),len(train_iter)for epoch in range(num_epochs):metric=d2l.Accumulator(3)net.train()for i,(X,y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X,y=X.to(device),y.to(device)y_hat=net(X)l=loss(y_hat,y)l.backward()optimizer.step()with torch.no_grad():metric.add(l*X.shape[0],d2l.accuracy(y_hat,y),X.shape[0])timer.stop()train_l=metric[0]/metric[2]train_acc=metric[1]/metric[2]if(i+1)%(num_batches//5)==0 or i==num_batches-1:animator.add(epoch+(i+1)/num_batches),(train_l,train_acc,None)test_acc=evaluate_accuracy_gpu(net,test_iter)animator.add(epoch+1,(None,None,test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')#训练和评估LeNet-5模型。
lr,num_epochs=0.9,10
train_ch6=(net,train_iter,num_epochs,lr,d2l.try_gpu())
"""结果输出:
loss 0.469, train acc 0.823, test acc 0.779
55296.6 examples/sec on cuda:0"""

  • 卷积神经网络(CNN)是一类使用卷积层的网络。

  • 在卷积神经网络中,我们组合使用卷积层、非线性激活函数和汇聚层。

  • 为了构造高性能的卷积神经网络,我们通常对卷积层进行排列,逐渐降低其表示的空间分辨率,同时增加通道数。

  • 在传统的卷积神经网络中,卷积块编码得到的表征在输出之前需由一个或多个全连接层进行处理。

  • LeNet是最早发布的卷积神经网络之一。

参考:

python中isinstance()函数详解_python instance函数-CSDN博客

Pytorch torch.device()的简单用法_torch.device('cuda:0')-CSDN博客

Python迭代器基本方法iter()及其魔法方法__iter__()原理详解-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/341498.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python高校舆情分析系统+可视化+情感分析 舆情分析+Flask框架(源码+文档)✅

毕业设计:2023-2024年计算机专业毕业设计选题汇总(建议收藏) 毕业设计:2023-2024年最新最全计算机专业毕设选题推荐汇总 🍅感兴趣的可以先收藏起来,点赞、关注不迷路,大家在毕设选题&#xff…

Python——python练习题

1.小明身高1.75,体重80.5kg。请根据BMI公式(体重除以身高的平方)帮小明计算他的BMI指数,并根据BMI指数: 低于18.5:过轻 18.5-25:正常 25-28:过重 28-32:肥胖 高于32&…

如何从多个文件夹里各提取相应数量的文件放一起到新文件夹中形成多文件夹组合

首先,需要用到的这个工具: 百度 密码:qwu2蓝奏云 密码:2r1z 说明一下情况 文件夹:1、2、3里面分别放置了各100张动物的图片,模拟实际情况的各种文件 操作:这里演示的是从3个文件夹里各取2张图…

kafka下载安装部署

Apache kafka 是一个分布式的基于push-subscribe的消息系统,它具备快速、可扩展、可持久化的特点。它现在是Apache旗下的一个开源系统,作为hadoop生态系统的一部分,被各种商业公司广泛应用。它的最大的特性就是可以实时的处理大量数据以满足各…

Flink异步IO

本文讲解 Flink 用于访问外部数据存储的异步 I/O API。对于不熟悉异步或者事件驱动编程的用户,建议先储备一些关于 Future 和事件驱动编程的知识。 本文代码gitee地址: https://gitee.com/ddxygq/BigDataTechnical/blob/main/Flink/src/main/java/operator/AsyncIODemo.java …

【libpcap】获取报文pcap的ns级别的时间戳

1.安装libpcap 首先&#xff0c;下载最新的 libpcap 源代码。你可以从 tcpdump.org 获取最新版本 1 解压下载的libpcap tar -zxvf libpcap-version.tar.gz 2 进入解压目录进行安装 cd libpcap-version ./configure make sudo make install2 解析报文时间戳 #include <pca…

第11章 GUI Page489~494 步骤三十 保存画板文件

为“保存”菜单项 MenuItemFileSave挂接事件响应函数&#xff1a; 实际运行时&#xff0c;现版TrySaveFile()函数有点儿傻&#xff0c;点击保存菜单&#xff0c;还会弹出对话框&#xff0c;问我们“要不要保存” 修改TrySaveFile()函数 函数声明修改为&#xff1a; 函数实现修…

工业以太网的网络安全与数据传输性能

工业以太网主要是一种用于工业控制系统的网络通信协议&#xff0c;它基于以太网技术&#xff0c;将其应用于工业环境中&#xff0c;以实现高速、可靠、安全的数据传输。跟传统的专用工业网络比较&#xff0c; 工业以太网具有更大的带宽、更低的成本以及更好的扩展性&#xff0c…

Maven 依赖管理项目构建工具 教程

Maven依赖管理项目构建工具 此文档为 尚硅谷 B站maven视频学习文档&#xff0c;由官方文档搬运而来&#xff0c;仅用来当作学习笔记用途&#xff0c;侵删。 另&#xff1a;原maven教程短而精&#xff0c;值得推荐&#xff0c;下附教程链接。 atguigu 23年Maven教程 目录 文章目…

【数据库学习】ClickHouse(ck)

1&#xff0c;ClickHouse&#xff08;CK&#xff09; 是一个用于联机分析(OLAP)的列式数据库管理系统(DBMS)。 1&#xff09;特性 按列存储&#xff0c;列越多速度越慢&#xff1b; 按列存储&#xff0c;数据更容易压缩&#xff08;类型相同、区分度&#xff09;&#xff1b…

半Happy的一天

终于差不多将SWMM模型与LisFlood模型耦合运转起来了 MDL的雏型也出来了&#xff0c;注册了模型方法和参数&#xff0c;差一个方法参数 晚上和师兄聊了聊未来规划&#xff0c;回顾了这半年研究生生涯的“拍烂”生活&#xff08;其实也没特别摆烂&#xff0c;还是学了不少东西&…

JDBC

1 连接JDBC jdbc是连接java和数据库的桥梁&#xff0c;对于不同的数据库&#xff0c;如果我们希望用java连接&#xff0c;我们需要下载不同的驱动。这里我们使用mysql数据库&#xff0c;下载驱动。 MySQL :: Download MySQL Connector/J (Archived Versions) &#xff08;版本…