风速预测(四)基于Pytorch的EMD-Transformer模型

目录

前言

1 风速数据EMD分解与可视化

1.1 导入数据

1.2 EMD分解

2 数据集制作与预处理

2.1 先划分数据集,按照8:2划分训练集和测试集

2.2 设置滑动窗口大小为7,制作数据集

3 基于Pytorch的EMD-Transformer模型预测

3.1 数据加载,训练数据、测试数据分组,数据分batch

3.2 定义EMD-Transformer预测模型

3.3 定义模型参数

3.4 模型训练

3.5 结果可视化


往期精彩内容:

风速预测(一)数据集介绍和预处理

风速预测(二)基于Pytorch的EMD-LSTM模型

风速预测(三)EMD-LSTM-Attention模型

前言

本文基于前期介绍的风速数据(文末附数据集),先经过经验模态EMD分解,然后通过数据预处理,制作和加载数据集与标签,最后通过Pytorch实现EMD-Transformer模型对风速数据的预测。风速数据集的详细介绍可以参考下文:

风速预测(一)数据集介绍和预处理-CSDN博客

1 风速数据EMD分解与可视化

1.1 导入数据

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc("font", family='Microsoft YaHei')# 读取已处理的 CSV 文件
df = pd.read_csv('wind_speed.csv' )
# 取风速数据
winddata = df['Wind Speed (km/h)'].tolist()
winddata = np.array(winddata) # 转换为numpy
# 可视化
plt.figure(figsize=(15,5), dpi=100)
plt.grid(True)
plt.plot(winddata, color='green')
plt.show()

1.2 EMD分解

from PyEMD import EMD# 创建 EMD 对象
emd = EMD()
# 对信号进行经验模态分解
IMFs = emd(winddata)# 可视化
plt.figure(figsize=(20,15))
plt.subplot(len(IMFs)+1, 1, 1)
plt.plot(winddata, 'r')
plt.title("原始信号")for num, imf in enumerate(IMFs):plt.subplot(len(IMFs)+1, 1, num+2)plt.plot(imf)plt.title("IMF "+str(num+1), fontsize=10)
# 增加第一排图和第二排图之间的垂直间距
plt.subplots_adjust(hspace=0.8, wspace=0.2)
plt.show()

2 数据集制作与预处理

2.1 先划分数据集,按照8:2划分训练集和测试集

2.2 设置滑动窗口大小为7,制作数据集

3 基于Pytorch的EMD-Transformer模型预测

3.1 数据加载,训练数据、测试数据分组,数据分batch

# 加载数据
import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 加载数据集
def dataloader(batch_size, workers=2):# 训练集train_set = load('train_set')train_label = load('train_label')# 测试集test_set = load('test_set')test_label = load('test_label')# 加载数据train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_set, train_label),batch_size=batch_size, num_workers=workers, drop_last=True)test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_set, test_label),batch_size=batch_size, num_workers=workers, drop_last=True)return train_loader, test_loaderbatch_size = 64
# 加载数据
train_loader, test_loader = dataloader(batch_size)

3.2 定义EMD-Transformer预测模型

注意:输入风速数据形状为 [64, 10, 7], batch_size=64,  维度10维代表10个分量,7代表序列长度(滑动窗口取值)。

3.3 定义模型参数

# 定义模型参数
batch_size = 64
input_len = 7     # 输入序列长度为7 (窗口值)
input_dim = 10    # 输入维度为10个分量
hidden_dim = 100  # Transformer隐层维度
num_layers = 4   # 编码器层数
num_heads = 2   # 多头注意力头数
output_size = 1 # 单步输出model = EMDTransformerModel(batch_size, input_len, input_dim, hidden_dim, num_layers, num_heads, output_size=1)  # 定义损失函数和优化函数 
model = model.to(device)
loss_function = nn.MSELoss()  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

3.4 模型训练

训练结果

采用两个评价指标:MSE 与 MAE 对模型训练进行评价,100个epoch,MSE 为0.01627,MAE  为 0.0005549,EMD-Transformer预测效果良好,适当调整模型参数,还可以进一步提高模型预测表现。EMD-Transformer参数量不到LSTM模型的十分之一,效果相近,可见EMD-Transformer性能的优越性。

注意调整参数:

  • 可以适当增加Transformer堆叠编码器层数和隐藏层的维度,微调学习率;

  • 调整多头注意力头数,增加更多的 epoch (注意防止过拟合)

  • 可以改变滑动窗口长度(设置合适的窗口长度)

3.5 结果可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/279024.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

现代雷达车载应用——第2章 汽车雷达系统原理 2.6节 雷达设计考虑

经典著作,值得一读,英文原版下载链接【免费】ModernRadarforAutomotiveApplications资源-CSDN文库。 2.6 雷达设计考虑 上述部分给出了汽车雷达基本原理的简要概述。在雷达系统的设计中,有几个方面是必不可少的,它们决定了雷达系…

redis:四、双写一致性的原理和解决方案(延时双删、分布式锁、异步通知MQ/canal)、面试回答模板

双写一致性 场景导入 如果现在有个数据要更新,是先删除缓存,还是先操作数据库呢?当多个线程同时进行访问数据的操作,又是什么情况呢? 以先删除缓存,再操作数据库为例 多个线程运行的正常的流程应该如下…

从零开始:前端架构师的基础建设和架构设计之路

文章目录 一、引言二、前端架构师的职责三、基础建设四、架构设计思想五、总结《前端架构师:基础建设与架构设计思想》编辑推荐内容简介作者简介目录获取方式 一、引言 在现代软件开发中,前端开发已经成为了一个不可或缺的部分。随着互联网的普及和移动…

Istio Wasm插件

目录 本节实战 实战名称🚩 实战:使用EnvoyFilter部署Wasm插件-2023.12.16(测试成功)🚩 实战:使用WasmPlugin部署Wasm插件-2023.12.16(测试成功) 原文链接 Istio Wasm插件 https://onedayxyy.cn/docs/Istio-Wasm 前言 WebAsse…

黑马头条--day01.环境搭建

一.前言 该项目学习自黑马程序员,由我整理如下,版权归黑马程序员所有 二.环境搭建 1.数据库 第一天,先创建如下库和表: sql文件如下: CREATE DATABASE IF NOT EXISTS leadnews_user DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_…

【论文阅读笔记】序列数据的数据增强方法综述

【论文阅读笔记】序列数据的数据增强方法综述 摘要 这篇论文探讨了在深度学习模型中由于对精度的要求不断提高导致模型框架结构变得更加复杂和深层的趋势。随着模型参数量的增加,训练模型需要更多的数据,但人工标注数据的成本高昂,且由于客观…

国际教育-微积分试讲讲稿

Substitution for Integration-Notes换元积分法

Leetcode—2413.最小偶倍数【简单】

2023每日刷题(六十) Leetcode—2413.最小偶倍数 class Solution { public:int smallestEvenMultiple(int n) {return (n % 2 1) * n;} };运行结果 之后我会持续更新,如果喜欢我的文章,请记得一键三连哦,点赞关注收藏…

2024年20多个最有创意的AI人工智能点子

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版,欢迎购买。点击进入详情 探索 2024 年将打造的 20 个基于人工智能产品的盈利创意 🔥🔥🔥 直到最近,企业对人工智能还不感兴趣,但…

迈入数据结构殿堂——时间复杂度和空间复杂度

目录 一,算法效率 1.如何衡量一个算法的好坏? 2.算法效率 二,时间复杂度 1.时间复杂度的概念 2.大O的渐进表示法 3.推导大O的渐进表示法 4.常见时间复杂度举例 三,空间复杂度 一,算法效率 数据结构和算法是密…

解决VSCode打开终端Terminal闪退的问题

一、背景 在新电脑上使用了VSCode,但是一打开Terminal,Terminal马上就消失了,在网上找了很久,都没有找到对应的分析 二、解决思路 首先,是从这个文档中找到了灵感,这个文档里面汇集了大部分的问题&#…

关于嵌入式开发的一些信息汇总:C标准、芯片架构、编译器、MISRA-C

关于嵌入式开发的一些信息汇总:C标准、芯片架构、编译器、MISRA-C 关于C标准芯片架构是什么?架构对芯片有什么作用?arm架构X86架构mips架构小结 编译器LLVM是什么?前端在干什么?后端在干什么? MISRA C的诞生…