时间序列预测——GRU模型

时间序列预测——GRU模型

在深度学习领域,循环神经网络(RNN)是处理时间序列数据的一种常见选择。上期已介绍了LSTM的单步和多步预测。本文将深入介绍一种LSTM变体——门控循环单元(GRU)模型,包括其理论基础、公式、优缺点,并通过Python实现单步预测的示例。同时,将与长短时记忆网络(LSTM)进行比较,以更好地理解GRU的特性。

1. 引言

循环神经网络(RNN)是一类专门用于处理序列数据的神经网络。然而,传统的RNN存在梯度消失和梯度爆炸等问题,这导致了对长序列的有效建模变得困难。为了解决这些问题,门控循环单元(GRU)被提出。

2. GRU模型的理论

2.1 简介

GRU cell

门控循环单元(GRU)是由Cho等人于2014年提出的,旨在解决长短时记忆网络(LSTM)的一些问题。与LSTM相似,GRU也具有长期依赖性建模的能力,但其结构更加简单。GRU通过更新门和重置门来控制信息的流动,减少了参数数量,使得训练更加高效。

2.2 GRU的结构

GRU由两个门控制:更新门(Update Gate)和重置门(Reset Gate)。与LSTM不同,GRU没有细胞状态,而是直接使用隐藏状态。

GRU的隐藏状态更新公式为:

h t = ( 1 − z t ) ⊙ h t − 1 + z t ⊙ h ~ t \begin{equation} h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \end{equation} ht=(1zt)ht1+zth~t

其中:

  • h t h_t ht是当前时间步的隐藏状态。
  • z t z_t zt 是更新门的输出。
  • ⊙ \odot 是逐元素相乘操作。
  • h ~ t \tilde{h}_t h~t 是当前时间步的候选隐藏状态。

2.3 更新门和重置门

更新门(Update Gate)和重置门(Reset Gate)的计算分别为:

z t = σ ( W z ⋅ [ h t − 1 , x t ] ) \begin{equation} z_t = \sigma(W_z \cdot [h_{t-1}, x_t]) \end{equation} zt=σ(Wz[ht1,xt])

r t = σ ( W r ⋅ [ h t − 1 , x t ] ) \begin{equation} r_t = \sigma(W_r \cdot [h_{t-1}, x_t]) \end{equation} rt=σ(Wr[ht1,xt])
其中:

  • W z W_z Wz W r W_r Wr 是权重矩阵。
  • σ \sigma σ 是sigmoid激活函数。
  • [ h t − 1 , x t ] [h_{t-1}, x_t] [ht1,xt] 是当前时间步的隐藏状态和输入拼接而成的向量。

2.4 候选隐藏状态

候选隐藏状态(Candidate Hidden State)的计算为:

h ~ t = tanh ⁡ ( W ⋅ [ r t ⊙ h t − 1 , x t ] ) \begin{equation} \tilde{h}_t = \tanh(W \cdot [r_t \odot h_{t-1}, x_t]) \end{equation} h~t=tanh(W[rtht1,xt])

其中:

  • W W W 是权重矩阵。

3. GRU模型与LSTM的区别

GRU与LSTM有相似之处,都采用了门控制机制,但它们在结构上存在一些区别。

  • 参数数量:GRU的参数数量相对较少,因为它没有细胞状态,直接使用隐藏状态。
  • 计算效率:由于参数较少,GRU在训练和预测时通常更加高效。
  • 表达能力:LSTM的细胞状态允许更好地保留和传递信息,适用于更复杂的序列建模任务。但在某些场景下,GRU由于其简单性能够表达一些简单序列的依赖关系。

4. Python实现GRU的单步预测

接下来,将使用Python和深度学习库Keras实现GRU的单步预测。将使用一个简单的时间序列数据集,以便清晰展示模型的训练和预测过程。

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense# 创建示例时间序列数据
np.random.seed(42)
data = np.arange(0, 100, 0.1)
noise = np.random.normal(0, 1, len(data))
data += noise# 准备训练数据
seq_length = 10
x, y = [], []for i in range(len(data) - seq_length):x.append(data[i:i + seq_length])y.append(data[i + seq_length])x = np.array(x)
y = np.array(y)x = x.reshape((x.shape[0], x.shape[1], 1))# 构建GRU模型
model = Sequential()
model.add(GRU(50, activation='relu', input_shape=(seq_length, 1)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')# 训练GRU模型
model.fit(x, y, epochs=50, verbose=0)# 使用训练好的模型进行单步预测
input_data = data[-seq_length:].reshape((1, seq_length, 1))
predicted_value = model.predict(input_data, verbose=0)# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(data, label='Original Data')
plt.scatter(len(data) - 1, predicted_value, color='red', marker='o', label='GRU Prediction (Single Step)')
plt.title('GRU Model - Single Step Prediction')
plt.legend()
plt.show()

多步预测其实就是修改输入输出的维度,这里不再赘述,可参考LSTM的单步和多步预测。

6. 总结

本文深入介绍了GRU模型的理论基础和相关公式,分析了其优缺点,并通过Python实现了单步预测的示例。GRU作为一种高效而强大的深度学习模型,在时间序列预测中展现了出色的性能。在实际应用中,可以根据具体任务的要求进行调整和优化,以达到更好的预测效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/448669.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NSFCdownload 国自然结题报告下载速度慢问题修复

最近有人反应国自然结题报告下载速度慢,大部分人出的问题都是在软件启动的时候,卡在那一直不动,卡的时间过长,以后就提示下载失败了。如下图所示,光标在这里,一直不往下走。 小编也是收到这个反馈以后&…

如何使用本地私有NuGet服务器

写在前面 上一篇介绍了如何在本地搭建一个NuGet服务器, 本文将介绍如何使用本地私有NuGet服务器。 操作步骤 1.新建一个.Net类库项目 2.打包类库 操作后会生成一个.nupkg文件,当然也可以用dotnet pack命令来执行打包。 3.推送至本地NuGet服务器 打开命…

LeetCode15. 三数之和

15. 三数之和 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请 你返回所有和为 0 且不重复的三元组。 **注意:**答案中不可以包含重复…

指针的学习1

目录 什么是指针? 野指针 造成野指针的原因: 如何避免野指针? 内存和指针 如何理解编址? 指针变量和地址 取地址操作符& 指针变量和解引用操作符 指针变量 如何拆解指针类型? 指针变量的大小 指针变量…

LeetCode.189. 轮转数组

题目 题目链接 分析 首先能想到的就是可以用一个新数组,先保存原数组的后 k 个元素,再保存原数组的前 n−k 个元素。但题目要求不使用额外的数组空间,那么就需要在原数组上做操作。 我们可以先把整个数组翻转一下,这样后半段元…

蓝桥杯---煤球数目

有一堆煤球,堆成三角棱锥形。具体: 第一层放1个, 第二层3个(排列成三角形), 第三层6个(排列成三角形), 第四层10个(排列成三角形), 如果一共有100层,共有多少个煤球? 请填表示煤球总数目的数字. 注意:你提交的应该是一个整数,不要填写任何多余的内容或说明性文字. 代码 pu…

Maven高级知识——分模块开发、继承与聚合

目录 一、分模块设计与开发 1.1 不分模块的问题 1.2 分模块设计 二、 继承与聚合 2.1 继承 2.1.1 继承关系 2.1.2 版本锁定 2.1.2.1 场景 2.1.2.2 介绍 2.1.2.3 实现 2.1.2.4 属性配置 2.2 聚合 2.2.1 介绍 2.2.2 实现 2.3 继承与聚合对比 三、Maven打包方式(jar、w…

2023 OpenHarmony 年度运营报告

汇聚 70 家企业 6700名贡献者力量, OpenHarmony 已成为下一代智能终端操作系统根社区; 我们在成长,OpenHarmony 项目群成员单位增至 35 家; 2023 年持续迭代更新 6 个版本及 OpenHarmony4.0 重点特性简介……

哈希表——C++

目录 一、首先使用拉链法: 二、开放寻址法 三、字符串哈希 1.具体如何使用进制的方式来存储字符前缀的可以看这个y总的这个图 2.接下来说一说算某个中间的区间的字符串哈希值 哈希表是一种数组之间互相映射的数据结构,比如举个简单的例子一个十个的数…

单细胞scATAC-seq测序基础知识笔记

单细胞scATAC-seq测序基础知识笔记 单细胞ATAC测序前言scATAC-seq数据怎么得出的? 该笔记来源于 Costa Lab - Bioinformatics Course 另一篇关于scRNA-seq的请移步 单细胞ATAC测序前言 因为我的最终目的是scATAC-seq的数据,所以这部分只是分享下我刚学…

2024 Flutter 重大更新,Dart 宏(Macros)编程开始支持,JSON 序列化有救

说起宏编程可能大家并不陌生,但是这对于 Flutter 和 Dart 开发者来说它一直是一个「遗憾」,这个「遗憾」体现在编辑过程的代码修改支持上,其中最典型的莫过于 Dart 的 JSON 序列化。 举个例子,目前 Dart 语言的 JSON 序列化高度依…

【DDD】学习笔记-代码模型的架构决策

代码模型属于软件架构的一部分,它是设计模型的进化与实现,体现出了代码模块(包)的结构层次。在架构视图中,代码模型甚至会作为其中的一个视图,通过它来展现模块的划分,并定义运行时实体与执行视…