【深度学习基础(4)】pytorch 里的log_softmax, nll_loss, cross_entropy的关系

一、常用的函数有: log_softmax,nll_loss, cross_entropy

1.log_softmax

log_softmax就是log和softmax合并在一起执行,log_softmax=log+softmax

2. nll_loss

nll_loss函数全称是negative log likelihood loss, 函数表达式为:f(x,class)=−x[class]
例如:假设x=[5,6,9], class=1, 则f(x,class)=−x[1]=−6

3. cross_entropy交叉熵

cross_entropy=log+softmax+nll_loss

二、代码实现

import torch
import torch.nn.functional as Fpreds = torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.1, 0.1, 0.1]])
target = torch.tensor([2, 3])print('三种方式实现交叉熵损失')
print('----------------手动实现------------------------------')
one_hot = F.one_hot(target).float() # 对标签作 one_hot 编码
print('[1]one_hot编码target:\n', one_hot)
exp = torch.exp(preds)
print('[2]对网络预测preds求指数:\n', exp)
sum_ = torch.sum(exp, dim=1).reshape(-1, 1)  # 按行求和
softmax = exp / sum_  # 计算 softmax()
print('[3]softmax操作:\n', softmax)
log_softmax = torch.log(softmax) # 计算 log_softmax()
print('[4]softmax后取对数:\n', log_softmax)
nllloss = -torch.sum(one_hot * log_softmax) / target.shape[0]  # 标签乘以激活后的数据,求平均值,取反
print("[5]手动计算交叉熵:", nllloss)print('----------------调用log_softmax+nll_loss实现------------------------------')
# 调用 NLLLoss() 函数计算
Log_Softmax = F.log_softmax(preds, dim=1)  # log_softmax() 激活
Nllloss = F.nll_loss(Log_Softmax, target)  # 无需对标签作 one_hot 编码
print("函数使用Nllloss计算交叉熵:", Nllloss)print('------------------调用cross_entropy实现----------------------------')
# 直接使用交叉熵损失函数 CrossEntropy_Loss()
cross_entropy = F.cross_entropy(preds, target)  # 无需对标签作 one_hot 编码
print('函数交叉熵cross_entropy:', cross_entropy)

查看结果,可以看到三种方式计算的结果是一样的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/573420.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt/QML编程之路:QPainter与OpenGL的共用(49)

在Qt编程中,有时会有这样一种场景:用OpenGL显示了一个3维立体图,但是想在右下角画一个2D的表格,里面写上几个字。那么这个时候就会出现QPainter与OpenGL共用或者说2D、3D共用。但是问题是调用了QPainter,drawline之后呢,OPenGL的状态被清空了丢失了,3D不显示了。 在Ope…

【海贼王之强者之路】经典动漫影视改编火爆剧情回合卡牌手游-Win服务端源码视频架设教程-开放多区-GM后台-安卓苹果IOS双端版本!

【海贼王之强者之路】站长推荐经典动漫影视改编火爆剧情回合卡牌手游-2024年3月27日最新打包Win服务端源码视频架设教程-开放多区-GM后台-安卓苹果IOS双端版本!

python学习12:python中的字符串格式化-数字精度控制

python中的字符串格式化-数字精度控制 1.使用辅助符号"m.n"来进行数据的宽度和精度的控制 m,控制宽度,要求是数字(一般是很少使用的),设置的宽度小于数字自身,不生效 n,控制小数点精度,要求是数…

python 开启restful之路

1、环境安装及配置 python & pip python 官网下载 3.8 Python Release Python 3.8.0 | Python.org Centos安装python3详细教程_centos 安装python3-CSDN博客 2、IDE工具 PyCharm Community Edition 2023.3.4 3、Flask 构建简单的 web应用 编写 app.py文件 from flask …

蓝桥杯算法赛(二进制王国)

问题描述 二进制王国是一个非常特殊的国家,因为该国家的居民仅由 0 和 1 组成。 在这个国家中,每个家庭都可以用一个由 0 和 1 组成的字符串 S 来表示,例如 101、 000、 111 等。 现在,国王选了出 N 户家庭参加邻国的庆典…

vue2 el-table指定某些数据不参与排序

vue2 el-table指定某些数据不参与排序 1、需求描述2、配置属性方法3、详细代码如下 1、需求描述 最后一行总计不参与排序 2、配置属性方法 el-table 需要配置 sort-change"soltHandle" 方法 el-table-column 需要配置 sortable"custom"属性3、详细代码如…

Go的数据结构与实现【Set】

介绍 Set是值的集合,可以迭代这些值、添加新值、删除值并清除集合、获取集合大小并检查集合是否包含值,集合中的一个值只存储一次,不能重复。 本文代码地址为go-store 简单实现 这是集合的一个简单实现,还不是并发安全的&#…

Android 性能优化(六):启动优化的详细流程

书接上文,Android 性能优化(一):闪退、卡顿、耗电、APK 从用户体验角度有四个性能优化方向: 追求稳定,防止崩溃追求流畅,防止卡顿追求续航,防止耗损追求精简,防止臃肿 …

机器学习——聚类算法-KMeans聚类

机器学习——聚类算法-KMeans聚类 在机器学习中,聚类是一种无监督学习方法,用于将数据集中的样本划分为若干个簇,使得同一簇内的样本相似度高,不同簇之间的样本相似度低。KMeans聚类是一种常用的聚类算法之一,本文将介…

外包干了15天,技术退步明显。。。。。。

说一下自己的情况,本科生,19年通过校招进入武汉某软件公司,干了接近4年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试&a…

开源博客项目Blog .NET Core源码学习(12:App.Application项目结构分析)

开源博客项目Blog的App.Application项目主要定义网站页面使用的数据类,同时定义各类数据的增删改查操作接口和实现类。App.Application项目未安装Nuget包,主要引用App.Core项目的类型。   App.Application项目的顶层文件夹如下图所示,下面逐…

蓝桥杯刷题-分巧克力

分巧克力 二分: N, K map(int, input().split()) w, h [], [] for i in range(N):cur_w, cur_h map(int, input().split())w.append(cur_w)h.append(cur_h)# 判断是否能分成K个及以上的巧克力 def check(a) -> bool: num 0for i in range(N):num (w[i] // …