机器学习入门--门控循环单元(GRU)原理与实践

GRU模型

随着深度学习领域的快速发展,循环神经网络(RNN)已成为自然语言处理(NLP)等领域中常用的模型之一。但是,在RNN中,如果时间步数较大,会导致梯度消失或爆炸的问题,这影响了模型的训练效果。为了解决这个问题,研究人员提出了新的模型,其中GRU是其中的一种。

本文将介绍GRU的数学原理、代码实现,并通过pytorch和sklearn的数据集进行试验,最后对该模型进行总结。

数学原理

GRU是一种门控循环单元(Gated Recurrent Unit)模型。与传统的RNN相比,它具有更强的建模能力和更好的性能。

重置门和更新门

在GRU中,每个时间步有两个状态:隐藏状态 h t h_t ht和更新门 r t r_t rt。。更新门控制如何从先前的状态中获得信息,而隐藏状态捕捉序列中的长期依赖关系。

GRU的核心思想是使用“门”来控制信息的流动。这些门是由sigmoid激活函数控制的,它们决定了哪些信息被保留和传递。
在每个时间步 t t t,GRU模型执行以下操作:

1.计算重置门
r t = σ ( W r [ x t , h t − 1 ] ) r_t = \sigma(W_r[x_t, h_{t-1}]) rt=σ(Wr[xt,ht1])
其中, W r W_r Wr是权重矩阵, σ \sigma σ表示sigmoid函数。重置门 r t r_t rt告诉模型是否要忽略先前的隐藏状态 h t − 1 h_{t-1} ht1,并只依赖于当前输入
x t x_t xt

2.计算更新门
z t = σ ( W z [ x t , h t − 1 ] ) z_t = \sigma(W_z[x_t, h_{t-1}]) zt=σ(Wz[xt,ht1])
其中,更新门 z t z_t zt告诉模型新的隐藏状态 h t h_t ht在多大程度上应该使用先前的状态 h t − 1 h_{t-1} ht1

候选隐藏状态和隐藏状态

在计算完重置门和更新门之后,我们可以计算候选隐藏状态 h ~ t \tilde{h}_{t} h~t和隐藏状态 h t h_t ht

1.计算候选隐藏状态
h ~ t = tanh ⁡ ( W [ x t , r t ∗ h t − 1 ] ) \tilde{h}_{t} = \tanh(W[x_t, r_t * h_{t-1}]) h~t=tanh(W[xt,rtht1])
其中, W W W是权重矩阵。候选隐藏状态 h ~ t \tilde{h}_{t} h~t利用当前输入 x t x_t xt和重置门 r t r_t rt来估计下一个可能的隐藏状态。

2.计算隐藏状态
h t = ( 1 − z t ) ∗ h t − 1 + z t ∗ h ~ t h_{t} = (1 - z_t) * h_{t-1} + z_t * \tilde{h}_{t} ht=(1zt)ht1+zth~t
这是GRU的最终隐藏状态公式。它在候选隐藏状态 h ~ t \tilde{h}_{t} h~t和先前的隐藏状态 h t h_t ht之间进行加权,其中权重由更新门 z t z_t zt控制。

代码实现

下面是使用pytorch和sklearn的房价数据集实现GRU的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义GRU模型
class GRUNet(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUNet, self).__init__()self.hidden_size = hidden_sizeself.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.gru(x)out = self.fc(out[:, -1, :])return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = GRUNet(input_size, hidden_size, output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of GRU Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码首先加载并标准化房价数据集,然后定义了一个包含GRU层和全连接层的GRUNet模型,并使用均方误差作为损失函数和Adam优化器进行训练。训练完成后,使用matplotlib库绘制损失曲线(如下图所示),并使用训练好的模型对新的数据点进行预测。
GRU 损失曲线

总结

GRU是一种门控循环单元模型,它通过更新门和重置门,有效地解决了梯度消失或爆炸的问题。在本文中,我们介绍了GRU的数学原理、代码实现和代码解释,并通过pytorch和sklearn的房价数据集进行了试验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/471390.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

精品jsp+ssm汽车店维保信息系统

《[含文档PPT源码等]精品jspssm汽车店维保信息系统[包运行成功]》该项目含有源码、文档、PPT、配套开发软件、软件安装教程、项目发布教程、包运行成功! 使用技术: 开发语言:Java 框架:ssm 技术:JSP JDK版本&…

VMwareWorkstation17.0虚拟机安装Windows2.03完整详细步骤图文教程

VMwareWorkstation17.0虚拟机安装Windows2.03完整详细步骤图文教程 第一篇 下载Windows2.03第二篇 配置Windows2.03虚拟机机器环境第三篇 启动Windows2.03系统 第一篇 下载Windows2.03 1.Windows2.0原版软盘下载地址是 暂不提供,后续更新 2.Windows2.03虚拟机镜像下…

Ps:注释工具

注释工具 Note Tool可以在图像上添加文字注释。通过注释,用户可以直接在图像文件内部留下具体的说明、建议或是任何重要的信息。 这对于设计反馈、协作指导、学习笔记或个人备忘录等场景特别有用。 快捷键:I 注释在图像上显示为不可打印的小图标。它们与…

【C++】友元、内部类和匿名对象

💗个人主页💗 ⭐个人专栏——C学习⭐ 💫点击关注🤩一起学习C语言💯💫 目录 1. 友元 1.1 友元函数 1.2 友元类 2. 内部类 2.1 成员内部类 2.2 局部内部类 3. 匿名对象 1. 友元 友元提供了一种突破封装…

红队学习笔记Day5 --->总结

今天先不讲新知识,来小小的复习一下 1.8888?隧道端口你怎么回事 在做隧道和端口转发的时候,我们常见的是通过一台跳板机,让外网的机器去远程连接到内网的一些机器,这时候就常见一些这样的命令 以防忘了,先…

Redis篇----第一篇

系列文章目录 文章目录 系列文章目录前言一、什么是 Redis?二、Redis 与其他 key-value 存储有什么不同?三、Redis 的数据类型?四、使用 Redis 有哪些好处?五、Redis 相比 Memcached 有哪些优势?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住…

C++,stl,常用遍历查找算法

目录​​​​​​​ 1.常用遍历算法 for_each transform 2.常用查找算法 find find_if adjacent_find binary_search count count_if 1.常用遍历算法 for_each #include<bits/stdc.h> using namespace std;void print(int v) {cout << v << ; }…

Date类(Java)、SimpleDateFormat

一、Date Date代表的是日期和时间 import java.util.Date;public class Test {public static void main(String[] args) {//Date日期类的使用//1.创建一个Date对象&#xff1a;代表系统当前时间信息Date d new Date();System.out.println(d); //打印当前时间信息//2.拿到时间…

Spring AOP的实现方式

AOP基本概念 Spring框架的两大核心&#xff1a;IoC和AOP AOP&#xff1a;Aspect Oriented Programming&#xff08;面向切面编程&#xff09; AOP是一种思想&#xff0c;是对某一类事情的集中处理 面向切面编程&#xff1a;切面就是指某一类特定的问题&#xff0c;所以AOP可…

北邮复试刷题103. 二叉树的锯齿形层序遍历

103. 二叉树的锯齿形层序遍历 给你二叉树的根节点 root &#xff0c;返回其节点值的 锯齿形层序遍历 。&#xff08;即先从左往右&#xff0c;再从右往左进行下一层遍历&#xff0c;以此类推&#xff0c;层与层之间交替进行&#xff09;。 示例 1&#xff1a;输入&#xff1a…

如何量化多样性?答:辛普森多样性指数(Simpson’s Diversity Index)

如何量化多样性&#xff1f;答&#xff1a;辛普森多样性指数&#xff08;Simpson’s Diversity Index&#xff09; 生物多样性指地球上各种生命的多样性&#xff0c;涵盖所有生命形式&#xff0c;从基因、细菌到所有生态系统&#xff0c;例如森林或珊瑚礁生态系统。 那么如何…

【C++】:哈希和哈希桶

朋友们、伙计们&#xff0c;我们又见面了&#xff0c;本期来给大家解读一下有关哈希和哈希桶的知识点&#xff0c;如果看完之后对你有一定的启发&#xff0c;那么请留下你的三连&#xff0c;祝大家心想事成&#xff01; C 语 言 专 栏&#xff1a;C语言&#xff1a;从入门到精通…