【PyTorch】在PyTorch中使用线性层和交叉熵损失函数进行数据分类

在PyTorch中使用线性层和交叉熵损失函数进行数据分类

前言:

在机器学习的众多任务中,分类问题无疑是最基础也是最重要的一环。本文将介绍如何在PyTorch框架下,使用线性层和交叉熵损失函数来解决分类问题。我们将以简单的Iris数据集作为起点,探讨线性模型在处理线性可分数据上的有效性。随后,我们将尝试将同样的线性模型应用于复杂的CIFAR-10图像数据集,并分析其性能表现。

背景:

  • Iris数据集:一个经典的线性可分数据集,包含三个类别的鸢尾花,每个类别有50个样本,每个样本有4个特征。
    请添加图片描述

  • CIFAR-10数据集:一个由10个类别组成的图像数据集,每个类别有6000张32x32彩色图像,总共有60000张图像。
    请添加图片描述

Iris数据集分类

数据读取与预处理:

read_data函数负责从CSV文件中读取数据,随机打乱,划分训练集和测试集,并进行标准化处理。

def read_data(file_path, only_test = False, normalize = True):np_data = pd.read_csv(file_path).valuesnp.random.shuffle(np_data)classes = np.unique(np_data[:,-1])class_dict = {}for index, class_name in enumerate(classes):class_dict[index] = class_nameclass_dict[class_name] = indextrain_src = np_data[:int(len(np_data)*0.8)]test_src = np_data[int(len(np_data)*0.8):]train_data = train_src[:,:-1]train_labels = train_src[:, -1].reshape(-1,1)test_data = test_src[:, :-1]test_labels = test_src[:, -1].reshape(-1,1)if (normalize):mean = np.mean(train_data)std = np.std(train_data)train_data = (train_data - mean) / stdmean = np.mean(test_data)std = np.std(test_data)test_data = (test_data - mean) / stdif (only_test):return test_data, test_labels, class_dictreturn train_data, train_labels, test_data, test_labels, class_dict

模型构建:

Linear_classify类定义了一个简单的线性模型,其中包含一个线性层。

class Linear_classify(th.nn.Module):def __init__(self, *args, **kwargs) -> None:super(Linear_classify, self).__init__()self.linear = th.nn.Linear(args[0], args[1])def forward(self, x):y_pred = self.linear(x)return y_pred

训练过程: 在main函数中,我们初始化模型、损失函数和优化器。然后,通过多次迭代来训练模型,并记录损失值的变化。

file_path = "J:\\MachineLearning\\数据集\\Iris\\iris.data"
train_data, train_labels, test_data, test_labels, label_dict = read_data(file_path)
print(train_data.shape)
print(train_labels.shape)
print(label_dict)int_labels = np.vectorize(lambda x: int(label_dict[x]))(train_labels).flatten()
print(int_labels[:10])tensor_labels = th.from_numpy(int_labels).type(th.long) 
num_classes = int(len(label_dict)/2)
train_data = th.from_numpy(train_data.astype("float32"))print (train_data.shape)
print (train_data[:2])
linear_classifier = Linear_classify(int(train_data.shape[1]), int(len(label_dict)/2))
loss_function = th.nn.CrossEntropyLoss()
optimizer = th.optim.SGD(linear_classifier.parameters(), lr = 0.001)
epochs = 10000
best_loss = 100
turn_to_bad_loss_count = 0
loss_history = []
for epoch in range(epochs):y_pred = linear_classifier(train_data)#print(y_pred)#print(y_pred.shape)loss = loss_function(y_pred, tensor_labels)if (float(loss.item()) > best_loss):turn_to_bad_loss_count += 1else:best_loss = float(loss.item())if (turn_to_bad_loss_count > 1000):breakif (epoch % 10 == 0):print("epoch {} loss is {}".format(epoch, loss))loss_history.append(float(loss.item()))loss.backward()optimizer.step()
plt.plot(loss_history)
plt.show()

评估与测试

使用测试集数据评估模型的准确率,并通过可视化损失值的变化来分析模型的学习过程。

accuracy = []
for _ in range(10):test_data, test_labels, label_dict = read_data(file_path, only_test = True)test_result = linear_classifier(th.from_numpy(test_data.astype("float32")))print(test_result[:10])result_index = test_result.argmax(dim=1)iris_name_result = np.vectorize(lambda x: str(label_dict[x]))(result_index).reshape(-1,1)accuracy.append(len(iris_name_result[iris_name_result == test_labels]) / len(test_labels))print("Accuracy is {}".format(np.mean(accuracy)))

结果

收敛很好很快

请添加图片描述

准确率较高

Accuracy is 0.9466666666666667

CIFAR-10数据集分类

关键改动

使用unpickle和read_data函数处理数据集,这部分是和前面不一样的
def unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dictdef read_data(file_path, gray = False, percent = 0, normalize = True):data_src = unpickle(file_path)np_data = np.array(data_src["data".encode()]).astype("float32")np_labels = np.array(data_src["labels".encode()]).astype("float32").reshape(-1,1)single_data_length = 32*32 image_ret = Noneif (gray):np_data = (np_data[:, :single_data_length] + np_data[:, single_data_length:(2*single_data_length)] + np_data[:, 2*single_data_length : 3*single_data_length])/3image_ret = np_data.reshape(len(np_data),32,32)else:image_ret = np_data.reshape(len(np_data),32,32,3)if(normalize):mean = np.mean(np_data)std = np.std(np_data)np_data = (np_data - mean) / stdif (percent == 0):return np_data, np_labels, image_ret else:return np_data[:int(len(np_data)*percent)], np_labels[:int(len(np_labels)*percent)], image_ret[:int(len(image_ret)*percent)]

运行结果

损失可以收敛,但收敛的幅度有限

可见只是从2.x 下降到了1.x
请添加图片描述

准确率比瞎猜准了3倍,非常的nice

请添加图片描述

train Accuracy is 0.6048

test Accuracy is 0.282

注意点:

  • 数据标准化:为了提高模型的收敛速度和准确率,对数据进行标准化处理是非常重要的,在本例中,不使用标准化会出现梯度爆炸,亲测。
  • 类别标签处理:在使用交叉熵损失函数时,需要确保类别标签是整数形式。

优化点:

  • 学习率调整:适当调整学习率可以帮助模型更快地收敛。
  • 早停法:当连续多次迭代损失值不再下降时,提前终止训练可以防止过拟合。
  • 损失函数选择:对于不同的问题,选择合适的损失函数对模型性能有显著影响,在多分类问题中,使用交叉熵损失函数是常见的选择,在pytorch中,交叉熵模块包含了softmax激活函数,这是其可以进行多分类的关键。

Softmax函数的推导过程如下:

首先,我们有一个未归一化的输入向量 z z z,其形状为 ( n , ) (n,) (n,),其中 n n n 是类别的数量。我们希望将这个向量转化为一个概率分布,其中所有元素的总和为1。

我们可以通过以下步骤来计算 softmax 函数:

  1. z z z 中的每个元素应用指数函数,得到一个新的向量 e z e^z ez

  2. 计算 e z e^z ez 中的最大值,记作 z ^ \hat{z} z^

  3. e z e^z ez 中的每个元素减去 z ^ \hat{z} z^,得到一个新的向量 v v v

  4. v v v 中的每个元素应用指数函数,得到一个新的向量 e v e^v ev

  5. 计算 e v e^v ev 中的最大值,记作 v ^ \hat{v} v^

  6. e v e^v ev 中的每个元素除以 v ^ \hat{v} v^,得到最终的概率分布。

以上步骤可以用以下的公式表示:

z = ( z 1 , z 2 , … , z n ) T e z = ( e z 1 , e z 2 , … , e z n ) T z ^ = m a x ( e z ) v = e z − z ^ e v = ( e v 1 , e v 2 , … , e v n ) T v ^ = m a x ( e v ) p = e v v ^ \begin{align*} z &= (z_1, z_2, \ldots, z_n)^T \\ e^z &= (e^{z_1}, e^{z_2}, \ldots, e^{z_n})^T \\ \hat{z} &= max(e^z) \\ v &= e^z - \hat{z} \\ e^v &= (e^{v_1}, e^{v_2}, \ldots, e^{v_n})^T \\ \hat{v} &= max(e^v) \\ p &= \frac{e^v}{\hat{v}} \end{align*} zezz^vevv^p=(z1,z2,,zn)T=(ez1,ez2,,ezn)T=max(ez)=ezz^=(ev1,ev2,,evn)T=max(ev)=v^ev

其中, p p p 是最终的概率分布。

结论:
通过实验,我们发现线性模型在Iris数据集上表现良好,但在CIFAR-10数据集上效果不佳。这说明线性模型在处理复杂的非线性问题时存在局限性。为了解决这一问题,我们将在后续的博客中介绍如何使用卷积神经网络来提高图像分类的准确率。

后记:
感谢您的阅读,希望本文能够帮助您了解如何在PyTorch中使用线性层和交叉熵损失函数进行数据分类。敬请期待我们的下一篇博客——“在PyTorch中使用卷积神经网络进行图像分类”。

完整代码

分类Iris数据集

import torch as th
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torchvisiondef read_data(file_path, only_test = False, normalize = True):np_data = pd.read_csv(file_path).valuesnp.random.shuffle(np_data)classes = np.unique(np_data[:,-1])class_dict = {}for index, class_name in enumerate(classes):class_dict[index] = class_nameclass_dict[class_name] = indextrain_src = np_data[:int(len(np_data)*0.8)]test_src = np_data[int(len(np_data)*0.8):]train_data = train_src[:,:-1]train_labels = train_src[:, -1].reshape(-1,1)test_data = test_src[:, :-1]test_labels = test_src[:, -1].reshape(-1,1)if (normalize):mean = np.mean(train_data)std = np.std(train_data)train_data = (train_data - mean) / stdmean = np.mean(test_data)std = np.std(test_data)test_data = (test_data - mean) / stdif (only_test):return test_data, test_labels, class_dictreturn train_data, train_labels, test_data, test_labels, class_dictclass Linear_classify(th.nn.Module):def __init__(self, *args, **kwargs) -> None:super(Linear_classify, self).__init__()self.linear = th.nn.Linear(args[0], args[1])def forward(self, x):y_pred = self.linear(x)return y_preddef main():file_path = "J:\\MachineLearning\\数据集\\Iris\\iris.data"train_data, train_labels, test_data, test_labels, label_dict = read_data(file_path)print(train_data.shape)print(train_labels.shape)print(label_dict)int_labels = np.vectorize(lambda x: int(label_dict[x]))(train_labels).flatten()print(int_labels[:10])tensor_labels = th.from_numpy(int_labels).type(th.long) num_classes = int(len(label_dict)/2)train_data = th.from_numpy(train_data.astype("float32"))print (train_data.shape)print (train_data[:2])linear_classifier = Linear_classify(int(train_data.shape[1]), int(len(label_dict)/2))loss_function = th.nn.CrossEntropyLoss()optimizer = th.optim.SGD(linear_classifier.parameters(), lr = 0.001)epochs = 10000best_loss = 100turn_to_bad_loss_count = 0loss_history = []for epoch in range(epochs):y_pred = linear_classifier(train_data)#print(y_pred)#print(y_pred.shape)loss = loss_function(y_pred, tensor_labels)if (float(loss.item()) > best_loss):turn_to_bad_loss_count += 1else:best_loss = float(loss.item())if (turn_to_bad_loss_count > 1000):breakif (epoch % 10 == 0):print("epoch {} loss is {}".format(epoch, loss))loss_history.append(float(loss.item()))loss.backward()optimizer.step()plt.plot(loss_history)plt.show()plt.show()accuracy = []for _ in range(10):test_data, test_labels, label_dict = read_data(file_path, only_test = True)test_result = linear_classifier(th.from_numpy(test_data.astype("float32")))print(test_result[:10])result_index = test_result.argmax(dim=1)iris_name_result = np.vectorize(lambda x: str(label_dict[x]))(result_index).reshape(-1,1)accuracy.append(len(iris_name_result[iris_name_result == test_labels]) / len(test_labels))print("Accuracy is {}".format(np.mean(accuracy)))if (__name__ == "__main__"):main()

分类CIFAR-10数据集

import torch as th
import numpy as np
import pandas as pd
import matplotlib.pyplot as pltdef unpickle(file):import picklewith open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dictdef read_data(file_path, gray = False, percent = 0, normalize = True):data_src = unpickle(file_path)np_data = np.array(data_src["data".encode()]).astype("float32")np_labels = np.array(data_src["labels".encode()]).astype("float32").reshape(-1,1)single_data_length = 32*32 image_ret = Noneif (gray):np_data = (np_data[:, :single_data_length] + np_data[:, single_data_length:(2*single_data_length)] + np_data[:, 2*single_data_length : 3*single_data_length])/3image_ret = np_data.reshape(len(np_data),32,32)else:image_ret = np_data.reshape(len(np_data),32,32,3)if(normalize):mean = np.mean(np_data)std = np.std(np_data)np_data = (np_data - mean) / stdif (percent == 0):return np_data, np_labels, image_ret else:return np_data[:int(len(np_data)*percent)], np_labels[:int(len(np_labels)*percent)], image_ret[:int(len(image_ret)*percent)]class Linear_classify(th.nn.Module):def __init__(self, *args, **kwargs) -> None:super(Linear_classify, self).__init__()self.linear = th.nn.Linear(args[0], args[1])def forward(self, x):x = self.linear(x)return xdef main():file_path = "J:\\MachineLearning\\数据集\\cifar-10-batches-py\\data_batch_1"train_data, train_labels, image_data = read_data(file_path, percent=0.5)print(train_data.shape)print(train_labels.shape)print(image_data.shape)'''fig, axs = plt.subplots(3, 3)for i, ax in enumerate(axs.flat):image = image_data[i]ax.imshow(image_data[i],cmap="rgb")ax.axis('off') # 关闭坐标轴plt.show()'''int_labels = train_labels.flatten()print(int_labels[:10])tensor_labels = th.from_numpy(int_labels).type(th.long) num_classes = int(len(np.unique(int_labels)))train_data = th.from_numpy(train_data)print (train_data.shape)print (train_data[:2])linear_classifier = Linear_classify(int(train_data.shape[1]), num_classes)loss_function = th.nn.CrossEntropyLoss()optimizer = th.optim.SGD(linear_classifier.parameters(), lr = 0.01)epochs = 7000best_loss = 100turn_to_bad_loss_count = 0loss_history = []for epoch in range(epochs):y_pred = linear_classifier(train_data)#print(y_pred)#print(y_pred.shape)loss = loss_function(y_pred, tensor_labels)if (float(loss.item()) > best_loss):turn_to_bad_loss_count += 1else:best_loss = float(loss.item())if (turn_to_bad_loss_count > 100):breakif (epoch % 10 == 0):print("epoch {} loss is {}".format(epoch, loss))loss_history.append(float(loss.item()))loss.backward()optimizer.step()plt.plot(loss_history)plt.show()plt.show()test_result = linear_classifier(train_data)print(test_result[:10])result_index = test_result.argmax(dim=1).reshape(-1,1)accuracy = (len(result_index[result_index.detach().numpy() == train_labels]) / len(train_labels))print("train Accuracy is {}".format(accuracy))file_path = "J:\\MachineLearning\\数据集\\cifar-10-batches-py\\test_batch"test_data, test_labels, image_data = read_data(file_path)test_result = linear_classifier(th.from_numpy(test_data))print(test_result[:10])result_index = test_result.argmax(dim=1).reshape(-1,1)accuracy = (len(result_index[result_index.detach().numpy() == test_labels]) / len(test_labels))print("test Accuracy is {}".format(accuracy))if (__name__ == "__main__"):main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/419159.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数组中第K个最大元素(算法村第十关白银挑战)

215. 数组中的第K个最大元素 - 力扣(LeetCode) 给定整数数组 nums 和整数 k,请返回数组中第 **k** 个最大的元素。 请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。 你必须设计并实现…

JVM工作原理与实战(二十一):内存管理

专栏导航 JVM工作原理与实战 RabbitMQ入门指南 从零开始了解大数据 目录 专栏导航 前言 一、不同语言的内存管理 1.C/C的内存管理 2.Java的内存管理 二、垃圾回收的对比 1.自动垃圾回收与手动垃圾回收的对比 2.优点与缺点 总结 前言 JVM作为Java程序的运行环境&#…

Golang 搭建 WebSocket 应用(八) - 完整代码

本文应该是本系列文章最后一篇了,前面留下的一些坑可能后面会再补充一下,但不在本系列文章中了。 整体架构 再来回顾一下我们的整体架构: 在我们的 demo 中,包含了以下几种角色: 客户端:一般是浏览器&am…

XSS漏洞:利用多次提交技巧实现存储型XSS

目录 搭建环境 XSS攻击 测试 xss系列往期文章: 初识XSS漏洞-CSDN博客 利用XSS漏洞打cookie-CSDN博客 XSS漏洞:xss-labs靶场通关-CSDN博客 XSS漏洞:prompt.mi靶场通关-CSDN博客 XSS漏洞:xss.haozi.me靶场通关-CSDN博客 本…

磁盘分区机制

lsblk查看分区 Linux分区 挂载的经典案例 1. 虚拟机增加磁盘 点击这里,看我的这篇文章操作 添加之后,需要重启系统,不重启在系统里看不到新硬盘哦 出来了,但还没有分区 2. 分区 还没有格式化 3. 格式化磁盘 4. 挂载 5. 卸载…

汇编语言----X86汇编指令

目录 1.汇编指令的构成 2.X86架构CPU中包含的寄存器 3.常见的x86汇编指令 (1)算数运算 (2)逻辑运算 (3)其他 4.AT&T格式 5.选择语句(分支结构) 6.循环语句 &#xff0…

huggingface学习 | 云服务器使用hf_hub_download下载huggingface上的模型文件

系列文章目录 huggingface学习 | 云服务器使用git-lfs下载huggingface上的模型文件 文章目录 系列文章目录一、hf_hub_download介绍二、找到需要下载的huggingface文件三、准备工作及下载过程四、全部代码 一、hf_hub_download介绍 hf_hub_download是huggingface官方支持&…

Linux中的共享内存

定义: 共享内存允许两个或者多个进程共享物理内存的同一块区域(通常被称为段)。由于一个共享内存段会称为一个进程用户空间的一部分,因此这种 IPC 机制无需内核介入。所有需要做的就是让一个进程将数 据复制进共享内存中&#xff…

力扣精选算法100题——串联所有单词的字串(滑动窗口专题)

本题链接——串联所有单词的字串 本题和找到字符串中所有字母异位词题目非常相似,思路都是一样。通过自己的大脑能发现其中的相似之处。 第一步:了解题意 就按实例来分析吧,这样更通俗易懂。 words["ab","cd","ef…

mysql从库重新搭建的流程

背景 生产环境上的主从集群,因为一些异常原因,导致主从同步失败。现记录下通过重做mysql从库的方式来解决,重做过程不影响主库。 步骤 1、在主库上的操作步骤 备份主库所有数据,并将dump.sql文件拷贝到从库/tmp目录 mysqldump …

Flutter 综述

Flutter 综述 1 介绍1.1 概述1.2 重要节点1.3 移动开发中三种跨平台框架技术对比1.4 flutter 技术栈1.5 IDE1.6 Dart 语言1.7 应用1.8 框架 2 Flutter的主要组成部分3 资料书籍 《Flutter实战第二版》Dart 语言官网Flutter中文开发者社区flutter 官网 4 搭建Flutter开发环境参考…

vue3-模版引用

模版引用 ref 属性 场景&#xff1a;需要直接访问底层 DOM 元素。 方法&#xff1a;使用特殊的 ref 属性。 <input ref"input">ref 属性 允许我们在一个特定的 DOM 元素或子组件实例被挂载后&#xff0c;获得对它的直接引用。 访问模板引用 小 Demo: 当 i…