简单的神经网络

一、softmax的基本概念

我们之前学过sigmoid、relu、tanh等等激活函数,今天我们来看一下softmax。

先简单回顾一些其他激活函数:

  1. Sigmoid激活函数:Sigmoid函数(也称为Logistic函数)是一种常见的激活函数,它将输入映射到0到1之间。它常用于二分类问题中,特别是在输出层以概率形式表示结果时。Sigmoid函数的优点是输出值限定在0到1之间,相当于对每个神经元的输出进行了归一化处理。
  2. Tanh激活函数:Tanh函数(双曲正切函数)将输入映射到-1到1之间。与Sigmoid函数相比,Tanh函数的中心点在零值附近,这意味着它的输出是以0为中心的。这种特性可以在某些情况下提供更好的性能。
  3. ReLU激活函数:ReLU(Rectified Linear Unit)函数是当前非常流行的一个激活函数,其表达式为f(x)=max(0, x)。ReLU函数的优点是计算简单,能够在正向传播过程中加速计算。此外,ReLU函数在正值区间内梯度为常数,有助于缓解梯度消失问题。但它的缺点是在负值区间内梯度为零,这可能导致某些神经元永远不会被激活,即“死亡ReLU”问题。

Softmax函数是一种在机器学习中广泛使用的函数,尤其是在处理多分类问题时。它的主要作用是将一组未归一化的分数转换成一个概率分布。Softmax函数的一个重要性质是其输出的总和等于1,这符合概率分布的定义。这意味着它可以将一组原始分数转换为概率空间,使得每个类别都有一个明确的概率值。

  • 二分类问题选择sigmoid激活函数

  • 多分类问题选择softmax激活函数

二、交叉熵损失函数

交叉熵损失函数的公式可以分为二分类和多分类两种情况。对于二分类问题,假设我们只考虑正类(标签为1)和负类(标签为0)在多分类问题中,交叉熵损失函数可以扩展为−∑𝑖=1𝐾𝑦𝑖⋅log⁡(𝑝𝑖)−∑i=1K​yi​⋅log(pi​),其中𝐾K是类别的总数,( y_i )是样本属于第𝑖i个类别的真实概率(通常用one-hot编码表示),而𝑝𝑖pi​是模型预测该样本属于第( i )个类别的概率。

import torch
from torch import nn# 确定随机数种子
torch.manual_seed(7)
# 自定义数据集
X = torch.rand((7, 2, 2))
target = torch.randint(0, 2, (7,))

定义网络结构

  • 一层全连接层 + Softmax层
  • x1𝑥1,x2𝑥2,x3𝑥3,x4𝑥4为 X
  • o1𝑜1,o2𝑜2,o3𝑜3为 target
class LinearNet(nn.Module):def __init__(self):super(LinearNet, self).__init__()# 定义一层全连接层self.dense = nn.Linear(4, 3)# 定义Softmaxself.softmax = nn.Softmax(dim=1)def forward(self, x):y = self.dense(x.view((-1, 4)))y = self.softmax(y)return ynet = LinearNet()
  •  nn.Softmax(dim=1)用于计算输入张量在指定维度上的softmax激活。dim=1表示沿着第二个维度(即列)进行softmax操作。

定义损失函数和优化函数

  • torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
  • 衡量模型输出与真实标签的差异,在分类时相当有用。
  • 结合了nn.LogSoftmax()和nn.NLLLoss()两个函数,进行交叉熵计算。
loss = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)  # 随机梯度下降法

训练模型

for epoch in range(70):train_l = 0.0y_hat = net(X)l = loss(y_hat, target).sum()# 梯度清零optimizer.zero_grad()# 自动求导梯度l.backward()# 利用优化函数调整所有权重参数optimizer.step()train_l += lprint('epoch %d, loss %.4f' % (epoch + 1, train_l))

三、自动微分模块

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False)  :自动求取梯度

  • grad_tensors:多梯度权重
  • create_graph:创建导数计算图,用于高阶求导
  • retain_graph:保存计算图
  • tensors:用于求导的张量,如 loss
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)y.backward(retain_graph=True)

 注意点:

  1. 梯度不自动清零
  2. 依赖于叶子节点的节点,requires_grad默认为True
  3. 叶子节点不可执行in-place

神经网络全连接层: 每个神经元都与前一层的所有神经元相连接。全连接层通常用于网络的最后几层,它将之前层(如卷积层和池化层)提取的特征进行整合,以映射到样本标记空间,即最终的分类或回归结果。

关于loss.backward()方法:

主要作用就是计算损失函数对模型参数的梯度,loss.backward()实现了反向传播算法,它通过链式法则计算每个模型参数相对于最终损失的梯度。这个过程从输出层开始,向后传递到输入层,逐层计算梯度。

过程:得到每个参数相对于损失函数的梯度,这些梯度信息会存储在对应张量的.grad属性中。loss.backward本身不负责更细权重,但它为权重更新提供了梯度值,方便配合optimizer.step()来更新参数。

前向传播过程中,数据从输入层流向输出层,并生成预测结果;而在反向传播过程中,误差(即预测值与真实值之间的差距,也就是损失函数的值)会从输出层向输入层传播,逐层计算出每个参数相对于损失函数的梯度。这些梯度指示了如何调整每一层中的权重和偏置,以最小化损失函数。

  • 损失函数衡量了当前模型预测与真实情况之间的不一致程度,而梯度则提供了损失函数减少最快的方向。

建立一个简单的全连接层:

import torch
import torch.nn as nn# 定义一个简单的全连接层模型
class SimpleFC(nn.Module):def __init__(self, input_size, output_size):super(SimpleFC, self).__init__()self.fc = nn.Linear(input_size, output_size)def forward(self, x):  return self.fc(x)# 创建输入数据和目标输出
input_data = torch.tensor([[1.0, 2.0, 3.0]])
target_output = torch.tensor([[4.0, 5.0]])# 实例化模型、损失函数和优化器
model = SimpleFC(input_size=3, output_size=2)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 前向传播
output = model(input_data)# 计算损失
loss = criterion(output, target_output)# 反向传播
loss.backward()# 更新参数
optimizer.step()

当调用loss.backward()时,PyTorch会自动计算损失值关于模型参数的梯度,并将这些梯度存储在模型参数的.grad属性中。然后优化器(torch.optim.SGD)可以使用这些梯度来更新模型参数,以最小化损失函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/687603.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu20.04 设置路由器

1. 网络拓扑图 2. 查看网卡信息 ip a得出如下网卡信息&#xff0c;enp1s0和enp2s0为两个网卡名称&#xff0c;以及相关两个网卡的详细信息&#xff0c;不同设备的网卡名称可能不一样 1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group defaul…

【UE】制作自己的插件

步骤 1. 打开插件面板&#xff0c;在界面左上角点击“添加”按钮 2. 在“新插件”界面先点击“纯内容”然后点击“创建插件” 此时在项目工程目录下的“Plugins”文件夹中就可以看到我们创建的插件 3. 如果想在自己创建的插件中添加功能&#xff0c;我们可以在项目浏览器中的“…

SSM【Spring SpringMVC Mybatis】——Mybatis(二)

如果对一些基础理论感兴趣可以看这一期&#x1f447; SSM【Spring SpringMVC Mybatis】——Mybatis 目录 1、Mybatis中参数传递问题 1.1 单个普通参数 1.2 多个普通参数 1.3 命名参数 1.4 POJO参数 1.5 Map参数 1.6 Collection|List|Array等参数 2、Mybatis参数传递【#与…

14:java基础-Tomcat-Web容器

文章目录 面试题Web 容器是什么&#xff1f;HTTP 的本质 面试题 Web 容器是什么&#xff1f; 让我们先来简单回顾一下 Web 技术的发展历史&#xff0c;可以帮助你理解 Web 容器的由来。早期的 Web 应用主要用于浏览新闻等静态页面&#xff0c;HTTP 服务器&#xff08;比如Apa…

ArcGIS10.2能用了10.2.2不行了(解决)

前两天我们的推文介绍了 ArcGIS10.2系列许可到期解决方案-CSDN博客文章浏览阅读2次。本文手机码字&#xff0c;不排版了。 昨晚&#xff08;2021\12\17&#xff09;12点后&#xff0c;收到很多学员反馈 ArcGIS10.2系列软件突然崩溃。更有的&#xff0c;今天全单位崩溃。​提示许…

社交媒体数据恢复:九信极速版

九信极速版是一款功能强大的社交软件&#xff0c;但在使用过程中&#xff0c;您可能会遇到数据丢失的问题&#xff0c;如聊天记录、好友等。不用担心&#xff0c;九信极速版提供了数据恢复功能&#xff0c;帮助您找回丢失的数据。以下是详细的恢复步骤&#xff1a; 一、恢复聊…

Spring如何控制Bean的加载顺序

前言 正常情况下&#xff0c;Spring 容器加载 Bean 的顺序是不确定的&#xff0c;那么我们如果需要按顺序加载 Bean 时应如何操作&#xff1f;本文将详细讲述我们如何才能控制 Bean 的加载顺序。 场景 我创建了 4 个 Class 文件&#xff0c;分别命名为 FirstInitialization Se…

(docker)进入容器后如何使用本机gpu

首次创建容器&#xff0c;不能直接使用本机gpu 在系统终端进行如下配置&#xff1a; 1.安装NVIDIA Container Toolkit 进入Nvidia官网Installing the NVIDIA Container Toolkit — NVIDIA Container Toolkit 1.15.0 documentation&#xff0c;安装NVIDIA Container Toolkit …

【Python 下载大量品牌网站的图片(二)】关于图片的处理和下载,吃满带宽,可多开窗口下载多个网站,DOS窗口类型

写作日期&#xff1a;2024.05.11 使用工具&#xff1a;Python 可修改功能&#xff1a;线程量、UA、Cookie、代理、存储目录、间隔时间、超时时间、图片压缩、图片缩放 默认功能&#xff1a;图片转换、断续下载、图片检测、路径处理、存储文件 GUI&#xff1a;DOS窗口 类型&…

Windows使用Miniconda3安装python、环境配置以及conda常用命令

Windows使用Miniconda3安装python以及conda常用命令 这是学习使用python的第一篇文章&#xff0c;这将是一个关于python学习和使用的一个系列文章的开始&#xff0c;有兴趣的可以给个关注持续获取更新内容。 Miniconda3是什么&#xff1f; Miniconda3是一个轻量级的Anaconda发…

【JS 的数据类型】

JS 的数据类型 基本数据类型 js有8种基本数据类型&#xff0c;分别为&#xff1a;undefined、number、Object、null、Symbol、Boolean、String、BigInt&#xff1b; 其中Symbol和BigInt是ES6新增的数据类型&#xff1a; ● Symobol代表独一无二的值&#xff0c;可以用来代表对…

MMdetection在Featurize服务器运行时相关问题

写点闲话&#xff1a; 之前因为毕业&#xff0c;想写代码再也没有稳定的机子跑了&#xff0c;自己电脑有时候也带不动&#xff0c;所以开始使用Featurize&#xff0c;这里可以租一些显卡来用&#xff0c;价格总体来说对我们这种偶尔有大规模算力需求的打工人非常友好。使用方法…