深度学习入门笔记——神经网络的构建和使用

news/2025/2/28 20:16:33/文章来源:https://www.cnblogs.com/cyMessi/p/18515297

神经网络的整体构建

神经网络的基本骨架

首先可以在Pytorch官网的Python API中查看torch.nn的使用,如下所示。可以看到神经网络包括Container(基本骨架)、卷积层、池化层、Padding层、非线性激活等等。
构建一个神经网络首先要先构建起基本骨架,也就是Containers

nn.Moudle的使用

这是官网中给出的具体示例,重点在于创建我们自己的神经网络类的时候必须要继承父类nn.Moudle,然后就可以重写里面的函数等,这里的forward是前向传播函数,后面会有反向传播函数

这是一个简单的nn.Moudle使用示例,并没有涉及到神经网络的卷积层等。可以通过断点调试来查看具体的代码执行流程

from torch import nn
class CY(nn.Module):def __init__(self):super().__init__()def forward(self,input):output=input+1return outputcy=CY()
input=1
output= cy(input)
print(output)

卷积层

构建好基本骨架之后,就需要对卷积层进行操作,可以看到官方给出的卷积层包括以下方式,其中对于图像来说常用的就是卷积2d操作

图像的卷积

首先明确一下图像卷积的概念,如下图所示,图像卷积就是用卷积核在输入图像上一步步的滑动,每个方格内的元素对应相乘后相加作为输出的对应位置的元素

官方文档中给出的示例是这样的,对于参数的解释已经很详细了
这里要注意的一个点就是 卷积层的输入和卷积核都要描述成(N,C,H,W)的tensor格式,其中N表示有多少张图片,C表示有多少个通道,H表示图片的高度,W表示图片的宽度。所以初始设置输入的时候不仅要用torch.tensor变成tensor格式,后续还需要将torch.reshape(input,[1,1,5,5])转变为conv2d的格式,因为初始格式是只有宽和高这两个参数的

这其中有几个参数可以解释一下:

  • stride:也就是卷积核一次移动的步数
  • padding:是否要将输入图像进行零填充,默认为0.可以看到设置填充之后,卷积得到的结果会比原来的大

具体代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Finput=torch.tensor([[1,2,0,3,1],  # 输入图像[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]])
kernel=torch.tensor([[1,2,1],     # 卷积核[0,1,0],[2,1,0]])
print(input.shape)
input = torch.reshape(input,[1,1,5,5])
kernel=torch.reshape(kernel,[1,1,3,3])output= F.conv2d(input,kernel,stride=1)
print(output)
output_1= F.conv2d(input,kernel,stride=2)
print(output_1)
output_2=F.conv2d(input,kernel,stride=1,padding=1)
print(output_2)

nn.conv2d的使用

官方给出的函数使用方法如下:

这里要注意的就是in_channels和out_channels的理解,可以说in_channels就是图像的通道数,也就是RGB=3,out_channels代表的是用多少个卷积核来对图像进行卷积,如果out_channels=6的时候就是用6个卷积核来对图像进行卷积,然后对得到的输出进行处理

参数的具体描述如下:

具体代码如下,要注意的是使用tensorboard对图像进行显示的时候,由于tensorboard显示的图像格式是规定的3个通道,所以上面得到的6个通道的图像是会报错的。所以我们可以用 output = torch.reshape(output, (-1, 3, 30, 30))来将图像格式进行重新设置,其中-1表示的是占位符,表示这个位置的参数交给后面的参数来计算

# -*- coding: utf-8 -*-
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("dataset/cifar-10-batches-py", train=False, transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()writer = SummaryWriter("logs")step = 0
for data in dataloader:imgs, targets = dataoutput = tudui(imgs)print(imgs.shape)print(output.shape)# torch.Size([64, 3, 32, 32])writer.add_images("input", imgs, step)# torch.Size([64, 6, 30, 30])  -> [xxx, 3, 30, 30]output = torch.reshape(output, (-1, 3, 30, 30))writer.add_images("output", output, step)step = step + 1writer.close()

最后得到到图像是这样的:


可以看到输出图像一个批次中有128个图像,这也就是我们将6个通道变为8个通道导致的,和初步设想一致

池化层

这里主要讲解的是2D类型最大池化层,同样的,详细的函数信息在官网上:

主要注意的就是**ceil_mode 这个参数,这里的意思其实就是要向下取整还是向上取整,如果为True的话就是向上取整,False的话就是向下取整。也就是说,在下图这个示例中,如果取为True的时候在进行池化的时候对于多出来的部分(原图像是5×5,池化核是3×3),会进行保留并得出结果,而为False的时候就不会保留结果。

dilation这个参数其实就是池化的时候是否要跳步进行
代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../dataset", train=False, download=True,transform=torchvision.transforms.ToTensor())dataloader = DataLoader(dataset, batch_size=64)class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=False)def forward(self, input):output = self.maxpool1(input)return outputtudui = Tudui()writer = SummaryWriter("../logs_maxpool")
step = 0for data in dataloader:imgs, targets = datawriter.add_images("input", imgs, step)output = tudui(imgs)writer.add_images("output", output, step)step = step + 1writer.close()

得到的结果如下,其实池化就是相当于做一个缩略马赛克处理

非线性激活

非线性激活就是例如ReLu、Sigmod等非线性激活函数,在Pytorch中的使用是比较简单的,调用函数即可,例如Sigmod函数:

这里要注意inplace的作用就是是否要有一个新的返回值来存储输出值,默认为False,如果为True的话输出值覆盖输入值

除了上面列举的一些神经网络最基本必须的网络之外,torch.nn中还有很多其他的层:正则化层、线性层、Transformer层等等,有一些在特定的网络中需要特定使用,可以去了解一下

Sequential的作用

sequential的作用就是将我们要创建的神经网络的层数按照顺序堆叠起来,个人觉得用处就是简化代码,后面可以再了解看看,如下图所示,用sequential堆叠起神经网络之后就可以直接创建实例并输入。相较于用x输出承接x输入是简洁很多的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/852650.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机做的所有事情都叫计算

计算机怎么解决问题?答:需要告诉计算机解决问题的步骤(不要写成说明书了) 怎么告诉计算机这个步骤?答:编程语言写程序 1. 写程序不是表达关系,是表达动作2. 是解决问题的步骤,编程的时候不是你说一句它做一句3. 编程语言不是用来和计算机交流的4. 计算机的交流是你的操…

第四章 文件管理

文件 4.1.1 文件的基本概念文件是指由创建者所定义的、具有文件名的一组相关元素的集合,是以硬盘为载体的存储在计算机上的信息集合 是文件系统中最大的数据单位 在用户进行的输入,输出中,则以文件为基本单位4.1.5 文件的逻辑结构 按文件是否有结构分类 无结构文件 文件内部…

程序执行两种方式

1.你写的程序交文件给它,它一步步按照你的要求执行 2.写的程序文件交给它,它翻译成计算机懂的文件,用计算机懂的文件执行 解释语言vs编译语言1. 语言本身没有解释和编译的区分,任何语言都可以编译执行和解释执行。2. 只是语言常用执行方式的传统和习惯的问题3. 解释语言 特…

P1070 [NOIP2009 普及组] 道路游戏

ProblemSolve 此题是求最优解,考虑贪心时会发现这个不满足局部最优->整体最优,故考虑DP 通过输入格式能受到启发,时间可以作为维度之一,所以定义为: \(f_{i,j}\)第i秒末,机器人在j号工厂能获得的最大金币 因为机器存在时间有上限,所以推的时候枚举本次机器人到底走了多…

2024-12-14:K 周期字符串需要的最少操作次数。用go语言,给定一个长度为n的字符串 word 和一个整数k,k是n的因数。每次操作可以选择两个下标i和j,使得i和j都可以被k整除,然后用从j

2024-12-14:K 周期字符串需要的最少操作次数。用go语言,给定一个长度为n的字符串 word 和一个整数k,k是n的因数。每次操作可以选择两个下标i和j,使得i和j都可以被k整除,然后用从j开始的长度为k的子串替换从i开始的长度为k的子串。要使得word成为一个K周期字符串,需要进行…

实现综合实例:简单文字处软件 (一)

学业繁重,更新缓慢。 本内容主要用于个人学习/复习QT简单入门控件 DAY ONE 创建项目界面设计与开发 实现简单的菜单栏设计本人并没有使用代码实现,而是用于使用UI设计师界面。action条例分类 设计控件(帮助) 设计帮助控件: 我们转到槽,填写如下代码: 这是一个基于QT6实现…

微信防撤回插件

插件 https://pan.quark.cn/s/bb5165185a6a部署 先查看电脑微信版本,比如我这里是3.9.12.15版本下载对应版本之后,将插件名字改为WeChatWin.dll,删掉前面的版本号在微信所在的文件夹下,找到这个同名插件,用下载的插件替换它即可 end 替换之后需要重启微信才可以,效果如下…

事务管理与锁机制

title: 事务管理与锁机制 date: 2024/12/14 updated: 2024/12/14 author: cmdragon excerpt: 在数据库系统中,事务管理至关重要,它确保多个数据库操作能够作为一个单一的逻辑单元来执行,从而维护数据的一致性和完整性。一个良好的事务管理系统能够解决并发操作带来的问题,…

2024-2025-1 20241319 《计算机基础与程序设计》第十二周学习总结

作业信息这个作业属于哪个课程 2024-2025-1-计算机基础与程序设计这个作业要求在哪里 https://www.cnblogs.com/rocedu/p/9577842.html#WEEK12这个作业的目标 结构体和数据结构基础 文件操作作业正文 https://www.cnblogs.com/wchxx/p/18607077教材学习内容总结 结构体(Struct…

数据采集综合设计

这个项目属于哪个课程2024数据采集与融合技术实践 组名 从你的全世界爬过团队logo:项目简介 项目名称:博物识植项目logo:项目介绍:在探索自然奥秘的旅途中,我们常与动植物相伴而行,却无法准确识别它们,更难以深入了解他们的特征。为了更好地理解和欣赏自然界的多样性,…

Java中创建线程的几种方式

盘点一下Java中创建线程的几种方式 一、继承Thread类,重写run()方法public class MyThread extends Thread {@Overridepublic void run() {System.out.println("my thread start " + Thread.currentThread().getName());}public static void main(String[] args) {S…

消防通道堵塞识别摄像机

消防通道堵塞识别摄像机是一种安装在建筑物消防通道中的监控设备,主要用于监测消防通道是否被车辆、杂物或其他障碍物所堵塞,以确保在火灾等紧急情况下消防通道畅通无阻。这种摄像机通常安装在消防通道的入口或周围,具备高清摄像功能,能够全天候监测通道状况。一旦摄像机检…