GRU模块:nn.GRU层

摘要: 

       如果需要深入理解GRU的话,内部实现的详细代码和计算公式就比较重要,中间的一些过程及中间变量的意义需要详细关注。只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等。所以,仔细研究这个模块的实现还是非常有必要的。以此类推,对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

       本文中介绍GRU内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即语句:self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。

       先从输入输出的角度看,即代码中的这一行:output, state = self.rnn(X) 。在 GRU(Gated Recurrent Unit)中,outputstate 都是由 GRU 层的循环计算产生的,它们之间有直接的关系。state 实际上是 output 中最后一个时间步的隐藏状态。 

代码示例

class Seq2SeqEncoder(d2l.Encoder):
"""⽤于序列到序列学习的循环神经⽹络编码器"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)# 嵌⼊层self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size, num_hiddens, num_layers,dropout=dropout)def forward(self, X, *args):# 输出'X'的形状:(batch_size,num_steps,embed_size)X = self.embedding(X)# 在循环神经⽹络模型中,第⼀个轴对应于时间步X = X.permute(1, 0, 2)# 如果未提及状态,则默认为0output, state = self.rnn(X)# output的形状:(num_steps,batch_size,num_hiddens)# state的形状:(num_layers,batch_size,num_hiddens)return output, state

output:在完成所有时间步后,最后⼀层的隐状态的输出output是⼀个张量(output由编码器的循环层返回),其形状为(时间步数,批量⼤⼩,隐藏单元数)。

state:最后⼀个时间步的多层隐状态是state的形状是(隐藏层的数量,批量⼤⼩, 隐藏单元的数量)。

GRU模块的框图 

GRU 的基本公式

GRU 的核心计算包括更新门(update gate)和重置门(reset gate),以及候选隐藏状态(candidate hidden state)。数学表达式如下:

  1. 更新门 \( z_t \): \[ z_t = \sigma(W_z \cdot h_{t-1} + U_z \cdot x_t) \]
       其中,\( \sigma \) 是sigmoid 函数,\( W_z \) 和 \( U_z \) 分别是对应于隐藏状态和输入的权重矩阵,\( h_{t-1} \) 是上一个时间步的隐藏状态,\( x_t \) 是当前时间步的输入。

  2. 重置门 \( r_t \):
       \[ r_t = \sigma(W_r \cdot h_{t-1} + U_r \cdot x_t) \]
       \( W_r \) 和 \( U_r \) 是更新门中定义的相似权重矩阵。

  3. 候选隐藏状态 \( \tilde{h}_t \):
       \[ \tilde{h}_t = \tanh(W \cdot r_t \odot h_{t-1} + U \cdot x_t) \]
       这里,\( \tanh \) 是激活函数,\( \odot \) 表示元素乘法(Hadamard product),\( W \) 和 \( U \) 是隐藏状态的权重矩阵。

  4. 最终隐藏状态 \( h_t \):
       \[ h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \]

output 和 state 的关系

  • output:在 GRU 中,output 包含了序列中每个时间步的隐藏状态。具体来说,对于每个时间步 \( t \),output 的第 \( t \) 个元素就是该时间步的隐藏状态 \( h_t \)。

  • state:state 是 GRU 层最后一层的隐藏状态,也就是 output 中最后一个时间步的隐藏状态 \( h_{T-1} \),其中 \( T \) 是序列的长度。

数学表达式

如果我们用 \( O \) 表示 output,\( S \) 表示 state,\( T \) 表示时间步的总数,那么:

\[ O = [h_0, h_1, ..., h_{T-1}] \]
\[ S = h_{T-1} \]

因此,state 实际上是 output 中最后一个元素,即 \( S = O[T-1] \)。

在 PyTorch 中,output 和 state 都是由 GRU 层的 `forward` 方法计算得到的。`output` 是一个三维张量,包含了序列中每个时间步的隐藏状态,而 `state` 是一个二维张量,仅包含最后一个时间步的隐藏状态。

GRU的内部实现

上面一节的代码示例,是通过调用API实现的,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)。那么,GRU内部是如何实现的呢?

分为模型、模型参数初始化和隐状态初始化三个部分:

模型定义(模型定义与数学表示式一致,也可以参考上图):

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)

  模型参数初始化(权重是从标准差0.01的高斯分布中提取的,超参数num_hiddens定义隐藏单元的数量,偏置项设置为0):

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three() # 更新⻔参数W_xr, W_hr, b_r = three() # 重置⻔参数W_xh, W_hh, b_h = three() # 候选隐状态参数
# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)
# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

隐状态初始化函数(此函数返回一个形状为(批量大小,隐藏单元个数)的张量,张量的值都为0

def init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

最后由一个函数统一起来,实现模型:

model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_gru_state, gru)

小结 

       总体上说,内部的代码实现与数学表达式一致,在实际使用中,一般是通过调用API来实现,即self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout),只需要设定相应的参数即可,免除了重新实现的繁琐,并且类似于pytorch框架中的API还做了计算上的优化,使用起来高效方便。但是,如果需要深入理解GRU的话,那么内部实现的详细代码和计算公式就比较重要,中间的一些过程和变量的意义需要详细关注,只有这样,才能准备把握这个模块的内涵和意义,设计初衷和使用方式等等,所以,仔细研究这个模块的实现还是非常有必要的。对于其他的模块同样如此,只有把各个经典的模块内部原理、实现和计算调用都搞清楚了,才能更好的去设计和利用神经网络,建立内在的直觉和能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/681923.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

特征提取与深度神经网络(角点检测)

图像特征概述 图像特征表示是该图像唯一的表述,是图像的DNA HOG HOG (Histogram of Oriented Gradients)是一种用于目标检测的特征描述子。在行人检测中用的最多。HOG特征描述了图像中局部区域的梯度方向信息,通过计算图像中各个…

【爬虫】爬取股票历史K线数据写入数据库(三)

前几天有写过两篇: 【爬虫】爬取A股数据写入数据库(二) 【爬虫】爬取A股数据写入数据库(一) 现在继续完善,分析及爬取股票的历史K线数据通过ORM形式批量写入数据库。 2024/05,本文主要内容如下…

数据库(MySQL)—— 索引

数据库(MySQL)—— 索引 什么是索引创建索引使用 CREATE INDEX 语句使用 ALTER TABLE 语句在创建表时定义索引特殊类型索引注意事项 举个例子无索引的情况有索引的情况为什么索引快索引的结构 今天我们来看看MySQL中的索引: 什么是索引 MyS…

centos安装k6

最简单的方式github安装 地址: Release v0.50.0 grafana/k6 GitHub 找到对应文件 在下载好的rpm包的安装目录下输入如下的命令: rpm -ivh k6-v0.33.0-linux-amd64.rpm 检查是否安装成功 使用k6或者k6 version命令

每日Attention学习2——Multi-Scale Convolutional Attention

模块出处 [link] [code] [NIPS 22] SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation 模块名称 Multi-Scale Convolutional Attention (MSCA) 模块作用 多尺度特征提取,更大感受野 模块结构 模块代码 import torch import torch.…

数据可视化训练第二天(对比Python与numpy中的ndarray的效率并且可视化表示)

绪论 千里之行始于足下;继续坚持 1.对比Python和numpy的性能 使用魔法指令%timeit进行对比 需求: 实现两个数组的加法数组 A 是 0 到 N-1 数字的平方数组 B 是 0 到 N-1 数字的立方 import numpy as np def numpy_sum(text_num):"""…

Redis简介和数据结构

目录 简介 进入之后身份认证才能使用 优点 用途: 数据结构 string string自动扩容 Redis中的简单动态字符串(SDS)具有以下优点: SDS数据的编码格式 比较: string 常用操作 分布式锁 使用情况,…

详解解读白盒测试用例的设计

🍅 视频学习:文末有免费的配套视频可观看 🍅 关注公众号【互联网杂货铺】,回复 1 ,免费获取软件测试全套资料,资料在手,涨薪更快 正文 语句覆盖:每条语句至少执行一次。判定覆盖&am…

大数据中的HDFS读写流程(namenode,datanode)

HDFS读写流程 读取流程 1、客户端请求上传文件 2、namenode检查是否存在,可以上传, 3、客户端请求第一个block块上传到datanode 4、namenode返回3个datanode节点,d1,d2,d3 5、客户端请求dn1调用数据,d1收到请求会继续调用d2&#…

每日Attention学习3——Cross-level Feature Fusion

模块出处 [link] [code] [PR 23] Cross-level Feature Aggregation Network for Polyp Segmentation 模块名称 Cross-level Feature Fusion (CFF) 模块作用 双级特征融合 模块结构 模块代码 import torch import torch.nn as nnclass BasicConv2d(nn.Module):def __init__(…

产品评测:SmartX 与 Nutanix 超融合在数据库场景下的性能表现

重点内容 SmartX 与 Nutanix 超融合分布式存储设计差异如何影响数据库性能表现。重点测试结论:数据库场景下,SmartX 超融合基于单卷部署的性能,依旧优于 Nutanix 超融合基于多卷部署最佳配置的性能。更多 SmartX、VMware、Nutanix 超融合技术…

【python】基于岭回归算法对学生成绩进行预测

前言 在数据分析和机器学习领域,回归分析是一种预测连续数值的监督学习技术。当数据特征与目标变量之间存在线性关系时,线性回归模型尤其有用。然而,当特征数量多于样本数量,或者特征之间存在多重共线性时,普通最小二…