算法学习笔记:Bi-LSTM和Bi-GRU

这篇文章的作为前几篇RNN\LSTM\RNN的后续之作,主要就是补充一个这两个哥的变体,想详细了解RNN\LSTM\GRU的详细理论和公式推导以及代码的请前往下面链接:

算法学习笔记:循环神经网络(Recurrent Neural Network)-CSDN博客

算法学习笔记:长短期记忆网络(Long Short Term Memory Network)-CSDN博客

算法学习笔记:门控循环单元(Gate Recurrent Unit)-CSDN博客

一、Bi-LSTM

Bi-LSTM(Bidirectional Long Short Term Memory)网络是是一种基于长短期记忆网络(LSTM)的时间序列预测方法;它结合了双向模型和LSTM的门控机制,由2个独立的LSTM网络构成。当Bi-LSTM处理序列数据时,输入序列会分别以正序和逆序输入到2个LSTM网络中进行特征提取,并将将2个输出向量(即提取后的特征向量)进行拼接后形成的输出向量作为该时间步的最终输出

(其实就是两个LSTM组合在一起,具体的原理和结构和LSTM一样啦)

Bi-LSTM的模型设计理念是使t时刻所获得特征数据同时拥有过去和将来之间的信息;此外,值得一提的是,Bi-LSTM中的2个LSTM网络参数是相互独立的,它们只共享同一批序列数据。

二、Bi-GRU

Bi-GRU(Bidirectional Gated Recurrent Unit)是一种基于门控循环单元(GRU)的时间序列预测方法;它结合了双向模型和门控机制,整体结构与单元体结构与GRU一致,因此也能够有效地捕捉时间序列数据中的时序关系。Bi-GRU的整体结构由两个方向的GRU网络组成,一个网络从前向后处理时间序列数据,另一个网络从后向前处理时间序列数据;这种双向结构可以同时捕捉到过去和未来的信息,从而更全面地建模时间序列数据中的时序关系。

(也就是两个GRU组会在一起啦,结构啥的都一样!)

三、Bi-LSTM和Bi-GRU源码

import numpy as np
from numpy import savetxt
import pandas as pd
from pandas.plotting import register_matplotlib_converters
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import r2_score
import tensorflow as tf
from tensorflow import keras
from keras.optimizers import Adam,RMSprop
import os
import tensorflow as tf
tf.config.set_visible_devices(tf.config.list_physical_devices('GPU'), 'GPU')
register_matplotlib_converters()
#plt.rcParams["font.sans-serif"] = [""]# 指定默认字体
pd.set_option('display.max_columns', None)  # 结果显示所有列
pd.set_option('display.max_rows', None)  # 结果显示所行行
#>>>>>>>>>>>>数据预处理
#1.训练集(New-train)数据处理
source = 'New-train.csv'
df_train = pd.read_csv(source, index_col=None)
df_train = df_train[['设置你的数据表头']] source = 'New-test.csv'
df_test = pd.read_csv(source, index_col=None)
df_test = df_test[['设置你的数据表头']] train_size = int(len(df_train))
test_size = int(len(df_test))
train = df_train.iloc[0:train_size]
test = df_test.iloc[0:test_size]
print(len(train), len(test))def training_data(X, y, time_steps=1):Xs, ys = [], []for i in range(len(X) - time_steps):v = X.iloc[i:(i + time_steps)].valuesXs.append(v)ys.append(y.iloc[i + time_steps])return np.array(Xs), np.array(ys)time_steps = 10
X_train, y_train = training_data(train.loc[:, '设置你的数据表头'], train.*, time_steps)
x_test, y_test = training_data(test.loc[:,'设置你的数据表头'], test.*, time_steps)#构建模型
# 单层双尾lstm
def model_BiLSTM(units):model = keras.Sequential()#Input Layermodel.add(keras.layers.Bidirectional(keras.layers.LSTM(units=units,activation="relu",input_shape=(X_train.shape[1], X_train.shape[2]))))model.add(keras.layers.Dropout(0.2))#Hidden Layermodel.add(keras.layers.Dense(1))model.compile(loss='mse', optimizer=Adam(learning_rate=0.001, clipvalue = 0.2))return model# 单层双尾GRU
def model_BiGRU(units):model = keras.Sequential()#inputmodel.add(keras.layers.Bidirectional(keras.layers.GRU(units=units,activation="relu",input_shape=(X_train.shape[1], X_train.shape[2]))))model.add(keras.layers.Dropout(0.2))model.add(keras.layers.Dense(1))model.compile(loss='mse', optimizer='adam')return model#训练模型
def fit_model(model):#早停机制,防止过拟合early_stop = keras.callbacks.EarlyStopping(monitor='val_loss',min_delta=0.0,#min_delta=0.0 表示如果训练过程中的指标没有发生任何改善,即使改善非常微小,也会被视为没有显著改善patience=2000)#表示如果在连续的 2000 个 epoch 中,指标没有超过 min_delta 的改善,训练将被提前停止history = model.fit(  # 在调用model.fit()方法时,模型会根据训练数据进行参数更新,并在训练过程中逐渐优化模型的性能X_train, y_train, # 当训练完成后,模型的参数就被更新为训练过程中得到的最优值epochs=400,         # 此时model已经是fit之后的model,直接model.predict即可(千万不要model=model.fit(),然后再model.predict)validation_split=0.1,   batch_size=12600,shuffle=False,callbacks=[early_stop])return historylstm_n64 = model_BiLSTM(64)
GRU = model_BiGRU(64)
#这里只是写了训练模型的代码,预测的话要根据自己的数据结构以及想要的效果来写喔

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/630331.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言进阶课程学习记录-函数指针的阅读

C语言进阶课程学习记录-函数指针的阅读 5个标识符含义解析技巧 本文学习自狄泰软件学院 唐佐林老师的 C语言进阶课程,图片全部来源于课程PPT,仅用于个人学习记录 5个标识符含义解析 int (*p1) (int* , int (*f) ( int* ) );定义了指针p1,指向函数&#…

2024年免费云服务器推荐,小编亲测好用!

随着云计算技术的飞速发展,云服务器以其弹性、高效、安全的特性,成为众多企业和个人用户的首选。尽管市面上有众多收费的云服务器产品,但免费的云服务器仍然吸引着大量用户,尤其是初学者和预算有限的用户。下面,我们就…

vue框架中的组件通信

vue框架中的组件通信 一.组件通信关系二.父子通信1.props 校验2.prop & data、单向数据流 二.非父子通信-event bus 事件总线三.非父子通信 (拓展) - provide & inject四.v-model简化父子通信代码五. .sync修饰符 一.组件通信关系 组件关系分类: 1.父子关系…

C++修炼之路之反向迭代器和非模板参数,模板特化,分离编译

目录 前言 一:反向迭代器 二:非类型模板参数 三:模板的特化 四:模板的分离编译 五:模板的优点与缺点 接下来的日子会顺顺利利,万事胜意,生活明朗-----------林辞忧 前言 在vector&am…

AIDE:自动驾驶目标检测的自动数据引擎

AIDE:自动驾驶目标检测的自动数据引擎 摘要IntroductionRelated WorksMethodData FeederModel Updater4 Experiments 摘要 自动驾驶车辆(AV)系统依赖于健壮的感知模型作为安全保证的基石。然而,道路上遇到的物体表现出长尾分布&a…

selenium 下载文件取消安全下载的方法

问题描述 我要从一个网站上下载文件,谷歌浏览器总是自动阻止下载,并询问我是否保留。 可是,我想要的是不要询问,默认下载即可。 运行环境 OS: macOSselenium: 4.19.0python: 3.10.11Chrome: 124.0.6367.62selenium chromedrive…

(最详细)关于List和Set的区别与应用

关于List与Set的区别 List和Set都继承自Collection接口; List接口的实现类有三个:LinkedList、ArrayList、Vector。Set接口的实现类有两个:HashSet(底层由HashMap实现)、LinkedHashSet。 在List中,List.add()是基于数组的形式来添…

C语言链表讲解

链表的概念与结构 链表是一种物理存储非连续,非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的。 如图所示: 链表通过指针域把一个一个节点链接起来,而最后一个节点的指针域指向NULL,表示到头了。 链表与顺序表的对比 链表是一种…

HTML快速入门

HTML简介 HTML(超文本标记语言)是一种用于创建网页和Web应用程序的标记语言。它由一系列标签组成,每个标签通过尖括号来定义,并用于标记文本、图像、链接和其他内容。HTML标签描述了网页中的信息结构和布局,并定义了文…

变配电场所智能综合监控系统无人化与自动化升级改造

一 项目背景 国家电力建设飞速发展,为了提高管理水平,智能化建设迫在眉睫。变配电场所作为电网中的核心单元,数量巨大,是智能化建设的中坚部分。但由于变配电场所分布的地理位置过于分散,且配电网的自动化水平有待提高,单纯依靠人力来对变配电场所进行巡视,不仅增加…

WdatePicker异常,无法弹出日期选择框

官网:My97日期控件官方网站 My97 DatePickerhttp://www.my97.net/ 可能使版本太老了,可以更新一下,然后根据官方的文件进行使用。 我的异常是因为在网上找的包里面缺少文件,去官网拉了一下最新的就行了。

Linux系统编程---进程间通信IPC(一)

一、进程间通信IPC(InterProcess Communication) 进程间通信的常用方式,特征: 1. 管道:简单 2. 信号:开销小 3. 共享存储映射(mmap)映射:非血缘关系进程间 4. socket(本地套接字):最…