机器学习——6.模型训练案例: 预测儿童神经缺陷分类TD/ADHD

案例目的

有一份EXCEL标注数据,如下,训练出合适的模型来预测儿童神经缺陷分类。

参考文章:机器学习——5.案例: 乳腺癌预测-CSDN博客

代码逻辑步骤

  1. 读取数据
  2. 训练集与测试集拆分
  3. 数据标准化
  4. 数据转化为Pytorch张量
  5. label维度转换
  6. 定义模型
  7. 定义损失计算函数
  8. 定义优化器
  9. 定义梯度下降函数
  10. 模型训练(正向传播、计算损失、反向传播、梯度清空)
  11. 模型测试
  12. 精度计算

代码实现

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerdf = pd.read_excel('/Users/guojun/Desktop/Learning/machine_learning/Preprocess_Without_WDE_Channels_Data.xlsx')X = df[df.columns[0:8]].values
mapping = {"TD":0,"ADHD":1}
Y = df["Class"].replace(mapping)# 数据集拆分
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.2,random_state=5)
Y_train = Y_train.to_numpy()
Y_test = Y_test.to_numpy()# 数据标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)# 转化为张量
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
Y_train = torch.from_numpy(Y_train.astype(np.float32))
Y_test = torch.from_numpy(Y_test.astype(np.float32))# 真值转为为二维数据
Y_train = Y_train.view(Y_train.shape[0],-1)
Y_test = Y_test.view(Y_test.shape[0],-1)# 定义模型
class Model(torch.nn.Module):def __init__(self,n_input_features):super(Model,self).__init__()self.linear = torch.nn.Linear(n_input_features,1)def forward(self,x):return torch.sigmoid(self.linear(x))model = Model(X_train.shape[1])
# 定义损失函数
loss = torch.nn.BCELoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)# 梯度下降函数
def gradient_descent():# 预测Y值pre_y = model(X_train)# 计算损失l = loss(pre_y,Y_train)# 反向传播l.backward()# 梯度更新optimizer.step()# 梯度清空optimizer.zero_grad()return l,list(model.parameters())# 模型训练
for i in range(10000):l,p = gradient_descent()print(l,p)# 模型测试
mapping = {0:"TD",1:"ADHD"}
index = np.random.randint(0,X_test.shape[0])
pre_y = model(X_test[index])
pre_y = mapping[int(pre_y.round().item())]
gt_y = mapping[int(Y_test[index].item())]
print(pre_y,gt_y)# 计算模型准确率
pres_y = model(X_test).round()
result = np.where(pres_y==Y_test,1,0)
ac = np.sum(result)/result.size
print(ac)

 即使调整参数后,损失在0.68左右就不会再下降了。

最终的准确率只有54%-60%,我会在后面的笔记中使用深度神经网络来重新训练,提升模型精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/687437.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hive的文件存储格式

TEXTFILE 文本文件,默认的文件存储格式,内容是可以直接查看,用来保存非结构化数据; 特点:文本文件的格式一旦定义就无法改变. 除TEXTFILE外,其他文件存储格式的表不能直接从本地文件导入数据,数据要先导入到textfile格式的表中, 然后再从表中用insert导入SequenceFile,RCFile,…

权限束缚术--权限提升你需要知道这些

前言 欢迎来到我的博客 个人主页:北岭敲键盘的荒漠猫-CSDN博客 本文主要对渗透测试中权限提升的一些基础知识进行整理 并不包含权限提升的具体操作 适合要入门权限提升的朋友 提权的重要性 我们在渗透网站时,我们往往会拿到一些权限,但是我们的权限有…

C语言什么是散列法?

一、问题 什么是散列法? 二、解答 散列法是⼀种将字符组成字符串,转换为固定长度(⼀般是更短长度)的数值或索引值的⽅法,也叫哈希法,⼜可以称为杂凑法或关键码⼀地址转换法。 那么,通过散列函数…

Aapache Tomcat AJP 文件包含漏洞(CVE-2020-1938)

1 漏洞描述 CVE-2020-1938 是 Apache Tomcat 中的一个严重安全漏洞,该漏洞涉及到 Tomcat 的 AJP(Apache JServ Protocol)连接器。由于 AJP 协议在处理请求时存在缺陷,攻击者可以利用此漏洞读取服务器上的任意文件,甚至…

融入新科技的SLM27211系列 120V, 3A/4.5A高低边高频门极驱动器兼容UCC27284,MAX15013A

SLM27211是高低边高频门极驱动器,集成了120V的自举二极管,支持高频大电流的输出,可在8V~17V的宽电压范围内驱动MOSFET,独立的高、低边驱动以方便控制,可用于半桥、全桥、双管正激和有源钳位正激等拓。有极好的开通、关…

AI早班车5.11

📢📢📢📣📣📣 哈喽!大家好,我是「奇点」,江湖人称 singularity。刚工作几年,想和大家一同进步🤝🤝 一位上进心十足的【Java ToB端大厂…

HTTPS 原理和 TLS 握手机制

HTTPS的概述与重要性 在当今数字化时代,网络安全问题日益凸显,数据在传输过程中的安全性备受关注。HTTPS 作为一种重要的网络通信协议,为数据的传输提供了强有力的安全保障。它是在 HTTP 的基础上发展而来,通过引入数据加密机制&a…

深入学习指针3

目录 前言 1.二级指针 2.指针数组 3.指针数组模拟二维数组 前言 Hello,小伙伴们我又来了,上期我们讲到了数组名的理解,指针与数组的关系等知识,那今天我们就继续深入到学习指针域数组的练联系,如果喜欢作者菌生产的内容还望不…

实时路况信息:盲人出行的守护者

在这个快速发展的数字时代,科技的力量正以前所未有的方式改变着人们的生活,尤其是对于视障群体而言,技术的每一次进步都可能意味着独立与自由的新篇章。在这样的背景下,实时路况信息对盲人的帮助成为了社会关注的热点话题&#xf…

JAVA使用Apache POI动态导出Word文档

文章目录 一、文章背景二、实现步骤2.1 需要的依赖2.2 创建模板2.3 书写java类2.3.1 模板目录2.3.2 Controller类2.3.2 工具类 2.4 测试2.4.1 浏览器请求接口2.4.2 下载word 三、注意事项四、其他导出word实现方式 一、文章背景 基于Freemarker模版动态生成并导出word文档存在弊…

去除Windows11广告,这款开源工具一键搞定!(汉化版)

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 工具介绍 📒🎈 获取方式 🎈📢 声明 📖 介绍 📖 你是否厌倦了在Windows 11中不断弹出的广告?是否…

Set接口

Set接口的介绍 Set接口基本介绍 无序(添加和取出的顺序不一致),没有索引不允许重复元素,所以最多包含一个nullJDK API中Set接口的实现类:主要有HashSet;TreeSet Set接口的常用方法 和List 接口一样&am…