【算法】长短期记忆网络(LSTM,Long Short-Term Memory)

这是一种特殊的循环神经网络,能够学习数据中的长期依赖关系,这是因为模型的循环模块具有相互交互的四个层的组合,它可以记忆不定时间长度的数值,区块中有一个gate能够决定input是否重要到能被记住及能不能被输出output。

原理

黄色方框内是四个神经网络层,红色圆圈是逐点算子,橙色圆圈是输入,蓝色圆圈是细胞状态。LSTM具有一个单元状态和三个门,对应选择有选择地学习、取消学习或保留来自每个单元的信息的能力。

LSTM中的单元状态通过只允许一些线性交互来帮助信息流过单元而不被改变。

每个单元都有一个输入、输出和一个遗忘门,可以将信息添加或者删除到单元状态。

在这里插入图片描述

遗忘门:使用sigmoid函数决定应该忘记来自先前单元状态的哪些信息。

输入门:分别使用sigmoid和tanh的逐点乘法运算控制信息流到当前单元状态。

输出门:最后,输出门决定哪些信息应该传递到下一个隐藏状态。

要在python中使用lstm模型,需要安装这些库:

pip install tansorflow pandas numpy matplotlib# pandas用来数据处理
# numpy用来数值计算
# matplotlib.pyplot用于数据可视化
# MinMaxScaler从sklearn.preprocessing用于数据规范化
# Sequential,LSTM,Dense从tensorflow.keras用于构建神经网络
# mean_squared_error从sklearn.metrics用于计算模型误差

实现

  1. 生成示例数据:简单的正弦波形,
  2. 设置随机数生成的种子,确保结果可以复现,
  3. 生成一系列时间步长,
  4. 创建数据,结合正弦波和随机噪声
  5. 数据转换为DataFrame
  6. 使用Pandas的DataFrame来存储和处理生成的数据,
  7. 数据规范化:使用MinMaxScaler将数据规范化到0和1之间,这对神经网络的性能至关重要。
  8. 分割数据为训练集和测试集:确定训练集的大小(数据的80%),剩余的20%s数据作为测试集。
  9. 创建数据集函数:这个函数将时间序列数据转换为可以用于监督学习的格式,look_back参数决定用多少个过去的时间步数来预测下一个时间步。
  10. 设置look_back,并创建训练/测试数据;
  11. 使用1作为look_back的值
  12. 重塑输入数据为[样本,时间步,特征]
  13. LSTM模型在keras中需要三维输入
  14. 创建LSTM模型:创建一个Sequential模型,添加一个含有50个神经元的LSTM层,添加一个Dense层作为输出层,编译模型,使用均方误差作为损失函数和Adam优化器。

注:epoch是指训练周期。

代码如下:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from sklearn.metrics import mean_squared_error# 生成示例数据:正弦波 + 随机噪声
np.random.seed(0)
timesteps = np.arange(0, 1000, 0.1)
data = np.sin(timesteps) + np.random.normal(scale=0.5, size=len(timesteps))# 数据转换为DataFrame
df = pd.DataFrame(data, columns=['value'])
values = df['value'].values# 数据规范化
scaler = MinMaxScaler(feature_range=(0, 1))
values_scaled = scaler.fit_transform(values.reshape(-1, 1))# 分割数据为训练集和测试集
train_size = int(len(values_scaled) * 0.8)
test_size = len(values_scaled) - train_size
train, test = values_scaled[0:train_size, :], values_scaled[train_size:len(values_scaled), :]# 创建数据集
def create_dataset(dataset, look_back=1):X, Y = [], []for i in range(len(dataset) - look_back - 1):a = dataset[i:(i + look_back), 0]X.append(a)Y.append(dataset[i + look_back, 0])return np.array(X), np.array(Y)look_back = 1
X_train, Y_train = create_dataset(train, look_back)
X_test, Y_test = create_dataset(test, look_back)# 重塑输入数据为 [样本, 时间步, 特征]
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))# 创建LSTM模型
model = Sequential()
model.add(LSTM(50, input_shape=(1, look_back)))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')# 训练模型
model.fit(X_train, Y_train, epochs=5, batch_size=1, verbose=2)# 进行预测
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)# 反转规范化
train_predict = scaler.inverse_transform(train_predict)
Y_train = scaler.inverse_transform([Y_train])
test_predict = scaler.inverse_transform(test_predict)
Y_test = scaler.inverse_transform([Y_test])# 计算均方误差
train_score = np.sqrt(mean_squared_error(Y_train[0], train_predict[:,0]))
test_score = np.sqrt(mean_squared_error(Y_test[0], test_predict[:,0]))# 可视化
plt.figure(figsize=(12, 6))
plt.plot(scaler.inverse_transform(values_scaled), label='Original Data')
plt.plot(np.append(np.zeros(train_size), train_predict[:,0]), linestyle='--', label='Training Predict')
plt.plot(np.append(np.zeros(train_size), test_predict[:,0]), linestyle='--', label='Test Predict')
plt.legend()
plt.show()

运行图如下:

在这里插入图片描述

观测

训练损失逐渐降低并趋于稳定,意味着模型正在从训练数据中学习。

在训练集和测试集上的评估速度很快,意味着模型的推断(预测)效率很高。

如果损失在后续的epoch中没有显著下降,可能意味着模型需要更多的epoch来训练。或者可能需要调整模型的结构或超参数(例如增加神经元数量、改变学习率)以进一步提高性能。

训练集和测试集的RMSE非常接近,说明模型在两者上的性能是一致的。

没有出现过拟合/欠拟合的迹象,则说明模型的泛化能力良好。

考虑到数据生成时添加了随机噪声,这个RMSE值表明模型在捕捉数据的基本趋势方面表现的不错,相对较小的RMSE表示预测的准确。

如果要改进LSTM,可以从贝叶斯超参数调优、增加更多训练周期(EPOCH)、尝试不同网络架构,或者在数据预处理时更复杂一些。

事实上,在RMSE上,SARIMA比LSTM更小,但非线性模式/利用长期依赖性的复杂时间序列数据时,LSTM更好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/501496.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024深圳国际微波射频技术展览会暨华南微波射频技术研讨会

2024深圳国际微波射频技术及应用展览会 2024年4月9日至11日,一场全球瞩目的科技盛宴将在深圳会展中心(福田)隆重举行。这便是2024深圳国际微波射频技术及应用展览会,一个汇集了国际顶尖微波射频技术厂商、科研机构和行业精英的盛会…

【YOLO v5 v7 v8 小目标改进】RFB:组合不同大小的卷积核和扩张卷积来模拟人类视觉感受野的多尺度特性

RFB:组合不同大小的卷积核和扩张卷积来模拟人类视觉感受野的多尺度特性 提出背景RFB 原理空间感受野结构RFB-Net 小目标涨点YOLO v5 魔改YOLO v7 魔改YOLO v8 魔改 提出背景 当前表现最好的目标检测器依赖于深层CNN骨干网络,如ResNet-101和Inception&am…

记录 android studio 通过安装NDK 编译C文件,得到需要的so文件

只怪自己太健忘,每次网上查了一圈,搞定后,再遇到又发现不会操作了,特此记下 不废话直接上步骤 (1) 进入AS的settinging如下界面 (2)选中图片箭头两个文件 进行下载 (…

jstat命令查看jvm的GC信息

文章目录 前言jstat命令查看jvm的GC信息1. 概述2. 应用堆内存水位阀值大小怎么确定3. 使用 jps 命令查看 Java 进程的进程号(PID)![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/5097401443314e9d808a83b694dbc6e5.png)4. jstat用法5. 类加载…

Qt样式表应用

1. Qt样式表概述 对Qt样式表应用毫无基础的可以先了解下Qt官方对样式表的说明:https://doc.qt.io/qt-5/stylesheet.html 简单来说就是给Qt的控件"化妆", 同样的控件可以通过改变样式表的设置表现出不同的效果,这些效果主要包括颜色、字体、线条…

光伏储能MPPT控制系统如何进行浪涌静电保护?

MPPT(Maximum Power Point Tracking)是太阳能电池板光伏发电系统中重要的一种控制技术。MPPT控制器能够实时侦测太阳能板的发电电压,并追踪最高电压电流值(VI),使系统以最大功率输出对蓄电池充电&#xff0…

【Linux C | 网络编程】gethostbyaddr 函数详解及C语言例子

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 🤣本文内容🤣&a…

力扣:125. 验证回文串

力扣:125. 验证回文串 描述 如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后,短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 字母和数字都属于字母数字字符。 给你一个字符串 s,如果它是 回文串 &…

useState多次渲染页面卡顿 useMemo

useState多次渲染页面卡顿 state变化了组件自然应该重新进行渲染,但有时我们并不需要。 React.memo()(useMemo)是一个高阶组件,它接收另一个组件作为参数,并且会返回一个包装过的新组件,包装过的新组件就会具有缓存作用&#xff…

淘金优化算法GRO求解不闭合SD-MTSP,可以修改旅行商个数及起点(提供MATLAB代码)

一、淘金优化算法GRO 淘金优化算法(Gold rush optimizer,GRO)由Kamran Zolf于2023年提出,其灵感来自淘金热,模拟淘金者进行黄金勘探行为。淘金优化算法(Gold rush optimizer,GRO)提…

基于JavaWeb实现的药店管理系统

一、系统架构 前端:jsp | layui | jquery | css 后端:spring | springmvn | mybatis 环境:jdk1.8 | mysql 二、代码及数据库 三、功能介绍 01. 登录 02. 首页 03. 药品管理 04. 销售管理-销售记录管理 05. 销售管理-退…

使用Fabric创建的canvas画布背景图片,自适应画布宽高

之前的文章写过vue2使用fabric实现简单画图demo,完成批阅功能;但是功能不完善,对于很大的图片就只能显示一部分出来,不符合我们的需求。这就需要改进,对我们设置的背景图进行自适应。 有问题的canvas画布背景 修改后的…