第三节 回归实战

news/2025/1/15 19:45:17/文章来源:https://www.cnblogs.com/jyp02/p/18673544

数据处理

超参:人为指定不能改变

测试数据只有x没有标签y
训练数据拆分,82开,作训练集和验证集(验证模型好坏),模型训练不是一路上升的过程,训练几次验证一次,最好的模型save下来

one-hot独热编码 猪(1 0 0) 狗(0 1 0) 猫(0 0 1)

def get_feature_importance(feature_data, label_data, k = 4, column = None):"""feature_data, label_data 要求字符串形式k为选择的特征数量如果需要打印column,需要传入行名此处省略 feature_data, label_data 的生成代码。如果是 CSV 文件,可通过 read_csv() 函数获得特征和标签。这个函数的目的是, 找到所有的特征种, 比较有用的k个特征, 并打印这些列的名字。"""model = SelectKBest(chi2, k=k)      #定义一个选择k个最佳特征的函数feature_data = np.array(feature_data, dtype=np.float64)X_new = model.fit_transform(feature_data, label_data)   #用这个函数选择k个最佳特征#feature_data是特征数据,label_data是标签数据,该函数可以选择出k个特征print('x_new', X_new)scores = model.scores_                # scores即每一列与结果的相关性# 按重要性排序,选出最重要的 k 个indices = np.argsort(scores)[::-1]        #[::-1]表示反转一个列表或者矩阵。# argsort这个函数, 可以矩阵排序后的下标。 比如 indices[0]表示的是,scores中最小值的下标。if column:                            # 如果需要打印选中的列名字k_best_features = [column[i] for i in indices[0:k].tolist()]         # 选中这些列 打印print('k best features are: ',k_best_features)return X_new, indices[0:k]                  # 返回选中列的特征和他们的下标。
class Covid_dataset(Dataset):def __init__(self, file_path, mode, dim=4, all_feature=False):with open(file_path, "r") as f:csv_data = list(csv.reader(f))#list会变为数据data = np.array(csv_data[1:])#去掉第一行if mode == "train": #逢五取一indices = [i for i in range(len(data)) if i % 5 !=0]elif mode == "val":indices = [i for i in range(len(data)) if i % 5 ==0]if all_feature:col_idx = [i for i in range(0,93)]else:_, col_idx = get_feature_importance(data[:,1:-1], data[:,-1], k=dim,column =csv_data[0][1:-1])if mode == "test":x = data[:, 1:].astype(float)#将数据由字符型转为浮点型x = torch.tensor(x[:, col_idx])else:x = data[indices, 1:-1].astype(float) #x 选indices的行 x要去掉最后一列x = torch.tensor(x[:, col_idx])y = data[indices, -1].astype(float)self.y = torch.tensor(y)self.x = (x-x.mean(dim=0, keepdim=True))/x.std(dim=0, keepdim=True)#数据的量纲不同,要归一化.每列的数值差距大 mean()平均值要在一列中取第0维同时保持维度不变self.mode = mode #把mode传到self中def __getitem__(self, item):if self.mode == "test":return self.x[item].float()#测试集没有y float是把变量变为32位,消耗不会太大,else:return self.x[item].float(), self.y[item].float() def __len__(self):return len(self.x)

Dataset类 xy

  1. init 初始化 filepath x[] Y[]
  2. getitem 取数据 idx->X[idx] Y[idx]
  3. len 数据长度


    模型部分
class myModel(nn.Module):#模型首先关注维度变化 回归中用全连接linear 上一层输出一定是下一层输入 输入维度(16,93)16为样本数量,93为样本维度def __init__(self, dim):#初始化 模型长什么样 dim为输入维度super(myModel, self).__init__()#修改这里的初始化self.fc1 = nn.Linear(dim, 100)#全连接(输入维度,输出维度)self.relu = nn.ReLU() #激活函数 不需要参数self.fc2 = nn.Linear(100, 1) #输入100维 输出1维def forward(self, x):#模型前向过程x = self.fc1(x) x = self.relu(x)x = self.fc2(x)if len(x.size())>1: #x有两维要减去一维 去掉第二维,去掉第二个x = x.squeeze(dim=1)return x

超参部分

config = {"lr" : 0.001,"momentum": 0.9,#动量,惯性,继续冲一下"epochs":20,"save_path": "model_save/model.pth", #保存路径"rel_path" : "pred.csv" #
}
loss = nn.MSELoss()optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=config["momentum"])train_val(model, train_loader, val_loader, device, config["epochs"], optimizer, loss, config["save_path"])evaluate(config["save_path"], device, test_loader, config["rel_path"])

训练流程

def train_val(model, train_loader, val_loader, device, epochs, optimizer, loss, save_path):model = model.to(device)plt_train_loss = [] #每轮训练记录loss值,记录所有轮次loss值plt_val_loss = []min_val_loss = 9999999999 #记录最好的模型,最小的loss值for epoch in range(epochs): #冲锋的号角train_loss = 0.0val_loss = 0.0start_time = time.time()model.train() #模型调整为训练模式for batch_x, batch_y in train_loader:x, target = batch_x.to(device), batch_y.to(device)pred = model(x) #得到预测值# train_bat_loss = loss(pred, target,model)train_bat_loss = loss(pred, target) #得到losstrain_bat_loss.backward() #回传optimizer.step() #更新模型optimizer.zero_grad() #模型梯度清0train_loss += train_bat_loss.cpu().item()#在gpu上是张量没法和浮点数相加 item为把数值取出来plt_train_loss.append(train_loss/train_loader.dataset.__len__())#加在所有的loss值里,记录的是平均值model.eval() #调整为验证模式with torch.no_grad(): #所有在张量网上的计算都会计算梯度,验证集不会更新模型for batch_x, batch_y in val_loader:x, target = batch_x.to(device), batch_y.to(device)pred = model(x)# val_bat_loss = loss(pred, target,model)val_bat_loss = loss(pred, target)val_loss += val_bat_loss.cpu().item()plt_val_loss.append(val_loss / val_loader.dataset.__len__())if val_loss < min_val_loss: #模型效果最好,保存模型torch.save(model, save_path)min_val_loss = val_lossprint("[%03d/%03d]  %2.2f secs Trainloss: %.6f Valloss: %.6f"%(epoch, epochs, time.time()-start_time,plt_train_loss[-1],plt_val_loss[-1]))plt.plot(plt_train_loss)#画图函数plt.plot(plt_val_loss)plt.title("loss")plt.legend(["train","val"])#图例plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/869756.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows git bash 文字显示/斜杠开头数字

前言全局说明Windows git bash 文字显示/斜杠开头数字一、说明 详细介绍:https://zhuanlan.zhihu.com/p/133706032二、问题三、解决方法 git config --global core.quotepath false免责声明:本号所涉及内容仅供安全研究与教学使用,如出现其他风险,后果自负。参考、来源: h…

DDR 带宽的计算与监控

DDR 带宽(Double Data Rate Bandwidth)是指 DDR 内存在一秒内可以传输的数据量,通常以 GB/s(Gigabytes per second) 为单位。它是衡量内存系统性能的重要指标,直接影响系统的数据吞吐能力。 1.如何计算 DDR 带宽 计算 DDR 理论带宽的公式为: DDR主频 * 位宽 = 理论带宽其…

1.15

尽力了,之前的粗心导致现在要改很多以前的坑,明天再继续

中考英语优秀范文-热点话题-传统文化-006 Welcome to Chinese Summer Camp 欢迎参加中国夏令营

1 写作要求 假定你是李华,你校今年暑假将为外国学生举办一场汉语夏令营活动(Chinese Summer Camp)。请你根据下面海报的内容,用英语给你的笔友David写一封电子邮件,介绍本次活动并邀请他参加。词数80左右。 Welcome to Chinese Summer Camp Time:July 18th—July 28th, 2…

机器人

本文来自博客园,作者:Traktorea,转载请注明原文链接:https://www.cnblogs.com/kdsmyhome/p/18673586

使用Nginx实现前端映射到公网IP后端内网不映射公网.250115

一、场景: 系统移动端需要映射到公网,但是后端地址不能映射出去 qbpm.xxxx.cn 系统解析内网IP qmbpm.xxxx.cn 移动端解析公网IP 二、思路: 移动端前端公网端口放出80 443端口 移动端后端映射到内网后端地址qbpm.xxxx.cn:8443 三、解决方法: vim nginx.confserver {listen 8…

Qml 中实现任意角为圆角的矩形

在 Qml 中,矩形(Rectangle)是最常用的元素之一。 然而,标准的矩形元素仅允许设置统一的圆角半径。 在实际开发中,我们经常需要更灵活的圆角设置,例如只对某些角进行圆角处理,或者设置不同角的圆角半径。 本文将介绍如何通过自定义 Qml 元素实现一个任意角可为圆角的矩形…

【附源码】JAVA在线投票系统源码+SpringBoot+VUE+前后端分离

学弟,学妹好,我是爱学习的学姐,今天带来一款优秀的项目:在线投票系统源码 。 本文介绍了系统功能与部署安装步骤,如果您有任何问题,也请联系学姐,偶现在是经验丰富的程序员! 一. 系统演示 系统测试截图系统视频演示https://githubs.xyz/show/340.mp4二. 系统概述【 系统…

Python Playwright学习笔记(一)

一、简介 1.1Playwright 是什么? 它是微软在 2020 年初开源的新一代自动化测试工具,其功能和 selenium 类似,都可以驱动浏览器进行各种自动化操作。 1.2、特点是什么支持当前所有的主流浏览器,包括 chrome、edge、firefox、safari; 支持跨平台多语言:支持Windows、Linux、…

智能驾驶数据采集回注测评工具 - ARS

在数据驱动智能驾驶的时代背景下,开发者们总结了一条适用于智能驾驶的数据闭环开发流程,这条开发线路大致包括实车数据采集->数据存储->数据处理->数据分析->数据标注->模型训练->仿真测试->实车测试->部署发布等关键环节,通过不断开发迭代,逐步完…

2025.1.15 学习

2025.1.15 学习 api开放平台 我们希望在后端使用Http请求调用接口,应该怎么做呢 可以用Hutool工具库中的Http请求工具类,使用如下: public class ApiClient {public String getNameByGet(String name){HashMap<String, Object> paramMap = new HashMap<>();para…

2024龙信年终技术考核

1. 分析手机备份文件,该机主的QQ号为?(标准格式:123) 看了下,备份里没有QQ,但是有微信,所以应该是微信绑定的QQ号(早期微信推广时可以用QQ直接注册登录)经过测试,对应的是这个结果为1203494553 2. 分析手机备份文件,该机主的微信号为?(标准格式:abcdefg)结果为…