动手学深度学习——softmax分类

1. 分类问题

回归与分类的区别:

  • 回归可以用于预测多少的问题, 比如"预测房屋被售出价格",它是个单值输出。
  • softmax可以用来预测分类问题,例如"某个图片中是猫、鸡还是狗?",这是一个多值输出,输出个数等于类别个数,输出的第i个值表示预测为第i类别的概率。

两者的区别在于是问多少还是问哪一个?

分类可以用来描述下面两个问题:

  1. 样本属于哪个类别
  2. 样本属于每个类别的概率

比较经典的分类问题有:

  1. MNIST数据集,手写数字识别,有0-9十个类别。
  2. ImageNet数据集,从一百万张图片中识别自然物体,有1000个类别。
  3. kaggle上的恶意软件类别识别。
  4. 区分淘宝商品的评论是正面还是负面评论。

2. 分类编码

由于自然语言表示的类别不方便运算,所以为了计算的需要,有必要对类别进行编码。

对于分类问题,最常用的编码方式为一位有效编码,也称为独热编码(one-hot encoding)。它可以表示为一个向量,长度等于类别数量,向量中只有一个特征为1,其它特征均为0。

这里我们以一个图像分类问题为例来讨论, 假设要预测一张图片是猫、鸡还是狗,那么我们对这三种类别进行一位有效编码的形式如下:

  • (1,0, 0)对应于“猫”
  • (0,1,0)对应于“鸡”
  • (0,0,1)对应于“狗”
  1. 正确类别对应的分量设置为1,其它所有分量均为0.
  2. 类别数量等于分量数量(这里的分量是指向量在具体一个维度上的值)

分类问题对模型的要求:正确类的置信度要远远大于非正确类的置信度,即Oy > Oi。

相比具体每个类别的预测值大小,我们更关心正确类别的预测值是否远大于其它非正确类别的预测值,只有这样,才能表明模型能真正区分出正确类别。

3. 网络架构

与线性回归一样,softmax回归也是一个单层神经网络。

接着上面的例子,假设每次输入是一个2*2的灰度图像,我们可以用一个标量表示每个像素值,每个图像对应四个特征[x1,x2,x3,x4]。

我们可以定义输出向量y=[o1,o2,o3], 其中o1、o2、o3分别表示输入i是猫、鸡、狗的预测值大小。

由于我们有4个特征和3个可能的输出类别, 我们将需要12个标量来表示权重w, 3个标量来表示偏置b。则每个类别的计算可以表示为:
在这里插入图片描述

由于计算每个输出o1、o2和o3取决于 所有输入x1、x2、x3和x4, 所以softmax回归的输出层也是全连接层。
在这里插入图片描述
如同线性回归一样,可以将计算公式简洁表示,o = Wx + b。这是将所有权重放到一个W矩阵中。 对于给定数据样本的特征x, 我们的输出y是由权重W与输入特征x进行矩阵-向量乘法再加上偏置b得到。

4. 输出概率化

对于分类问题,我们希望模型的输出yj可以视为它属于类别j的概率,然后只需要选择具有最大输出值的类别argmax(xj,yj) 作为我们的预测即可, 这样能同时方便人脑理解和算术运算。

例如,如果为猫、鸡和狗的概率分别为0.1、0.8和0.1, 因为0.8概率最大,所以我们预测的类别是2,在我们的例子中代表“鸡”。

这里之所以要进行标准化概率计算,而不直接将预测o作为输出,其原因在于将线性层的输出o视作概率会存在一些问题:

  1. 线性层输出没有限制输出数字的总和为1,不符合概率分布。
  2. 根据输入的不同,线性层的输出是可以为负值的,会影响我们的计算。

要将输出视为概率,我们必须保证以下两点:

  1. 在任何类别上的输出都是非负
  2. 所有类别的预测值总和为1。

而softmax函数则正好能够将未规范化的预测变换为非负数并且总和为1,同时让模型保持可导的性质。它的作法为:

  • 对每个未规范化的预测求幂(指数),这样可以确保输出非负
  • 让每个求幂后的结果除以它们的总和,就能确保最终输出的概率值总和为1
    在这里插入图片描述通过对输出向量o进行softmax运算后,预测值就是一个概率分布。

而真实的值经过独热编码后也符合这个特征,因为它也符合概率的特性:

  1. 非负数:只有0和1两种值;
  2. 和为1:只有一个值为1,其它均为0;

这样就得到两个概率:预测值概率和真实值概率。接下来,就可以比较两个概率来作为损失。

5. 损失计算

交叉熵损失:用来衡量两个概率分布之间的差异。

对于分类问题,我们不关心非正确类别的预测值,只关心对正确类的预测值置信度有多大。

假设模型对每个类别的预测概率分别是0.7、0.2和0.1,实际该样本属于第一个类别。交叉熵损失会根据模型对第一个类别的预测概率和实际概率来计算一个损失值。用数学表示如下:

H(p, q) = -Σ p(x) * log(q(x))
  • p(x)表示实际的概率分布,q(x)表示模型预测的概率分布。
  • 前面加负号的目的是为了保证交叉熵为正值。log(q(x))的值通常是小于0的(小于1时,对数为负数),p(x)是一个概率值,介于0和1之间。
  • 交叉熵越小,表示两个概率分布越接近,模型的预测效果越好。

可以把交叉熵H(P,Q)想象为“主观概率为Q
的观察者在看到根据概率P生成的数据时的意外程度”。 当P=Q时,这种意外程度降到最低。

训练的目的:最小化交叉熵来优化模型的参数,使得模型的预测结果更接近于实际标签。

由于真实值p(x)是一个独热编码向量,只有一项为1,其它项均为0,所以这里的交叉熵又可以简写成:
在这里插入图片描述
所以,对于分类问题来说,我们不关心非正确类别的预测值,只关心正确类别的预测值有多大。

而梯度则是预测概率与真实概率之间的差异,损失函数对输出o求导为:
在这里插入图片描述

softmax回归模型训练的目标:给出任何样本特征,我们可以预测每个输出类别的概率。 通常我们使用预测概率最高的类别作为输出类别。 如果预测与实际类别(标签)一致,则预测是正确的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/661033.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙准备1

鸿蒙心路 感慨索性, 看看鸿蒙吧。打开官网相关介绍 新建工程目录结构 感慨 最近面试Android应用开发,动不动就问framework的知识,什么touch事件的触发源是啥(eventHub),gc流程是啥,图形框架是什…

Vue入门到关门之Vue项目工程化

一、创建Vue项目 1、安装node环境 官网下载,无脑下一步,注意别放c盘就行 Node.js — Run JavaScript Everywhere (nodejs.org) 需要两个命令 npm---->pipnode—>python 装完检查一下,hello world检测,退出crtlc 2、搭建vu…

Python 与 TensorFlow2 生成式 AI(五)

原文:zh.annas-archive.org/md5/d06d282ea0d9c23c57f0ce31225acf76 译者:飞龙 协议:CC BY-NC-SA 4.0 第十二章:用生成式人工智能玩视频游戏:GAIL 在之前的章节中,我们已经看到如何使用生成式人工智能来生成…

【iOS】消息流程分析

文章目录 前言动态类型动态绑定动态语言消息发送objc_msgSendSEL(selector)IMP(implementation)IMP高级用法 MethodSEL、IMP、Method总结流程概述 快速查找消息发送快速查找的总结buckets 慢速查找动态方法解析resolveInstanceMet…

HIVE启动步骤

不如意的时候不要尽往悲伤里钻 想想有笑声的日子 启动HIEV 1.启动虚拟机Hadoop集群 2.连接Linux 3.start-all.sh 4.hive 5.hive启动时报错 当我们启动Hadoop集群时 启动hive可能会出现卡在true处不动的情况 那么我们只需要做一个操作就可以解决问题啦 hdfs haadmin -transitio…

CTF-Show nodejs

web334 下载附件,有两个文件 在Character.toUpperCase()函数中,字符ı会转变为I,字符ſ会变为S。 在Character.toLowerCase()函数中,字符İ会转变为i,字符K会转变为k。 所以用ctfſhow 123456登录就可以出flag了 w…

linux安装Redis 7.2.4笔记

一.保姆级安装 1.下载Redis 7.2.4安装包 sudo wget https://download.redis.io/releases/redis-7.2.4.tar.gz2.解压,可以指定 sudo tar -zvxf redis-7.2.4.tar.gz 3.检测并安装 GCC 编译器: yum 是基于 Red Hat 的 Linux 发行版(如 CentOS、…

Java面试八股之Java中能创建volatile数组吗

Java中能创建volatile数组吗 Java中可以创建volatile数组,如声明volatile int[] myVolatileArray new int[10];。此处volatile修饰符应用于数组变量myVolatileArray,而非数组内部的各个元素。 volatile关键字对数组变量的主要作用包括: 1…

【docker】Spring Boot3.x 打包 Docker容器

Docker化Spring Boot应用 创建文件夹 demo mkdir democd demo创建Dockerfile # 两个 openjdk 二选一 #FROM openjdk:17-jre-alpineFROM eclipse-temurin:17MAINTAINER chengxuyuanshitang <chengxuyuanshitangXX.com>RUN mkdir -p /workspace/java/demoCOPY demo.ja…

ClickHouse高原理与实践

ClickHouse高原理与实践 1 ClickHouse的特性1.1. OLAP1.2. 列式存储1.3. 表引擎1.4. 向量化执行1.5. 分区1.6. 副本与分片1.7 其他特性 2. ClickHouse模块设计2.1 Parser分析器与Interpreter解释器2.2 Storage2.3 Column与Field2.4 DataType2.4 Block2.5 Cluster与Replication …

OpenCV的图像矩(64)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV如何为等值线创建边界旋转框和椭圆(63) 下一篇 :OpenCV系列文章目录&#xff08;持续更新中......&#xff09; Image Moments&#xff08;图像矩&#xff09;是 OpenCV 库中的一个…

远程仓库——GitHub

远程仓库——GitHub 一、在GitHub创建远程仓库二、在GitHub上添加密钥三、克隆远程仓库的代码到本地四、如何将本地仓库第一次同步到Github五、总结1.常用命令总结 注意&#xff1a;本文主要讲解的是&#xff0c;如何快速的将本地仓库的代码托管到GitHub上&#xff0c;如果不知…