动手学深度学习(五)Kaggle房价预测

Kaggle房价数据集,前四个为房价特征,最后一个为标签(房价)。

一、下载数据集

import numpy as np
import pandas as pd
import torch
from torch import nn
from d2l import torch as d2l
import hashlib
import os
import tarfile
import zipfile
import requests# 数据集下载
DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'def download(name, cache_dir=os.path.join('.', 'data')):  # @save"""下载一个DATA_HUB中的文件,返回本地文件名"""assert name in DATA_HUB, f"{name} 不存在于 {DATA_HUB}"url, sha1_hash = DATA_HUB[name]os.makedirs(cache_dir, exist_ok=True)fname = os.path.join(cache_dir, url.split('/')[-1])if os.path.exists(fname):sha1 = hashlib.sha1()with open(fname, 'rb') as f:while True:data = f.read(1048576)if not data:breaksha1.update(data)if sha1.hexdigest() == sha1_hash:return fname  # 命中缓存print(f'正在从{url}下载{fname}...')r = requests.get(url, stream=True, verify=True)with open(fname, 'wb') as f:f.write(r.content)return fnamedef download_extract(name, folder=None):  # @save"""下载并解压zip/tar文件"""fname = download(name)base_dir = os.path.dirname(fname)data_dir, ext = os.path.splitext(fname)if ext == '.zip':fp = zipfile.ZipFile(fname, 'r')elif ext in ('.tar', '.gz'):fp = tarfile.open(fname, 'r')else:assert False, '只有zip/tar文件可以被解压缩'fp.extractall(base_dir)return os.path.join(base_dir, folder) if folder else data_dirdef download_all():  # @save"""下载DATA_HUB中的所有文件"""for name in DATA_HUB:download(name)DATA_HUB['kaggle_house_train'] = (DATA_URL + 'kaggle_house_pred_train.csv','585e9cc93e70b39160e7921475f9bcd7d31219ce')DATA_HUB['kaggle_house_test'] = (DATA_URL + 'kaggle_house_pred_test.csv','fa19780a7b011d9b009e8bff8e99922a8ee2eb90')train_data = pd.read_csv(download('kaggle_house_train'))
test_data = pd.read_csv(download('kaggle_house_test'))  # 读表

查看数据集大小和部分样本:

print(train_data.shape)
print(test_data.shape)print(train_data.iloc[0:4, [0, 1, 2, 3, -3, -2, -1]])

(1460, 81)
(1459, 80)


   Id  MSSubClass MSZoning  LotFrontage SaleType SaleCondition  SalePrice
0   1          60       RL         65.0       WD        Normal     208500
1   2          20       RL         80.0       WD        Normal     181500
2   3          60       RL         68.0       WD        Normal     223500
3   4          70       RL         60.0       WD       Abnorml     140000

二、数据预处理

""" 数据预处理 """
all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:])) # 去掉id列# 将所有缺失的值替换为相应特征的平均值。通过将特征重新缩放到零均值和单位方差来标准化数据
numeric_features = all_features.dtypes[all_features.dtypes != 'object'].indexall_features[numeric_features] = all_features[numeric_features].apply(lambda x: (x - x.mean()) / (x.std()))                                   # 标准化,将所有特征的均值变为0和方差变为1all_features[numeric_features] = all_features[numeric_features].fillna(0)   # 将缺失项设置为0# “Dummy_na=True”将“na”(缺失值)视为有效的特征值,并为其创建指示符特征
all_features = pd.get_dummies(all_features, dummy_na=True) # 为离散值生成独热编码,并增加一列表示空缺值# 从pandas格式中提取NumPy格式,并将其转换为张量表示
n_train = train_data.shape[0]
train_features = torch.tensor(all_features[:n_train].values, dtype=torch.float32)
test_features = torch.tensor(all_features[n_train:].values, dtype=torch.float32)
train_labels = torch.tensor(train_data.SalePrice.values.reshape(-1, 1), dtype=torch.float32)

 查看特征总数大小:

print(all_features.shape)

(2919, 331) 

可以看到经过数据预处理会将特征总数由79增加到331。

三、训练函数

房价就像股票价格一样,我们关心的是相对误差,而不是绝对误差。比如说,农村的房价原本为12.5万,误差10万,和在市中心豪宅区的房价原本为420万,误差10万,显然使用绝对误差对结果评估的影响是不一样的,我们希望使用一种误差测量方法不受样本大小波动的影响,预测昂贵房屋和廉价房屋的误差能够同等影响预测结果,因此需要使用相对误差的测量方法,我们采用均方根损失来测量房价预测的相对误差。

""" 训练 """
loss = nn.MSELoss()
in_features = train_features.shape[1]   # 输入特征总数为331def get_net():net = nn.Sequential(nn.Linear(in_features, 1))return netdef log_rmse(net, features, labels):# 为了在取对数时进一步稳定该值,将小于1的值设置为1clipped_preds = torch.clamp(net(features), 1, float('inf'))rmse = torch.sqrt(loss(torch.log(clipped_preds),torch.log(labels)))return rmse.item()

均方根损失函数

# 均方根损失
def log_rmse(net, features, labels):# 为了在取对数时进一步稳定该值,将小于1的值设置为1clipped_preds = torch.clamp(net(features), 1, float('inf'))rmse = torch.sqrt(loss(torch.log(clipped_preds),torch.log(labels)))return rmse.item()

训练函数

训练函数使用Adam优化器。

# 训练函数
def train(net, train_features, train_labels, test_features, test_labels,num_epochs, learning_rate, weight_decay, batch_size):train_ls, test_ls = [], []train_iter = d2l.load_array((train_features, train_labels), batch_size) # 加载训练数据# 这里使用的是Adam优化算法optimizer = torch.optim.Adam(net.parameters(),lr = learning_rate,weight_decay = weight_decay)for epoch in range(num_epochs):for X, y in train_iter:optimizer.zero_grad()l = loss(net(X), y)l.backward()optimizer.step()train_ls.append(log_rmse(net, train_features, train_labels))if test_labels is not None:test_ls.append(log_rmse(net, test_features, test_labels))return train_ls, test_ls

 

四、K折交叉验证(可选,炼丹步骤)

def get_k_fold_data(k, i, X, y):assert k > 1fold_size = X.shape[0] // kX_train, y_train = None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)X_part, y_part = X[idx, :], y[idx]if j == i:X_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = torch.cat([X_train, X_part], 0)y_train = torch.cat([y_train, y_part], 0)return X_train, y_train, X_valid, y_valid

 当我们在K折交叉验证中训练K次后,返回训练和验证误差的平均值

def k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay,batch_size):train_l_sum, valid_l_sum = 0, 0  # 用于存储训练误差和验证误差的总和for i in range(k):data = get_k_fold_data(k, i, X_train, y_train)net = get_net() #选择模型train_ls, valid_ls = train(net, *data, num_epochs, learning_rate,weight_decay, batch_size) # 训练模型train_l_sum += train_ls[-1] # 将当前训练误差的最后一个值累加到train_l_sum变量中valid_l_sum += valid_ls[-1]if i == 0: # 第一次循环d2l.plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls],xlabel='epoch', ylabel='rmse', xlim=[1, num_epochs],legend=['train', 'valid'], yscale='log')print(f'折{i + 1},训练log rmse{float(train_ls[-1]):f}, 'f'验证log rmse{float(valid_ls[-1]):f}')return train_l_sum / k, valid_l_sum / k

五、模型选择(可选,炼丹步骤)

不断的更换超参数,保留最优的超参数。

k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr,weight_decay, batch_size)
print(f'{k}-折验证: 平均训练log rmse: {float(train_l):f}, 'f'平均验证log rmse: {float(valid_l):f}')

六、训练

""" 训练与预测 """
def train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size):net = get_net()train_ls, _ = train(net, train_features, train_labels, None, None,num_epochs, lr, weight_decay, batch_size)d2l.plot(np.arange(1, num_epochs + 1), [train_ls], xlabel='epoch',ylabel='log rmse', xlim=[1, num_epochs], yscale='log')print(f'训练log rmse:{float(train_ls[-1]):f}')# 将网络应用于测试集。preds = net(test_features).detach().numpy()# 将其重新格式化以导出到Kaggletest_data['SalePrice'] = pd.Series(preds.reshape(1, -1)[0])submission = pd.concat([test_data['Id'], test_data['SalePrice']], axis=1)submission.to_csv('submission.csv', index=False)k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/98302.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

百度文心一言GPT免费入口也来了!!!

文心一言入口地址:文心一言能力全面开放 文心一言是百度全新一代知识增强大语言模型,文心大模型家族的新成员,能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。 文心一言的技…

Keil Flash的下载算法

更进一步的了解Keil Flash的下载算法 前面提到了通用算法的选择,那么问题来了,这个算法文件如何来的呢?如果你所用的MCU不是默认支持的品牌,如何编写属于自己的算法呢? 工具/原料 Keil uVision ULINK2仿真器 方法/…

Vulnstack----5、ATTCK红队评估实战靶场五

文章目录 一 环境搭建二 外网渗透三 内网信息收集3.1 本机信息收集3.2 域内信息收集 四 横向移动4.1 路由转发和代理通道4.2 抓取域用户密码4.3 使用Psexec登录域控4.4 3389远程登录 五、痕迹清理 一 环境搭建 1、项目地址 http://vulnstack.qiyuanxuetang.net/vuln/detail/7/ …

使用docker部署db2

1.使用docker部署db2 1.1 拉db2镜像 将db2镜像拉起到本地。 docker pull ibmcom/db21.2启动容器 docker run -d -p 50000:50000 --name db2 --privilegedtrue -e DB2INST1_PASSWORDdbPassword DBNAMEjumpdb -e LICENSEaccept -v /usr/local/db2:/database ibmcom/db2实例化…

详解TCP/IP的三次握手和四次挥手

文章目录 前言一、TCP/IP协议的三次握手1.1 三次握手流程 二、TCP/IP的四次挥手2.1 四次挥手流程 三、主要字段3.1、标志位(Flags)3.2、序号(sequence number)3.3、确认号(acknowledgement number) 四、状态…

视频监控/视频汇聚/视频云存储EasyCVR平台HLS流集成在小程序无法播放问题排查

安防视频/视频云存储/视频集中存储EasyCVR视频监控综合管理平台可以根据不同的场景需求,让平台在内网、专网、VPN、广域网、互联网等各种环境下进行音视频的采集、接入与多端分发。在视频能力上,视频云存储平台EasyCVR可实现视频实时直播、云端录像、视频…

如何快速搭建母婴行业的微信小程序?

如果你想为你的母婴行业打造一个独特的小程序,但没有任何编程经验,别担心!现在有许多小程序制作平台提供了简单易用的工具,让你可以轻松地建立自己的小程序。接下来,我将为你详细介绍搭建母婴行业小程序的步骤。 首先&…

数学建模--三维图像绘制的Python实现

目录 1.绘制三维坐标轴的方法 2.绘制三维函数的样例1 3.绘制三维函数的样例2 4.绘制三维函数的样例3 5.绘制三维函数的样例4 6.绘制三维函数的样例5 1.绘制三维坐标轴的方法 #%% #1.绘制三维坐标轴的方法 from matplotlib import pyplot as plt from mpl_toolkits.mplot3…

【JPC出版】第二届能源与电力系统国际学术会议 (ICEEPS 2023)

第二届能源与电力系统国际学术会议 (ICEEPS 2023) 2023 2nd International Conference on Energy and Electrical Power Systems 第二届能源与电力系统国际学术会议 (ICEEPS 2023)将于2023年10月27日至29日在中国厦门举行。ICEEPS 将汇集能源科学、电气工程和电力系统领域的…

QTableView合并单元格

QtableView的功能 QTableView是Qt框架提供的用于显示表格数据的类。它是基于MVC(模型-视图-控制器)设计模式的一部分,用于将数据模型和界面视图分离。 以下是一些QTableView的主要特点和功能: 1. 显示表格数据: QTa…

淘宝数据库,主键如何设计的?

聊一个实际问题:淘宝的数据库,主键是如何设计的? 某些错的离谱的答案还在网上年复一年的流传着,甚至还成为了所谓的 MySQL 军规。其中,一个最明显的错误就是关于MySQL 的主键设计。 大部分人的回答如此自信&#xff…

关于el-input和el-select宽度不一致问题解决

1. 情景一 单列布局 对于上图这种情况&#xff0c;只需要给el-select加上style"width: 100%"即可&#xff0c;如下&#xff1a; <el-select v-model"fjForm.region" placeholder"请选择阀门类型" style"width: 100%"><el-o…