学习pytorch14 损失函数与反向传播

神经网络-损失函数与反向传播

  • 官网
  • 损失函数
    • L1Loss MAE 平均
    • MSELoss 平方差
    • CROSSENTROPYLOSS 交叉熵损失
      • 注意
      • code
  • 反向传播
    • 在debug中的显示
      • code

B站小土堆pytorch视频学习

官网

https://pytorch.org/docs/stable/nn.html#loss-functions
在这里插入图片描述

损失函数

在这里插入图片描述

L1Loss MAE 平均

在这里插入图片描述
在这里插入图片描述

import torchinput = torch.tensor([1, 2, 3], dtype=float)
# target = torch.tensor([1, 2, 5], dtype=float)
target = torch.tensor([[[[1, 2, 5]]]], dtype=float) # shape [1, 1, 1, 3]
input = torch.reshape(input, (1,1,1,3))
# target = torch.reshape(target, (1,1,1,3))
print(input.shape)
print(target.shape)loss1 = torch.nn.L1Loss()
loss2 = torch.nn.L1Loss(reduction="sum")
result1 = loss1(input, target)
print(result1) # tensor(0.6667, dtype=torch.float64)
result2 = loss2(input, target)
print(result2) # tensor(2., dtype=torch.float64)

MSELoss 平方差

在这里插入图片描述
在这里插入图片描述

import torchinput = torch.tensor([1, 2, 3], dtype=float)
# target = torch.tensor([1, 2, 5], dtype=float)
target = torch.tensor([[[[1, 2, 5]]]], dtype=float) # shape [1, 1, 1, 3]
input = torch.reshape(input, (1,1,1,3))
# target = torch.reshape(target, (1,1,1,3))
print(input.shape)
print(target.shape)loss_mse = torch.nn.MSELoss(reduction='mean')
result_mse = loss_mse(input, target)
print(result_mse) # tensor(1.3333, dtype=torch.float64)
loss_mse2 = torch.nn.MSELoss(reduction='sum')
result_mse2 = loss_mse2(input, target)
print(result_mse2)   # tensor(4., dtype=torch.float64)

CROSSENTROPYLOSS 交叉熵损失

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
在这里插入图片描述
在这里插入图片描述
在神经网络中,默认log是以e为底的,所以也可以写成ln
在这里插入图片描述
在这里插入图片描述

注意

  1. 根据需求选择对应的loss函数
  2. 注意loss函数的输入输出shape

code

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
myseq = MySeq()
print(myseq)
for data in dataloader:imgs, targets = dataprint(imgs.shape)output = myseq(imgs)result = loss(output, targets)print(result)

反向传播

在debug中的显示

显示在网络结构中,每一层的保护属性中,都有weight属性,梯度属性在weitht属性里面
先找模型结构 在找每一层 在找weight权重,梯度在weight权重里面

在这里插入图片描述

code

核心代码:result_loss.backward() # 要在最后获取 backward函数要挂在通过loss函数计算后的结果上。

# 模型定义、数据加载 同上个代码
for data in dataloader:imgs, targets = dataprint(imgs.shape)output = myseq(imgs)result_loss= loss(output, targets)result_loss.backward()  # 要在最后获取print(result_loss)print(result_loss.grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/145697.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis不止能存储字符串,还有List、Set、Hash、Zset,用对了能给你带来哪些优势?

文章目录 🌟 Redis五大数据类型的应用场景🍊 一、String🍊 二、Hash🍊 三、List🍊 四、Set🍊 五、Zset 📕我是廖志伟,一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO…

安装mmcv及GPU版本的pytorch及torchvision

一、先装GPU版本的pytorch和torchvision pip install torch1.9.1cu111 torchvision0.10.1cu111 torchaudio0.9.1 -f https://download.pytorch.org/whl/torch_stable.html注意:以上适用cuda11.1版本 如果想离线安装,就看这篇文章 二、安装mmcv 看这篇…

Power BI 傻瓜入门 1. 数据分析术语:Power BI风格

本章内容包括: 了解Power BI可以处理的不同类型的数据了解您的商业智能工具选项熟悉Power BI术语 数据无处不在。从你醒来的那一刻到你睡觉的时候,某个系统会代表你收集数据。即使在你睡觉的时候,也会产生与你生活的某些方面相关的数据。如…

11 Self-Attention相比较 RNN和LSTM的优缺点

博客配套视频链接: https://space.bilibili.com/383551518?spm_id_from=333.1007.0.0 b 站直接看 配套 github 链接:https://github.com/nickchen121/Pre-training-language-model 配套博客链接:https://www.cnblogs.com/nickchen121/p/15105048.html RNN 无法做长序列,当一…

软件项目管理【UML-组件图】

目录 一、组件图概念 二、组件图包含的元素 1.组件(Component)->构件 2.接口(Interface) 3.外部接口——端口 4.连接器(Connector)——连接件 4.关系 5.组件图表示方法 三、例子 一、组件图概念…

Centos 7 Zabbix配置安装

前言 Zabbix是一款开源的网络监控和管理软件,具有高度的可扩展性和灵活性。它可以监控各种网络设备、服务器、虚拟机以及应用程序等,收集并分析性能指标,并发送警报和报告。Zabbix具有以下特点: 1. 支持多种监控方式:可…

MySQL -- 库和表的操作

MySQL – 库和表的操作 文章目录 MySQL -- 库和表的操作一、库的操作1.创建数据库2.查看数据库3.删除数据库4.字符集和校验规则5.校验规则对数据库的影响6.修改数据库7.备份和恢复8.查看连接情况 二、表的操作1.创建表2.查看表结构3.修改表4.删除表 一、库的操作 注意&#xf…

视频去噪网络BSVD的实现

前些天写了视频去噪网络BSVD论文的理解,详情请点击这里,这两个星期动手实践了一下,本篇就来记录一下这个模型的实现。 这个网络的独特之处在于,它的训练和推理在实现上有所差别。在训练阶段,其使用了TSM(T…

Hexo搭建个人博客系列之环境准备

环境准备 Git Git官网,安装过程,就是一直下一步,详细的看这篇文章 Git的安装 Node.js Node.js官网 Node.js的安装 注册一个GitHub账号 安装hexo 新建一个文件夹(位置任意),运行cmd(若出现了operation not permitted,就以管理员的权限来运行cmd),运行…

WordPress SMTP邮件发送插件 Easy WP SMTP

Easy WP SMTP是一款 WordPress 邮件发送插件,WordPress 中经常用到邮件发送,包括新注册用户的邮件通知、找回密码通知、评论回复通知等。因为云服务器默认不启用 SMTP功能,所以需要安装 SMTP插件来解决这个问题。 SMTP 主机:smtp.…

acwing第 126 场周赛 (扩展字符串)

5281. 扩展字符串 一、题目要求 某字符串序列 s0,s1,s2,… 的生成规律如下: s0 DKER EPH VOS GOLNJ ER RKH HNG OI RKH UOPMGB CPH VOS FSQVB DLMM VOS QETH SQBsnDKER EPH VOS GOLNJ UKLMH QHNGLNJ Asn−1AB CPH VOS FSQVB DLMM VOS QHNG Asn−1AB,其…

node-red常用包分析

node-red-contrib-opcua Use OpcUa-Item to define variables. Use OpcUa-Client to read / write / subscribe / browse OPC UA server. 需要想通过OpcUa-Item节点来指定一个数据点。 触发器-->opcua_item----->opcua_client opcua_client的Action项解析: …