机器学习 - multi-class 数据集训练 (含代码)

直接上代码

# Multi-class datasetimport numpy as np
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
N = 100 # number of points per class
D = 2 # dimensionality
K = 3 # number of classes
X = np.zeros((N*K, D))
y = np.zeros(N*K, dtype='uint8')
for j in range(K):ix = range(N*j, N*(j+1))r = np.linspace(0.0, 1, N)  # radiust = np.linspace(j*4, (j+1)*4, N) + np.random.randn(N)*0.2X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]y[ix] = j
plt.scatter(X[:, 0], X[:, 1], c = y, s = 40, cmap=plt.cm.RdYlBu)
plt.show()# Turn data into tensors
X = torch.from_numpy(X).type(torch.float)
y = torch.from_numpy(y).type(torch.LongTensor)# Create train and test splits
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = RANDOM_SEED)#!pip -q install torchmetrics
from torchmetrics import Accuracy
acc_fn = Accuracy(task = "multiclass", num_classes = 3).to("cpu")class SpiralModel(nn.Module):def __init__(self):super().__init__()self.linear1 = nn.Linear(in_features = 2, out_features = 10)self.linear2 = nn.Linear(in_features = 10, out_features = 10)self.linear3 = nn.Linear(in_features = 10, out_features = 3)self.relu = nn.ReLU()def forward(self, x):return self.linear3(self.relu(self.linear2(self.relu(self.linear1(x)))))model_1 = SpiralModel().to("cpu")# Setup data to be device agnostic
X_train, y_train = X_train.to("cpu"), y_train.to("cpu")
X_test, y_test = X_test.to("cpu"), y_test.to("cpu")# Setup Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_1.parameters(),lr = 0.02)# Build a training loop for the model
epochs = 1000# Loop over data
for epoch in range(epochs):## Trainingmodel_1.train()y_logits = model_1(X_train)y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1)loss = loss_fn(y_logits, y_train)acc = acc_fn(y_pred, y_train)optimizer.zero_grad()loss.backward()optimizer.step()## Trainingmodel_1.eval()with torch.inference_mode():test_logits = model_1(X_test)test_pred = torch.softmax(test_logits, dim=1).argmax(dim=1)test_loss = loss_fn(test_logits, y_test)test_acc = acc_fn(test_pred, y_test)if epoch % 100 == 0:print(f"Epoch: {epoch} | Loss: {loss:.2f} Acc: {acc:.2f} | Test loss: {test_loss:.2f} Test acc: {test_acc:.2f}")# Plot decision boundaries for training and test sets 
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model_1, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model_1, X_test, y_test)

结果如下

Epoch: 0 | Loss: 1.13 Acc: 0.32 | Test loss: 1.12 Test acc: 0.37
Epoch: 100 | Loss: 0.09 Acc: 0.98 | Test loss: 0.06 Test acc: 1.00
Epoch: 200 | Loss: 0.04 Acc: 0.99 | Test loss: 0.01 Test acc: 1.00
Epoch: 300 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 400 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 500 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 600 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 700 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 800 | Loss: 0.02 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00
Epoch: 900 | Loss: 0.01 Acc: 0.99 | Test loss: 0.00 Test acc: 1.00

结果1
结果2

给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/596322.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

摆动序列(力扣376)

文章目录 题目前知题解一、思路二、解题方法三、Code 总结 题目 Problem: 376. 摆动序列 如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。第一个差(如果存在的话)可能是正数或负数。仅有一个元素或者含两个不等元…

C语言程序编译全流程,从源代码到二进制

源程序 对于一个最简单的程序&#xff1a; int main(){int a 1;int b 2;int c a b;return 0; }预处理 处理源代码中的宏指令&#xff0c;例如#include等 clang -E test.c处理结果&#xff1a; # 1 "test.c" # 1 "<built-in>" 1 # 1 "&…

本地Windows打包启动前端后台

本地Windows打包启动前端后台 1、安装jdk Windows JDK安装 2、Nginx 2.1、将 nginx-1.16.1文件夹复制到D:\home\jisapp目录下 2.2、域名证书配置&#xff1a; 将域名证书放到D:\home\jisapp\ssl\2023目录下->配置nginx.conf文件&#xff08;D:\home\jisapp\nginx-1.22.0…

智能感应门改造工程

今天记录一下物联网专业学的工程步骤及实施过程 智能感应门改造工程 1 规划设计1.1 项目设备清单1.2项目接线图 软件设计信号流 设备安装与调试工程函数 验收 1 规划设计 1.1 项目设备清单 1.2项目接线图 软件设计 信号流 设备安装与调试 工程函数 工程界面: using System; …

新手如何开始运营朋友圈?分享朋友圈前5条内容运营技巧!

最近加入的新伙伴比较多&#xff0c;不少伙伴反馈一个问题&#xff1a;作为新人&#xff0c;前期我们的朋友圈要如何发&#xff1f;要怎么开始发朋友&#xff1f;要怎么配图&#xff0c;怎么配文案&#xff1f; 为了解决新伙伴们的这个问题&#xff0c;今天同伙伴们分享&#…

7.二叉树的遍历方式及二叉树习题

4.二叉树链式结构的实现 二叉树是&#xff1a; 空树 非空&#xff1a;根节点&#xff0c;根节点的左子树、根节点的右子树组成的。 4.1二叉树的遍历 4.2.1 前序、中序以及后序遍历 前序遍历(Preorder Traversal 亦称先序遍历)——访问根结点的操作发生在遍历其左右子树之前…

华清远见STM32MP157开发板助力嵌入式大赛ST赛道MPU应用方向项目开发

第七届&#xff08;2024&#xff09;全国大学生嵌入式芯片与系统设计竞赛&#xff08;以下简称“大赛”&#xff09;已经拉开帷幕&#xff0c;大赛的报名热潮正席卷而来。嵌入式大赛截止今年已连续举办了七届&#xff0c;为教育部认可的全国普通高校大学生国家级A类赛事&#x…

数据结构和算法:分治

分治算法 分治&#xff08;divide and conquer&#xff09;&#xff0c;全称分而治之&#xff0c;是一种非常重要且常见的算法策略。分治通常基于递归实现&#xff0c;包括“分”和“治”两个步骤。 1.分&#xff08;划分阶段&#xff09;&#xff1a;递归地将原问题分解为两个…

Spring源码解析-容器基本实现

spring源码解析 整体架构 defaultListableBeanFactory xmlBeanDefinitionReader 创建XmlBeanFactory 对资源文件进行加载–Resource 利用LoadBeandefinitions(resource)方法加载配置中的bean loadBeandefinitions加载步骤 doLoadBeanDefinition xml配置模式 validationMode 获…

Open CASCADE学习|放样建模

在CAD软件中&#xff0c;Loft&#xff08;放样&#xff09;功能则是用于创建三维实体或曲面的重要工具。通过选取两个或多个横截面&#xff0c;并沿这些横截面进行放样&#xff0c;可以生成复杂的三维模型。在CAD放样功能的操作中&#xff0c;用户可以选择不同的选项来定制放样…

二分答案 蓝桥杯 2022 省A 青蛙过河

有些地方需要解释&#xff1a; 1.从学校到家和从家到学校&#xff0c;跳跃都是一样的&#xff0c;直接看作2*x次过河就可以。 2.对于一个跳跃能力 y&#xff0c;青蛙能跳过河 2x 次&#xff0c;当且仅当对于每个长度为 y 的区间&#xff0c;这个区间内 h 的和都大于等于…

并发编程01-深入理解Java并发/线程等待/通知机制

为什么我们要学习并发编程&#xff1f; 最直白的原因&#xff0c;因为面试需要&#xff0c;我们来看看美团和阿里对 Java 岗位的 JD&#xff1a; 从上面两大互联网公司的招聘需求可以看到&#xff0c; 大厂的 Java 岗的并发编程能力属于标配。 而在非大厂的公司&#xff0c; 并…