Softmax回归(多类分类模型)

目录

  • 1.对真实值类别编码:
  • 2.预测值:
  • 3.目标函数要求:
  • 4.使用Softmax模型将输出置信度Oi计算转换为输出匹配概率y^i:
  • 5.使用交叉熵作为损失函数:
  • 6.代码实现:

1.对真实值类别编码:

在这里插入图片描述

  • y为真实值,有且仅有一个位置值为1,该位置即为该元素真实类别

2.预测值:

在这里插入图片描述

  • Oi为该元素与类别i匹配的置信度

3.目标函数要求:

在这里插入图片描述

  • 对于正确类y的置信度Oy要远远大于其他非正确类的置信度Oi,才能使识别到的正确类与错误类具有更明显的差距

4.使用Softmax模型将输出置信度Oi计算转换为输出匹配概率y^i:

在这里插入图片描述

  • y^为n维向量,每个元素非负且和为1
  • y^i为元素与类别i匹配的概率

5.使用交叉熵作为损失函数:

在这里插入图片描述

  • L为真实概率y与预测概率y^的差距
  • 分类问题不关心非正确类的预测值,只关心正确类的预测值有多大

6.代码实现:

import sys
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2los.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"## 读取小批量数据
batch_size = 256
trans = transforms.ToTensor()
#train_iter, test_iter = common.load_fashion_mnist(batch_size) #无法翻墙的,可以参考这种方法取下载数据集
mnist_train  = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) # 需要网络翻墙,这里数据集会自动下载到项目跟目录的/data目录下
mnist_test  = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) # 需要网络翻墙,这里数据集会自动下载到项目跟目录的/data目录下
print(len(mnist_train))  # train_iter的长度是235;说明数据被分成了234组大小为256的数据加上最后一组大小不足256的数据
print('11111111')## 展示部分数据
def get_fashion_mnist_labels(labels):  # @save"""返回Fashion-MNIST数据集的文本标签。"""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_fashion_mnist(images, labels):d2l.use_svg_display()# 这里的_表示我们忽略(不使用)的变量_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axes.get_xaxis().set_visible(False)f.axes.get_yaxis().set_visible(False)plt.show()train_data, train_targets = next(iter(data.DataLoader(mnist_train, batch_size=18)))
#展示部分训练数据
show_fashion_mnist(train_data[0:10], train_targets[0:10])# 初始化模型参数
num_inputs = 784
num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)# 定义模型
def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)return X_exp / partition  # 这里应用了广播机制def net(X):return softmax(torch.matmul(X.reshape(-1, num_inputs), W) + b)# 定义损失函数
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))def cross_entropy(y_hat, y):return - torch.log(y_hat.gather(1, y.view(-1, 1)))# 计算分类准确率
def accuracy(y_hat, y):return (y_hat.argmax(dim=1) == y).float().mean().item()# 计算这个训练集的准确率
def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / nnum_epochs, lr = 10, 0.1# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()# 执行优化方法if optimizer is not None:optimizer.step()else:d2l.sgd(params, lr, batch_size)train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))# 训练模型
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)# 预测模型
for X, y in test_iter:break
true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
show_fashion_mnist(X[0:9], titles[0:9])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/410241.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

部署YUM仓库及NFS共享存储

引言: 学习YUM 软件仓库,可以完成安装、卸载、自动升级 rpm 软件包等任务,能够自动 查找并解决 rpm 包之间的依赖关系,而无须管理员逐个、手工地去安装每个 rpm 包,使管理员在维护大量 Linux 服务器时更加轻松自如。特…

【IAP】核心开发流程

最近做了IAP U盘升级模块开发,总结下IAP基本开发流程,不深入讨论原理。 详细原理参考 首先需要知道我们需要把之前的APP区域拆一块出来做BOOT升级程序区域。 以STM32F103为例,0x08000000到0x0807FFFF为FLASH空间,即上图代码区域…

软件测试卷王的自述,我难道真的很卷吗?

前言 转眼就到了2024年了,工作这几年我的薪资也从12k涨到了18k,对于工作只有3年多的我来说,还是比较满意的,毕竟一些工作4、5年的可能还没我高。 我可能就是大家说的卷王,感觉自己年轻,所以从早干到晚&am…

龙蜥操作系统上安装MySQL:步骤详解与常见问题解决

目录 博客前言 一.下载MySQL 1.官网下载 2.上传文件到龙蜥操作系统中 ​编辑二.安装MySQL 1.检查操作系统中的默认数据库并移除 2.创建文件夹解压 3.开始安装 4.启动服务 ​编辑 5.登录修改密码,进行授权 三.第三方工具连接(naviact&#xff…

计算机三级(网络技术)——应用题

第一题 61.输出端口S0 (直接连接) RG的输出端口S0与RE的S1接口直接相连构成一个互联网段 对172.0.147.194和172.0.147.193 进行聚合 前三段相同,将第四段分别转换成二进制 11000001 11000010 前6位相同,加上前面三段 共30…

网络安全中的“三高一弱”和“两高一弱”是什么?

大家在一些网络安全检查中,可能经常会遇到“三高一弱”这个说法。那么,三高一弱指的是什么呢? 三高:高危漏洞、高危端口、高风险外连 一弱:弱口令 一共是4个网络安全风险,其中的“高危漏洞、高危端口、弱…

Qt6入门教程 8:信号和槽机制(连接方式)

目录 一.一个信号与槽连接的例子 二.第五个参数 1.Qt::AutoConnection 2.Qt::DirectConnection 3.Qt::QueuedConnection 4.Qt::BlockingQueuedConnection 5.Qt::UniqueConnection 三.信号 四.connect函数原型 五.信号与槽的多种用法 六.槽的属性 一.一个信号与槽连接…

vscode(visual studio code) 免密登陆服务器

1.生成密钥 首先,在本地,打开命令输入框: WinR–>弹出输入框,输入cmd,打开命令框。 然后,在命令框,输入 ssh-keygen -t rsa -C "love"按两次回车键,问你是否重写,选择…

城建档案数字化怎么做?

城建档案数字化的关键是整理、扫描、标注、管理和安全性管理,通过建立适当的系统和流程,可以实现城建档案的数字化管理和应用。 城建档案数字化的具体步骤可以分为以下几个方面: 1. 档案整理与分类:首先需要将城建档案进行整理和分…

无需任何三方库,在 Next.js 项目在线预览 PDF 文件

前言: 之前在使用Vue和其它框架的时候,预览 PDF 都是使用的 PDFObject 这个库,步骤是:下载依赖,然后手动封装一个 PDF 预览组件,这个组件接收本地或在线的pdf地址,然后在页面中使用组件的车时候…

本地一键部署grafana+prometheus

本地k8s集群内一键部署grafanaprometheus 说明: 此一键部署grafanaPrometheus已包含: victoria-metrics 存储prometheus-servergrafanaprometheus-kube-state-metricsprometheus-node-exporterblackbox-exporter grafana内已导入基础的dashboard【7个…

贪心算法-活动安排-最详细注释解析

贪心算法-活动安排-最详细注释解析 题目: 学校在最近几天有n个活动,这些活动都需要使用学校的大礼堂,在同一时间,礼堂只能被一个活动使用。由于有些活动时间上有冲突,学校办公室人员只好让一些活动放弃使用礼堂而使用…