torch神经网络温度预测

news/2024/12/21 20:39:32/文章来源:https://www.cnblogs.com/jackchen28/p/18448801

数据件文件temp.csv


"""
气温预测
"""
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import warnings
warnings.filterwarnings('ignore')features = pd.read_csv('temps.csv')
features.head()
# --------数据说明--------
# temp_2:前天的最高温度值
# temp_1:昨天的最高温度值
# average:在历史中,每年这一天的平均最高温度值
# actual:这就是我们的标签,当天的真实最高温度
# friend:朋友猜测的可能值
print(features.shape)
years = features['year']
months = features['month']
days = features['day']# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]
print(dates[:5])# -----------画图看数据-----------
plt.style.use('fivethirtyeight')
# 设置布局
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))
fig.autofmt_xdate(rotation=45)# 标签值
ax1.plot(dates, features['actual'])
ax1.set_xlabel('');ax1.set_ylabel('Temp');ax1.set_title('Max Temp')
# 昨天
ax2.plot(dates, features['temp_1'])
ax2.set_xlabel('');ax2.set_ylabel('Temp');ax2.set_title('Previous Max Temp')
# 前天
ax3.plot(dates, features['temp_2'])
ax3.set_xlabel('Date');ax3.set_ylabel('Temp');ax3.set_title('Two Days Prior Max Temp')
# 我的朋友
ax4.plot(dates, features['friend'])
ax4.set_xlabel('Date');ax4.set_ylabel('Temp');ax4.set_title('Friend Estimate')
plt.tight_layout(pad=2)
plt.show()# week列为字符串不是数值,利用独热编码,将数据中非字符串转换为数值,并拼接到数据中
features = pd.get_dummies(features)
# 看独热编码的效果
print(features.head(5))# 标签
labels = np.array(features['actual'])# 去掉标签用作特征
features = features.drop('actual', axis=1)# 保存列名用于展示
features_list = list(features.columns)# 转换为合适的格式
features = np.array(features)
print(features.shape)# 数据标准化
from sklearn import preprocessinginput_features = preprocessing.StandardScaler().fit_transform(features)# 看一下数字标准化的效果
print(input_features[0])# =======================构建神经网络模型=============================== #
# 将输入和预测转为tensor
x = torch.tensor(input_features, dtype=float)
y = torch.tensor(labels, dtype=float)# 权重参数初始化
weights = torch.randn((14, 128), dtype= float, requires_grad= True)
biases = torch.randn(128, dtype=float, requires_grad= True)
weights2 = torch.randn((128, 1), dtype=float, requires_grad= True)
biases2 = torch.randn(1, dtype=float, requires_grad=True)learning_rate = 0.001
losses = []for i in range(1000):# 前向传播# 计算隐藏层hidden = x.mm(weights) + biases# 加入激活函数hidden = torch.relu(hidden)# 预测结果predictions = hidden.mm(weights2) + biases2# 计算损失loss = torch.mean((predictions - y)**2)losses.append(loss.data.numpy())# 打印损失if i % 100 == 0:print('loss:', loss)# 反向传播loss.backward()# 更新参数weights.data.add_(- learning_rate * weights.grad.data)biases.data.add_(- learning_rate * biases.grad.data)weights2.data.add_(- learning_rate * weights2.grad.data)biases2.data.add_(- learning_rate * biases2.grad.data)# 梯度清零weights.grad.data.zero_()biases.grad.data.zero_()weights2.grad.data.zero_()biases2.grad.data.zero_()# -----或者我们使用简化的方法----
input_size = input_features.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size),torch.nn.Sigmoid(),torch.nn.Linear(hidden_size, output_size),
)# 指定损失函数
cost = torch.nn.MSELoss(reduction='mean')# 指定优化器
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)# 训练网络
losses = []
for i in range(1000):batch_loss = []for start in range(0, len(input_features), batch_size):end = start + batch_size if start + batch_size < len(input_features) else len(input_features)xx = torch.tensor(input_features[start:end], dtype=torch.float, requires_grad=True)yy = torch.tensor(labels[start:end], dtype=torch.float, requires_grad=True)prediction = my_nn(xx)loss = cost(prediction, yy)optimizer.zero_grad()loss.backward(retain_graph=True)optimizer.step()batch_loss.append(loss.data.numpy())if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))# 预测,并以图片的形式展示
# 预测结果
x = torch.tensor(input_features, dtype=torch.float)
predict = my_nn(x).data.numpy() # 转化为numpy格式,tensor格式画不了图# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]# 创建一个表格来保存日期和其对应的标签数值
true_data = pd.DataFrame(data={'date': dates, 'actual': labels})# 再创建一个来存日期和其对应的模型预测值
months = features[:, features_list.index('month')]
days = features[:, features_list.index('day')]
years = features[:, features_list.index('year')]test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
test_dates = datespredictions_data = pd.DataFrame(data={'date': test_dates, 'prediction': predict.reshape(-1)})# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label='actual')# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label='prediction')
plt.xticks(rotation='vertical');
plt.legend()# 图名
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')
plt.title('Actual and Predicted Values')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/808834.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VUE2常见问题以及解决方案汇总,vue+element ui 问题以及解决方案汇总(不断更新中)

解决vue项目中 el-table 的 @row-click 事件与行内点击事件冲突,点击事件不生效(表格行点击事件和行内元素点击事件冲突)需要阻止事件冒泡 问题描述 1.点击列的编辑按钮,会触发按钮本身事件,同时会触发行点击事件 2.点击列的元素,会触发本身事件,同时会触发行点击事件 需…

1分钟了解什么是docker和docker-compose?前后端必知必会技能GET啦

@目录前情提要Docker定义:主要功能:命令示例:其他Docker Compose定义:我为什么使用它?主要功能:命令示例:主要区别配置文件:命令行操作:依赖关系管理:实际应用场景单个服务:多服务应用:总结结语欢迎路过的小哥哥小姐姐们提出更好的意见哇~~ 前情提要 本文非常简短,如果需要详…

VUE2常见问题以及解决方案汇总(不断更新中)

vue子组件传递数据给父组件 子组件可以使用 $emit 向父组件传递数据。父组件监听这个事件,并在事件触发时接收数据。 上代码 子组件 (Child.vue) <template><button @click="sendDataToParent">Send Data to Parent</button> </template>&l…

1分钟搞懂K8S中的NodeSelector

@目录NodeSelector是什么?为什么使用NodeSelector?怎么用NodeSelector?POD配置示例yaml配置示例如何知道K8S上面有哪些节点,每个节点都有什么信息呢?1. 使用kubectl命令行工具查看所有节点及其标签2. 使用kubectl命令行工具查看特定节点的标签代码举例常见的NodeSelector节…

谷歌浏览器调试技巧

谷歌浏览器断点调试# “资源(Sources)”面板# 进入浏览器,点击F12,进入调试面板,点击source 切换按钮 会打开文件列表的选项卡。资源(Sources)面板包含三个部分:文件导航(File Navigator) 区域列出了 HTML、JavaScript、CSS 和包括图片在内的其他依附于此页面的文件。…

两种方案手把手教你多种服务器使用tinyproxy搭建http代理

@目录Tinyproxy是什么?特点功能安装方案一:Docker安装安装tinyproxy镜像,启动容器将内部8888端口至外部,ANY代表允许所有ip访问代理获得代理地址安装方案二:系统包管理器Tinyproxy 可以通过包管理器安装。以下是一些常见的 Linux 和 mac发行版的安装命令:MAC电脑Linux配置…

Docker系列-5种方案超详细讲解docker数据存储持久化(volume,bind mounts,NFS等)

@目录Docker的数据持久化是什么?1.数据卷(Data Volumes)使用Docker 创建数据卷创建数据卷创建一个容器,将数据卷挂载到容器中的 /data 目录。进入容器,查看数据卷内容停止并重新启动容器,数据卷中的数据仍然存在再次进入容器,检查文件是否存在使用 Docker Compose 创建数…

基于simulink的风轮机发电系统建模与仿真

1.课题概述使用simulink实现风轮机发电系统建模与仿真,包括风速模型(基本风+阵风+阶跃风+随机风组成),风力机模型,飞轮储能模块等。2.系统仿真结果 3.核心程序与模型 版本:MATLAB2022a风速模块:风力机模块 整体模型4.系统原理简介 4.1 风速模型风速模型在风力发电和其他…

2024-10-06 闲话

2024-10-06 闲话坐在电脑前 1 小时也什么都写不出来。 比如我现在住的地方(在一个房子里面)旁边有一个大冰块,因为这个大冰块在吸热所以我在家里感受到了无尽的寒冷。 于是我读了几本古圣先贤的书,合成了能烧来取暖的蜂窝煤。我又拿了根钻头把蜂窝煤点着了,尾气全部排到房…

报错集

报错集弹性云服务器ECS + 自动分配IP地址 + 配置安全组规则 + 配置并创建桶1.另外一个冲突的操作当前正作用在这个资源上,请等待一段时间后重试。 A conflicting conditional operation is currently in progress against this resource.Please try again 解决方案:桶的名称重…

云锵投资 2024 年 9 月简报

季报摘要行情:双重底结束,牛市启动;未来:长线看多; 期权策略:研发成功。节后正式上线,是未来的主要现金流策略; 微盘策略:非主流策略,三月连涨,未来长持; 本季度量化基金策略业绩:15.89%,优,全国排名:1858/11684;平均 Beta:1.00; 本季度量化股票策略业绩:3…

激活 Ultra Mobile Paygo

淘宝买一张 Ultra Mobile Paygo 电话卡(也叫做美国紫卡)(可选)在 NumberBarn 购买一个手机号。Plan 记得选 Port Away。打开 paygo.ultra.me/activate,填入卡面上的激活码,然后继续。填写相关信息。如果购买了手机号,选择 Transfer an Existing Number。未完待续