Softmax分类器

文章目录

  • 回顾
    • 使用Sigmoid构建多分类器?
  • SoftMax函数
  • 交叉熵损失函数
    • 例子
  • MINIST多分类器
    • 数据集
    • 步骤
    • 实现
      • 1.数据集
      • 2.构建模型
      • 3.构建损失函数和优化器
      • 4. 训练和测试
  • 完整代码

回顾

上节课利用糖尿病数据集做了二分类任务
在这里插入图片描述
MNIST数据集有10个类别我们又该如何进行分类呢?
在这里插入图片描述

使用Sigmoid构建多分类器?

之前二分类使用的是sigmoid函数进行分类,它可以把输出归一化到[0,1]之间。如果使用Sigmoid激活函数进行多分类,会出现一个问题:每个类别的概率都是[0,1]之间,他们加起来的概率和可能就不为1.我们想要的结果是满足一个分布:概率P>=0;并且概率之和=1.
在这里插入图片描述
在这里插入图片描述

SoftMax函数

其实就是对输出值y取对数,然后再除以输出的对数之和
在这里插入图片描述
在这里插入图片描述

交叉熵损失函数

标签采用One-hot编码,与预测的概率值计算损失。

  • 后面部分叫做NLLLoss
    在这里插入图片描述
    在这里插入图片描述
  • pytorch的CrossEntropyLoss=softmax+NLLLoss

在这里插入图片描述

例子

在这里插入图片描述

MINIST多分类器

数据集

MINIST数据是一个28*28像素的矩阵,如果把它线性隐射到[0,1]之间
在这里插入图片描述

步骤

在这里插入图片描述

实现

1.数据集

transforms.ToTensor()

  • transform进行图像变换,将PIL图像变换为C*W *H大小的的Tensor。
  • PIL库会将图片像素由[0,255]映射到[0,1]之间,方便pytorch进行运算。
transform=transforms.Compose([transforms.ToTensor(),#Convert the PIL Image to Tensor.transforms.Normalize((0.1307,),(0.3081,))])#The parameters are mean and std respectively.

在这里插入图片描述
图像我们通常会有通道这个概念,可以理解为一个通道就是一个图像的矩阵。

  • 灰度图片只有一个通道;
  • 彩色图片:R、G、B三个通道;
  • 所以一张图片其实可以表示成WHC,在Pytorch中我们需要转换为CWH,即通道数C要放在最前面。
    在这里插入图片描述
    在这里插入图片描述
    transforms.Normalize((0.1307,),(0.3081,))
  • 归一化,参数分别是均值和标准差
  • 在这里插入图片描述

2.构建模型

  • 将输入图像进行view,变成(N,784),然后进行训练。
  • 注意最后一层线形层不需要激活函数,因为交叉熵损失函数包含了softmax激活函数。
    在这里插入图片描述
class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.linear1=torch.nn.Linear(784,512)self.linear2=torch.nn.Linear(512,256)self.linear3=torch.nn.Linear(256,128)self.linear4=torch.nn.Linear(128,64)self.linear5=torch.nn.Linear(64,10)def forward(self, x):x=x.view(-1,784)#把每一张图片的像素都拼接起来,然后变成二维(N,748)(748=28*28)x=F.relu(self.linear1(x))x=F.relu(self.linear2(x))x=F.relu(self.linear3(x))x=F.relu(self.linear4(x))x=self.linear5(x)#注意由于后续交叉熵损失函数包含激活函数,所以这一层不需要激活函数return xmodel=Net()

3.构建损失函数和优化器

  • 采用交叉熵损失函数
  • 采用SGDM优化器
criterion=torch.nn.CrossEntropyLoss()#损失函数
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)#SGD with momentum

4. 训练和测试

  • 训练:forward + backward + update(记得梯度清零)
  • 测试:不计算梯度 with torch.no_grad():
def train(epoch):running_loss=0.0for batch_idex,data in enumerate(train_loader,0):inputs,target=dataoptimizer.zero_grad()#梯度清零#forward + backward + updateoutputs=model(inputs)loss=criterion(outputs,target)loss.backward()optimizer.step()running_loss+=loss.item()if batch_idex%300==299:#每300epoch输出一次loss信息print('[%d,%5d loss:%.3f]' % (epoch+1,batch_idex+1,running_loss/300))running_loss=0.0def test():correct=0total=0with torch.no_grad():#不需要计算梯度for data in test_loader:images,labels=dataoutputs=model(images)# predicted=torch.max(outputs.data,dim=1)predicted=torch.argmax(outputs.data,dim=1)#求预测数据最大值的下标(指定沿着维度1进行计算)total+=labels.size(0)#size(0)是样本个数N,计算总共预测数据的样本总数correct += (predicted == labels).sum().item()#计算预测正确的数目print('Accuracy on test set: %d %%' % (100 * correct / total))#print(correct)

完整代码

import numpy as np
import torch
from torch.utils.data import DataLoader #For constructing DataLoader
from torchvision import transforms #For constructing DataLoader 对图像进行处理
from torchvision import datasets #For constructing DataLoader
import torch.nn.functional as F #For using function relu()
import torch.optim as optim #For constructing Optimizerbatch_size=64
transform=transforms.Compose([transforms.ToTensor(),#Convert the PIL Image to Tensor.transforms.Normalize((0.1307,),(0.3081,))])#The parameters are mean and std respectively.train_dataset = datasets.MNIST(root='../dataset/mnist',train=True,transform=transform,download=True)
test_dataset = datasets.MNIST(root='../dataset/mnist',train=False,transform=transform,download=True)
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False)class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.linear1=torch.nn.Linear(784,512)self.linear2=torch.nn.Linear(512,256)self.linear3=torch.nn.Linear(256,128)self.linear4=torch.nn.Linear(128,64)self.linear5=torch.nn.Linear(64,10)def forward(self, x):x=x.view(-1,784)#把每一张图片的像素都拼接起来,然后变成二维(N,748)(748=28*28)x=F.relu(self.linear1(x))x=F.relu(self.linear2(x))x=F.relu(self.linear3(x))x=F.relu(self.linear4(x))x=self.linear5(x)#注意由于后续交叉熵损失函数包含激活函数,所以这一层不需要激活函数return xmodel=Net()criterion=torch.nn.CrossEntropyLoss()#损失函数
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)#SGD with momentumdef train(epoch):running_loss=0.0for batch_idex,data in enumerate(train_loader,0):inputs,target=dataoptimizer.zero_grad()#梯度清零#forward + backward + updateoutputs=model(inputs)loss=criterion(outputs,target)loss.backward()optimizer.step()running_loss+=loss.item()if batch_idex%300==299:#每300epoch输出一次loss信息print('[%d,%5d loss:%.3f]' % (epoch+1,batch_idex+1,running_loss/300))running_loss=0.0def test():correct=0total=0with torch.no_grad():#不需要计算梯度for data in test_loader:images,labels=dataoutputs=model(images)# predicted=torch.max(outputs.data,dim=1)predicted=torch.argmax(outputs.data,dim=1)#求预测数据最大值的下标(指定沿着维度1进行计算)total+=labels.size(0)#size(0)是样本个数N,计算总共预测数据的样本总数correct += (predicted == labels).sum().item()#计算预测正确的数目print('Accuracy on test set: %d %%' % (100 * correct / total))#print(correct)if __name__ == '__main__':for epoch in range(10):#epoch=10,训练一轮,测试一轮train(epoch)test()

结果:

Accuracy on test set: 96 %
[5,  300 loss:0.104]
[5,  600 loss:0.096]
[5,  900 loss:0.101]
Accuracy on test set: 96 %
[6,  300 loss:0.078]
[6,  600 loss:0.078]
[6,  900 loss:0.086]
Accuracy on test set: 97 %
[7,  300 loss:0.066]
[7,  600 loss:0.064]
[7,  900 loss:0.067]
Accuracy on test set: 97 %
[8,  300 loss:0.052]
[8,  600 loss:0.054]
[8,  900 loss:0.051]
Accuracy on test set: 97 %
[9,  300 loss:0.042]
[9,  600 loss:0.044]
[9,  900 loss:0.046]
Accuracy on test set: 97 %
[10,  300 loss:0.031]
[10,  600 loss:0.038]
[10,  900 loss:0.036]
Accuracy on test set: 97 %
  • 可以看到准确度在97%就上不去了,这是因为线性模型对于图片数据的特征提取不是很友好。
  • 有很多方法可以对图像进行自动的特征提取,不需要人工设计。(CNN)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/440089.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java反射常用方法

反射思维导图 使用案例 package Reflection.Work.WorkTest01;import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.Arrays;public class WorkDe…

基于数字签名技术的挑战/响应式认证方式

挑战/响应式认证方式简便灵活,实现起来也比较容易。当网络需要验证用户身份时,客户端向服务器提出登录请求;当服务器接收到客户端的验证请求时,服务器端向客户端发送一个随机数,这就是这种认证方式的“冲击&#xff08…

java学习之路(2)-编译java文件运行Java文件

创建.java后缀文本文件HelloWorld .java 写入代码: public class HelloWorld { public static void main(String []args) { System.out.println("Hello World"); } } 运行cmd命令 找到代码所在目录 输入javac编译Java文件生成HelloWorld.class 编译:…

Spring Security简介

什么是Spring Security Spring Security是 Spring提供的安全认证服务的框架。 使用Spring Security可以帮助我 们来简化认证和授权的过程。 官网&#xff1a;Spring Security 对应的maven坐标&#xff1a; <!--security启动器--> <dependency><groupId>or…

服装产业转型升级,iPayLinks帮助企业拓展市场盈更多

从十万件的大订单转变为几百件的小订单&#xff0c;小单快反模式为中国服装出口带来了机遇&#xff0c;也带来了挑战。   “十三行-中大-鹭江”是广州曾经最具代表性的外贸服装产业带。在过去很长的一段时间里&#xff0c;服装外贸老板在这里创造“神话”&#xff1a;24小时内…

C# IP v4转地址·地名 高德

需求: IPv4地址转地址 如&#xff1a;输入14.197.150.014&#xff0c;输出河北省石家庄市 SDK: 目前使用SDK为高德地图WebAPI 高德地图开放平台https://lbs.amap.com/ 可个人开发者使用&#xff0c;不过有配额限制。 WebAPI 免费配额调整公告https://lbs.amap.com/news/…

第一节 分布式架构设计理论与Zookeeper环境搭建

目录 1. 分布式架构设计理论 1. 分布式架构介绍 1.1 什么是分布式 1.2 分布式与集群的区别 1.3 分布式系统特性 1.4 分布式系统面临的问题 2. 分布式理论 2.1 数据一致性 2.1.1 什么是分布式数据一致性 2.1.2 副本一致性 2.1.3 一致性分类 2.2 CAP定理 2.2.1 CAP定…

ros2配合yolov8具体实现

效果图 用yolov8实时检测物体,包括物体的类别,置信度和坐标通过ros2发布出去自定义消息 int64 xmin int64 ymin int64 xmax int64 ymax float32 conf string name发布端代码 from ultralytics import YOLO import cv2 import rclpy from yolo_interfaces.msg import Msgyo…

C++入门(一)— 使用VScode开发简介

文章目录 C 介绍C 擅长领域C 程序是如何开发编译器、链接器和库编译预处理编译阶段汇编阶段链接阶段 安装集成开发环境 &#xff08;IDE&#xff09;配置编译器&#xff1a;构建配置配置编译器&#xff1a;编译器扩展配置编译器&#xff1a;警告和错误级别配置编译器&#xff1…

基于Vue uniapp和java SpringBoot的汽车充电桩微信小程序

摘要&#xff1a; 随着新能源汽车市场的迅猛发展&#xff0c;汽车充电桩的需求日益增长。为了满足市场需求&#xff0c;本课题开发了一款基于Java SpringBoot后端框架和Vue uniapp前端框架的汽车充电桩微信小程序。该小程序旨在为用户提供一个简洁高效的充电服务平台&#xff0…

Pytest中doctests的测试方法应用!

在 Python 的测试生态中&#xff0c;Pytest 提供了多种灵活且强大的测试工具。其中&#xff0c;doctests 是一种独特而直观的测试方法&#xff0c;通过直接从文档注释中提取和执行测试用例&#xff0c;确保代码示例的正确性。本文将深入介绍 Pytest 中 doctests 的测试方法&…

如何使用Python+Flask搭建本地Web站点并结合内网穿透公网访问?

文章目录 前言1. 安装部署Flask并制作SayHello问答界面2. 安装Cpolar内网穿透3. 配置Flask的问答界面公网访问地址4. 公网远程访问Flask的问答界面 前言 Flask是一个Python编写的Web微框架&#xff0c;让我们可以使用Python语言快速实现一个网站或Web服务&#xff0c;本期教程…