时序预测demo 代码快速实现 MLP效果比LSTM 好,简单模拟数据

【PyTorch修炼】用pytorch写一个经常用来测试时序模型的简单常规套路(LSTM多步迭代预测)

层数的理解:
LSTM(长短期记忆)的层数指的是在神经网络中堆叠的LSTM单元的数量。层数决定了网络能够学习的复杂性和深度。每一层LSTM都能够捕捉和记忆不同时间尺度的依赖关系,因此增加层数可以使网络更好地理解和处理复杂的序列数据。
在这里插入图片描述

LSTM方法:

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltimport torch
import torch.nn as nnx = torch.linspace(0, 999, 1000)
y = torch.sin(x*2*3.1415926/70)plt.xlim(-5, 1005)
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.title("sin")
plt.plot(y.numpy(), color='#800080')
plt.show()x = torch.linspace(0, 999, 1000)
y = torch.sin(x * 2 * 3.1415926 / 100) + 0.3 * torch.sin(x * 2 * 3.1415926 / 25) + 0.8 * np.random.normal(0, 1.5)plt.plot(y.numpy(), color='#800080')
plt.title("Sine-Like Time Series")
plt.xlabel('Time')
plt.ylabel('Value')
plt.show()train_y= y[:-70]
test_y = y[-70:]def create_data_seq(seq, time_window):out = []l = len(seq)for i in range(l-time_window):x_tw = seq[i:i+time_window]y_label = seq[i+time_window:i+time_window+1]out.append((x_tw, y_label))return out
time_window = 60
train_data = create_data_seq(train_y, time_window)class MyLstm(nn.Module):def __init__(self, input_size=1, hidden_size=128, out_size=1):super(MyLstm, self).__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size=input_size, hidden_size=self.hidden_size, num_layers=1, bidirectional=False)self.linear = nn.Linear(in_features=self.hidden_size, out_features=out_size, bias=True)self.hidden_state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))def forward(self, x):out, self.hidden_state = self.lstm(x.view(len(x), 1, -1), self.hidden_state)pred = self.linear(out.view(len(x), -1))return pred[-1]time_window = 60
train_data = create_data_seq(train_y, time_window)learning_rate = 0.00001
epoch = 13
multi_step = 70model=MyLstm()
mse_loss = nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate,betas=(0.5,0.999))for i in range(epoch):for x_seq, y_label in train_data:x_seq = x_seq y_label = y_label model.hidden_state = (torch.zeros(1, 1, model.hidden_size) ,torch.zeros(1, 1, model.hidden_size) )pred = model(x_seq)loss = mse_loss(y_label, pred)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {i} Loss: {loss.item()}")preds = []labels = []preds = train_y[-time_window:].tolist()for j in range(multi_step):test_seq = torch.FloatTensor(preds[-time_window:]) with torch.no_grad():model.hidden_state = (torch.zeros(1, 1, model.hidden_size) ,torch.zeros(1, 1, model.hidden_size) )preds.append(model(test_seq).item())loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))print(f"Performance on test range: {loss}")plt.figure(figsize=(12, 4))plt.xlim(700, 999)plt.grid(True)plt.plot(y.numpy(), color='#8000ff')plt.plot(range(999 - multi_step, 999), preds[-multi_step:], color='#ff8000')plt.show()

class SimpleMLP(nn.Module):def __init__(self, input_size=60, hidden_size=128, output_size=1):super(SimpleMLP, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return xmlp_model = SimpleMLP()
mse_loss = nn.MSELoss()
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.0001)
for i in range(epoch):for x_seq, y_label in train_data:x_seq = x_seqy_label = y_labelpred = mlp_model(x_seq)loss = mse_loss(y_label, pred)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {i} Loss: {loss.item()}")preds = []labels = []preds = train_y[-time_window:].tolist()for j in range(multi_step):test_seq = torch.FloatTensor(preds[-time_window:])with torch.no_grad():preds.append(mlp_model(test_seq).item())loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))print(f"Performance on test range: {loss}")plt.figure(figsize=(12, 4))plt.xlim(700, 999)plt.grid(True)plt.plot(y.numpy(), color='#8000ff')plt.plot(range(999 - multi_step, 999), preds[-multi_step:], color='#ff8000')plt.show()

生成的一个带些随机数的正弦波:y = torch.sin(x * 2 * 3.1415926 / 100) + 0.3 * torch.sin(x * 2 * 3.1415926 / 25) + 0.8 * np.random.normal(0, 1.5)

结果发现:MLP效果比LSTM好?!
MLP:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
偶然有不是很准,但大部分非常准

LSTM:
就很奇怪?

在这里插入图片描述
在这里插入图片描述

但是如果是纯正弦波 y = torch.sin(x23.1415926/70) ,规律太明显了,好像效果都还行:
MLP:
简单聪明的MLP第一轮就学会了
在这里插入图片描述
LSTM:
开始几轮还有些懵
在这里插入图片描述
后边就悟了
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/485069.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络安全-pikachu之文件上传漏洞2

进入到第二个文件上传漏洞,发现名字是MIME type,并且查看前端源代码没发现限制,所以是后段,盲猜通过抓包就可以绕过后段限制。 先知道MIME type是什么,通过查找资料发现是:Content-Type是返回消息中非常重…

SpringCloud(15)之SpringCloud Gateway

一、Spring Cloud Gateway介绍 Spring Cloud Gateway 是Spring Cloud团队的一个全新项目,基于Spring 5.0、SpringBoot2.0、 Project Reactor 等技术开发的网关。旨在为微服务架构提供一种简单有效统一的API路由管理方式。 Spring Cloud Gateway 作为SpringCloud生态…

jquery写组件滑动人机验证组件

jquery组件,虽然 jquery 语法古老,但是写好了用起来真的很爽啊,本文用滑动人机验证给大家做个详细教程(直接复制代码就可以用噢o(* ̄▽ ̄*)ブ) 第一步 先看下组件本身 component.js (function() {…

Rust介绍与开发环境搭建

安装rust rust 安装官方指南:[HTPS][3W].rust-lang.org/tools/install (自己替换 HTPS,3W) Linux或者Macbook上安装rust 打开终端并输入下面命令: #因审核问题下面链接需要替换一下 HTPS->httpscurl --tlsv1.2 [HTPS]://s…

Windows Server 2019修改网络位置为公用网络

Windows Server 2019修改网络位置与Windows有点点区别,特记录如下列图:

我是这样通过CATTI考试的,没办法,必须考!原创首发

2023年“侥幸”通过CATTI英语二级笔译。11月初考试,按官方原计划应该是2024年1月初公布考试成绩,但12月底就突然出分了。当时正好在上班,忙里偷闲登录网址、查分,没有想象中的那么激动,一切平淡如水。随后,…

基于springboot+vue的知识管理系统(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

Linux篇:进程

一. 前置知识 1.1冯诺依曼体系结构 我们常见的计算机,如笔记本。我们不常见的计算机,如服务器,大部分都遵守冯诺依曼体系 为什么计算机要采用冯诺依曼体系呢? 在计算机出现之前有很多人都提出过计算机体系结构,但最…

【国产MCU】-CH32V307-通用定时器(GPTM)-单脉冲模式

通用定时器(GPTM)-单脉冲模式 文章目录 通用定时器(GPTM)-单脉冲模式1、单脉冲模式介绍2、驱动API介绍3、单脉冲使用实例本文将详细介绍如何使用CH32V307通用定时器的单脉冲模式。 1、单脉冲模式介绍 单脉冲模式可以响应一个特定的事件,在一个延迟之后产生一个脉冲,延迟…

关于参数处理那点事,C标准库反汇编解析

关于参数处理那点事,C标准库反汇编解析 1 stdarg.h 内容概览 这个头文件用于提供访问无名参数(既没有命名也没有类型)的类型和宏。 假设函数形如: void functionWithMltipleInput(normalType n, ...)第一个参数名为n,后续省略号…

【C++精简版回顾】6.构造函数

一。类的四种初始化方式 1.不使用构造函数初始化类 使用函数引用来初始化类 class MM { public:string& getname() {return name;}int& getage() {return age;}void print() {cout << "name: " << name << endl << "age: &quo…

【2024软件测试面试必会技能】

Unittest(5)&#xff1a;unittest_忽略用例 忽略用例 在执行测试脚本的时候&#xff0c;可能会有某几条用例本次不想执行&#xff0c;但又不想删也 不想注释&#xff0c;unittest通过忽略部分测试用例不执行的方式&#xff0c;分无条件忽略和有条 件忽略,通过装饰器实现所描述…