python自定义交叉熵损失,再和pytorch api对比

背景

我们知道,交叉熵本质上是两个概率分布之间差异的度量,公式如下

 其中概率分布P是基准,我们知道H(P,Q)>=0,那么H(P,Q)越小,说明Q约接近P。

损失函数本质上也是为了度量模型和完美模型的差异,因此可以用交叉熵作为损失函数,公式如下

其中

的部分不过是考虑到每次都是输入一批样本,因此把每个样本的交叉熵求出来以后要再求个平均。

注意,我的代码没有考虑标签是soft embedding的情况,如果遇到标注Y是[[0.1,0.1,0.8],[0.1,0.8,0.1],[0.1,0.1,0.8]],那么你需要把代码再推广一下。

自定义交叉熵损失

from typing import List
import mathdef my_softmax(x:List[List[float]])->List[List[float]]:new_x:List[List[float]] = []for i in range(len(x)):sum:float = 0new_x_i = []for j in range(len(x[0])):sum += math.exp(x[i][j])for j in range(len(x[0])):new_x_i.append(math.exp(x[i][j])/sum)new_x.append(new_x_i)return new_xdef my_cross_entropy(x:List[List[float]],y:List[int])->float:res:float = 0x = my_softmax(x)for i in range(len(x)):res += -math.log(x[i][y[i]]) # 根号外面的1和底数e省去了res /= len(x) # meanreturn res# 假设有一个简单的三分类问题,批量大小为2
# 预测输出(通常是模型的原始输出,没有经过softmax)
logits = [[1.5, 0.5, -0.5], [1.2, 0.2, 3.0]]
# 0 和 2 分别表示第一个和第三个类别是正确的
targets = [0, 2]
print(my_cross_entropy(logits,targets))

Pytorch交叉熵损失

import torch
import torch.nn as nnlogits = torch.tensor([[1.5, 0.5, -0.5],[1.2, 0.2, 3.0]])targets = torch.tensor([0, 2])  criterion = nn.CrossEntropyLoss()loss = criterion(logits, targets)print(loss.item())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/643571.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

maven多模块创建-安装配置

1、前提 许久没有写文章了,荒废了2年多的时间,在整理的时候,发现Maven还差一篇安装配置的文章,现在开始提笔完善它,参考:https://blog.csdn.net/m0_72803119/article/details/134634164。 —写于2024年4月…

一个docker配置mysql主从服务器

这也就是因为穷,不然谁用一个docker配置主从,哈哈 既然成功了就记录下。过程挺折磨人的。 首先要保证你的电脑安装好了docker 为了保证docker当中主从能正常连网,现在docker里面创建一个网络环境 docker network create --driver bridge mysq…

40. UE5 RPG给火球术增加特效和音效

前面,我们将火球的转向和人物的转向问题解决了,火球术可以按照我们的想法朝向目标发射。现在,我们解决接下来的问题,在角色释放火球术时,会产生释放音效,火球也会产生对应的音效,在火球击中目标…

基于小程序实现的查寝打卡系统

作者主页:Java码库 主营内容:SpringBoot、Vue、SSM、HLMT、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、小程序、安卓app等设计与开发。 收藏点赞不迷路 关注作者有好处 文末获取源码 技术选型 【后端】:Java 【框架】:ssm 【…

Golang | Leetcode Golang题解之第46题全排列

题目: 题解: func permute(nums []int) [][]int {var (n len(nums)dfs func(vals []int) // 已选择数 排列为vals 后续回溯继续选择 直至选完ans [][]int)dfs func(vals []int) {//边界if len(vals) n {ans append(ans, vals)}//转移 枚举选哪个f…

项目实战 | 责任链模式 (下)

案例二:工作流,费用报销审核流程 同事小贾最近刚出差回来,她迫不及待的就提交了费用报销的流程。根据金额不同,分为以下几种审核流程。报销金额低于1000元,三级部门管理者审批即可,1000到5000元除了三级部…

如何从架构层面降低公有云多可用区同时故障的概率

阿里云和腾讯云都曾出现过因一个组件故障而导致所有可用区同时瘫痪的情况。本文将探讨如何从架构设计的角度减小故障域,在故障发生时最小化业务损失,并以 Sealos 的稳定性实践为例,分享经验教训。 抛弃主从,拥抱点对点架构 从腾…

第26天:安全开发-PHP应用模版引用Smarty渲染MVC模型数据联动RCE安全

第二十六天 一、PHP新闻显示-数据库操作读取显示 1.新闻列表 数据库创建新闻存储代码连接数据库读取页面进行自定义显示 二、PHP模版引用-自写模版&Smarty渲染 1.自写模版引用 页面显示样式编排显示数据插入页面引用模版调用触发 2.Smarty模版引用 1.下载&#xff1a…

小程序 rich-text 解析富文本 图片过大时如何自适应?

在微信小程序中&#xff0c;用rich-text 解析后端返回的数据&#xff0c;当图片尺寸太大时&#xff0c;会溢出屏幕&#xff0c;导致横向出现滚动 查看富文本代码 图片是用 <img 标签&#xff0c;所以写个正则匹配一下图片标签&#xff0c;手动加上样式即可 // content 为后…

iOS ------代理 分类 拓展

代理协议 一&#xff0c;概念&#xff1a; 代理&#xff0c;又称委托代理&#xff08;delegate&#xff09;&#xff0c;是iOS中常用的一种设计模式。顾名思义&#xff0c;它是把某个对象要做的事委托给别的对象去做。那么别的对象就是这个对象的代理&#xff0c;代替它来打理…

构建NodeJS库--前端项目的打包发布

1. 前言 学习如何打包发布前端项目&#xff0c;需要学习以下相关知识&#xff1a; package.json 如何初始化配置&#xff0c;以及学习npm配置项&#xff1b; 模块类型type配置&#xff0c; 这是nodejs的package.json的配置main 入口文件的配置 webpack 是一个用于现代 JavaSc…

Python程序设计教案

文章目录&#xff1a; 一&#xff1a;软件环境安装 1.软件环境 2.技巧 3.新建工程项目 二&#xff1a;相关 1.规范 2.关键字 3.Ascll码表 三&#xff1a;语法基础 1.各种符号 1.1 注释 1.2 占位置的 1.3 回车换行 2.输入输出 2.1 输入input 2.2 输出print …