【Pytorch深度学习开发实践学习】Pytorch实现LeNet神经网络(1)

1.model.py

import torch.nn as nn
import torch.nn.functional as F

引入pytorch的两个模块
关于这两个模块的作用,可以参考下面
Pytorch官方文档
在这里插入图片描述
torch.nn包含了构成计算图的基本模块
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
torch,nn.function包括了计算图中的各种主要函数,包括:卷积函数、池化函数、注意力机制函数、非线性激活函数、dropout函数、线性函数、距离函数、损失函数、可视化函数和多GPU分布式函数等。

class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)

我们构建了从torch.nn.Module类下面继承的LeNet类,这个类构建了LeNet神经网络,所有pytorch定义的神经网络模型都从nn.Module类派生。

首先def init(self):是类的初始化函数,当我们创建这个类的一个实例时,这个函数会被调用。

super(LeNet, self).init()是调用父类nn.Module的初始化函数,这是一个标准的做法,确保父类中的任何初始化都被正确的执行。

self.conv1 = nn.Conv2d(3, 16, 5)调用nn.function的二维卷积函数,定义了第一个卷积层的操作,输入的通道数是3,输出通道数16,5*5的卷积核

self.pool1 = nn.MaxPool2d(2, 2)调用了最大池化函数,定义了第一个池化层的操作,池化窗口的大小是2*2

self.conv2 = nn.Conv2d(16, 32, 5)定义了第二个卷积层的操作,输入通道数是16,因为第一个卷积层的输出通道数是16,这一层的输出通道数是32,5*5的卷积核

self.pool2 = nn.MaxPool2d(2, 2)调用了最大池化函数,定义了第二个池化层的操作,池化窗口的大小是2*2

self.fc1 = nn.Linear(3255, 120)定义了一个全连接层,输入特征是3255,这是前一层输出的特征数与卷积核大小、步长和填充的综合结果),输出特征数为120。

self.fc2 = nn.Linear(120, 84)定义另一个全连接层,输入特征数为120,输出特征数为84。

self.fc3 = nn.Linear(84, 10) 定义最后一个全连接层,输入特征数为84,输出特征数为10。这个输出特征数通常对应于分类问题的类别数(如果有10个类别的话)。

LeNet模型是一个经典的卷积神经网络结构,主要用于图像分类任务。它包括两个卷积层、两个池化层和三个全连接层。

    def forward(self, x):x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

这段代码定义了一个神经网络的前向传播过程。

def forward(self, x):
定义一个名为forward的方法,它描述了数据从输入到输出的前向传播过程。

x = F.relu(self.conv1(x))
输入数据x经过第一个卷积层self.conv1,然后通过ReLU激活函数。输入是33232,输出的大小为(16, 28, 28)。

x = self.pool1(x)
对上一步的输出进行最大池化操作,输出的大小变为(16, 14, 14)。

x = F.relu(self.conv2(x))
输入数据x经过第二个卷积层self.conv2,然后通过ReLU激活函数。输出的大小为(32, 10, 10)。

x = self.pool2(x)
对上一步的输出进行最大池化操作,输出的大小变为(32, 5, 5)。

x = x.view(-1, 3255)
对上一步的输出进行展平操作,即将三维数据变为二维数据。输出的大小为(-1, 3255),其中-1表示批量大小,可以自动计算。

x = F.relu(self.fc1(x))
对展平后的数据x进行全连接操作,然后通过ReLU激活函数。输出的大小为120

x = F.relu(self.fc2(x))
对上一步的输出进行全连接操作,然后通过ReLU激活函数。输出的大小为84

x = self.fc3(x)
对上一步的输出进行全连接操作,不使用激活函数。输出的大小为(10)。这通常表示有10个类别。

return x
返回最终的输出结果。

这个前向传播过程描述了数据从输入到输出的整个流程,包括卷积、池化、全连接和激活函数等操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/496178.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何让网页APP化 渐进式Web应用(PWA)

前言 大家上网应该发现有的网页说可以安装对应应用,结果这个应用好像就是个web,不像是应用,因为这里采用了PWA相关技术。 PWA,全称为渐进式Web应用(Progressive Web Apps),是一种可以提供类似…

Nginx重写功能和反向代理

目录 一、重写功能rewrite 1. ngx_http_rewrite_module模块指令 1.1 if 指令 1.2 return 指令 1.3 set 指令 1.4 break 指令 2. rewrite 指令 3. 防盗链 3.1 实现盗链 3.2 实现防盗链 4. 实用网址 二、反向代理 1. 概述 2. 相关概念 3. 反向代理模块 4. 参数配…

读《Shape-Guided: Shape-Guided Dual-Memory Learning for 3D Anomaly Detection》

Chu Y M, Chieh L, Hsieh T I, et al. Shape-Guided Dual-Memory Learning for 3D Anomaly Detection[J]. 2023.(为毛paperwithcode上面曾经的榜一引用却只有1) 摘要 专家学习 无监督 第一个专家:局部几何,距离建模 第二个专家&…

Spring篇----第十二篇

系列文章目录 文章目录 系列文章目录前言一、指出在 spring aop 中 concern 和 crosscuttingconcern 的不同之处二、AOP 有哪些实现方式?三、Spring AOP and AspectJ AOP 有什么区别?四、如何理解 Spring 中的代理?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,…

STC-ISP原厂代码研究之 V3.7d汇编版本

最近在研究STC的ISP程序,用来做一个上位机烧录软件,逆向了上位机软件,有些地方始终没看明白,因此尝试读取它的ISP代码,但是没有读取成功。应该是目前的芯片架构已经将引导代码放入在了单独的存储块中,而这存储块有硬件级的使能线,在面包板社区-宏晶STC单片机的ISP的BIN文…

基于视觉识别的自动采摘机器人设计与实现

一、前言 1.1 项目介绍 【1】项目功能介绍 随着科技的进步和农业现代化的发展,农业生产效率与质量的提升成为重要的研究对象。其中,果蔬采摘环节在很大程度上影响着整个产业链的效益。传统的手工采摘方式不仅劳动强度大、效率低下,而且在劳…

李沐动手学习深度学习——3.1练习

字写的有点丑不要介意 由于公式推导烦的要死,所以手写形式,欢迎进行讨论,因为我也不知道对错

【Vue渗透】Vue站点渗透思路

原文地址 极核GetShell 前言 本文经验适用于前端用Webpack打包的Vue站点,阅读完本文,可以识别出Webpack打包的Vue站点,同时可以发现该Vue站点的路由。 成果而言:可能可以发现未授权访问。 识别Vue 识别出Webpack打包的Vue站…

BevFusion (2): nuScenes 数据介绍及点云可视化

1. nuScenes 数据集 1.1 概述 nuScenes 数据集 (pronounced /nu:ːsiː:nz/) 是由 Motional (以前称为 nuTonomy) 团队开发的自动驾驶公共大型数据集。nuScenes 数据集的灵感来自于开创性的 KITTI 数据集。 nuScenes 是第一个提供自动驾驶车辆整个传感器套件 (6 个摄像头、1 …

网络安全之内容安全

内容安全 攻击可能只是一个点,防御需要全方面进行 IAE引擎 DFI和DPI技术--- 深度检测技术 DPI --- 深度包检测技术--- 主要针对完整的数据包(数据包分片,分段需要重组),之后对 数据包的内容进行识别。(应用…

使用 Debezium 和 RisingWave 对 MongoDB 进行持续分析

MongoDB 和流式 Join 的挑战 谷歌趋势显示,有关 MongoDB 流式计算的搜索率不断上升 作为一种操作型数据库,MongoDB 在提供快速数据操作和查询性能方面表现十分出色。然而,在维护实时视图或执行流处理任务的内置支持方面,它确实存…

spring-boot static-path-pattern如何配置生效

WebMvcAutoConfiguration AbstractUrlHandlerMapping ResourceHttpRequestHandler springboot 版本 2.3.9.RELEASE 一、如何用 yaml配置 spring:mvc:static-path-pattern: /doctest/**resources:static-locations: classpath:/doc/资源文件配置 访问路径 二、原理 第一个问…