深度学习(五)softmax 回归之:分类算法介绍,如何加载 Fashion-MINIST 数据集

Softmax 回归

基本原理

回归和分类,是两种深度学习常用方法。回归是对连续的预测(比如我预测根据过去开奖列表下次双色球号),分类是预测离散的类别(手写语音识别,图片识别)。

1699720169075

现在我们已经对回归的处理有一定的理解了,如何过渡到分类呢?

假设我们有 n 类,首先我们要编码这些类让他们变成数据。所有类变成一个列向量。

y = [ y 1 , y 2 , . . . y n ] T y=[y_1,y_2,...y_n]^T y=[y1,y2,...yn]T

有一个数据属于第 i 类,那么他的列向量就是:

y = [ 0 , 0 , . . . , 1 , . . . , 0 , 0 ] T y=[0,0,...,1,...,0,0]^T y=[0,0,...,1,...,0,0]T

也就是只有他所在的那个类的元素=1.

可以用均方损失训练,通过概率判断最终选用哪一个。

Softmax 回归就是一种分类方式(回归问题在多分类上的推广)。首先确定输入特征数和输出类别数。比如上图中我们有4个特征和3个可能的类别,那么计算各自概率的公式包括3个线性回归:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

可以看出 Softmax 是全连接的单层神经网络。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们让所有输出结果归一化后,从中选择出最大可能的,置信度最高的分类结果。

image-20231112100423488

采用 e 的指数可以让值全变为非负。

用真实的概率向量-我们预测得到的概率向量就是损失。真实值就是只有一个1的列向量。

交叉熵损失:

image-20231112101259670

可见**分类问题,我们不关心对非正确的预测值,只关心正确预测值是否足够大。**因为正确值是只有一个元素为1的列向量。

常用的损失函数

L2 Loss:均方损失。

image-20231112101555142

L1 Loss:绝对值损失。

image-20231112101829868

L2 梯度是一条倾斜直线,对于梯度下降算法等更为合适;L1 是一个跳变,梯度要么 -1 要么 1. 如图是 L1 L2 的梯度。

image-20231112102551104

我们可以结合两者,得到一个新的损失函数(鲁棒损失 Huber Robust):

KaTeX parse error: {equation} can be used only in display mode.

image-20231112102721527

图像分类数据集

MINIST 是一个常用图像分类数据集,但是过于简单。后来的 upgrade 版叫 Fashion-MINIST(服装分类).

首先,我们研究研究怎么加载训练数据集,以便后面测试算法用。

# 导包
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2ld2l.use_svg_display()d2l.use_svg_display()# 下载数据集并读取到内存
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)		# 训练数据集
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)	# 测试数据集用于评估性能# 定义函数用于返回对应索引的标签
def get_fashion_mnist_labels(labels):  #@save"""返回Fashion-MNIST数据集的文本标签"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]# 图像可视化,让结果看着更直观,比如下面那个绿色图的样子
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save"""绘制图像列表"""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# 图片张量ax.imshow(img.numpy())else:# PIL图片ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes# 我们先读一点数据集看看啥样的
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

1699980345931

# 通过内置数据加载器读取一批量数据,自动随机打乱读取,不需要我们自己定义
batch_size = 256def get_dataloader_workers():  #@save"""使用4个进程来读取数据"""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

测量以上用时基本2-3s。

总结整合以上数据读取过程,代码如下:

def load_data_fashion_mnist(batch_size, resize=None):  #@save"""下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

加载图像还可以调整其大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/188380.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

链表(一)----关于单链表的一切细节这里都有

一.链表 1 链表的概念及结构 概念:链表是一种物理存储结构上非连续、非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的 。 现实中的链表结构 数据结构中的链表结构 1.链式结构在逻辑上是连续的,但在物理上不一定是…

【图解算法】- 异位词问题:双指针+哈希表

一 - 前言 介绍:大家好啊,我是hitzaki辰。 社区:(完全免费、欢迎加入)日常打卡、学习交流、资源共享的知识星球。 自媒体:我会在b站/抖音更新视频讲解 或 一些纯技术外的分享,账号同名&#xff…

Linux系统(CentOS7)上安装MYSQL8.x

Linux系统是CentOS7版本,今天在新电脑上安装MYSQL,跟着网上的文章,尝试了好几次,都是启动失败,删了安,安了删,搞了一下午,头昏脑胀,网上的一些文章太乱了,每种…

flink中配置Rockdb的重要配置项

背景 由于我们在flink中使用了状态比较大,无法完全把状态数据存放到tm的堆内存中,所以我们选择了把状态存放到rockdb上,也就是使用rockdb作为状态后端存储,本文就是简单记录下使用rockdb状态后端存储的几个重要的配置项 使用rockdb状态后端…

geoserver点聚合样式sld

【第六章 WebGIS】geoserver生成点聚合效果 - 知乎 需要WPS插件&#xff0c;注意版本要对应 GeoServer&#xff0c;加压缩后的jar包放到geoserver的lib目录下&#xff0c;重启geoserver。 原始默认样式 聚合sld样式 <?xml version"1.0" encoding"ISO-8859…

基于51单片机步进电机节拍步数正反转LCD1602显示( proteus仿真+程序+原理图+设计报告+讲解视频)

基于51单片机步进电机节拍步数正反转LCD1602显示 &#x1f4d1;1. 主要功能&#xff1a;&#x1f4d1;2. 讲解视频&#xff1a;&#x1f4d1;3. 仿真&#x1f4d1;4. 程序代码&#x1f4d1;5. 设计报告&#x1f4d1;6. 设计资料内容清单&&下载链接&#x1f4d1;[资料下…

Appium移动自动化测试--安装Appium

Appium 自动化测试是很早之前就想学习和研究的技术了&#xff0c;可是一直抽不出一块完整的时间来做这件事儿。现在终于有了。 反观各种互联网的招聘移动测试成了主流&#xff0c;如果再不去学习移动自动化测试技术将会被淘汰。 web自动化测试的路线是这样的&#xff1a;编程语…

使用Microsoft Dynamics AX 2012 - 2. 入门:导航和常规选项

Microsoft Dynamics AX的核心原则之一是为习惯于Microsoft软件的用户提供熟悉的外观和感觉。然而&#xff0c;业务软件必须适应业务流程&#xff0c;这可能相当复杂。 用户界面和常见任务 在我们开始进行业务流程和案例研究之前&#xff0c;我们想了解一下本章中的常见功能。…

“轻松实现文件复制备份,自动编号轻松管理

在日常工作中&#xff0c;我们经常需要复制文件到另一个文件夹进行备份或整理。然而&#xff0c;手动复制粘贴不仅效率低下&#xff0c;还容易出错。为了解决这个问题&#xff0c;我们推出了一款全新的文件工具——【文件批量改名高手】&#xff0c;让你轻松搞定文件复制备份&a…

基于SSM+Vue的校园共享单车管理系统

基于SSMVue的校园共享单车管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringMyBatisSpringMVC工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 主页 登录界面 管理员界面 用户界面 摘要 随着城市交通的不断发展和人们出…

黑五来袭,如何利用海外代理进行助力

黑五作为下半年年度尤为重要的一个节日&#xff0c;是各大商家的必争之地&#xff0c;那么海外代理是如何帮助跨境商家做好店铺管理和营销呢&#xff1f; 为什么跨境人都关注海外代理&#xff0c;下面我们来进行介绍。 一、什么是海外代理 海外代理就是我们所说的&#xff1…

内存模型以及如何判定对象已死问题

1.展示堆内存溢出 设置堆的内存大小为10M&#xff0c;最大的堆内存为10M&#xff0c;这两个参数最好一致&#xff0c;即便最大内存设置为1G&#xff0c;很有可能也分配不到1G。 -Xmx10M -Xms10M 一直往list放东西 public class T1 {public static void main(String[] args) …