RNN 手动实现

news/2024/12/12 17:26:11/文章来源:https://www.cnblogs.com/cjming/p/18599208

RNN原理

RNN的整体架构如图

RNN每次看到一个词,通过状态hi来积累看到的信息。
例如,h0包含x0的信息,h1包含x0和x1的信息,最后一个状态ht包含了整句话的信息,从而可以把它作为整个句子的特征,用来做其他任务。
注意,无论RNN的链条有多长,都只有一个参数矩阵A,A可以随机初始化,然后再通过训练来学习。
image

image

RNN的激活函数用的是tanh,非sigmoid和relu

image

多层RNN (Stacked RNN)

多个全连接层可以堆叠,多个卷积层也可以堆叠。同理:RNN也可以堆叠形成多层RNN。
如下图所示:对于每一个时刻的输出 $ h_t$,它既会作为下一个时刻的输入,也会作为下一层RNN的输入。
image

nn.RNN

输入:

输入序列 x: (seq_len, batch_size, input_size)
初始化 h0: (num_layers, batch_size, hidden_size) 不提供默认全零

其中:
seq_len 是序列长度
batch_size 是批大小,
input_size 是输入的特征维度
num_layers 是RNN堆叠层数
hidden_size是隐藏状态的维度

输出

output: (batch_size, seq_len, hidden_size)
hidden: (num_layers, batch_size, hidden_size)

output 输出RNN在所有时间步上的隐藏状态输出。它包含了整个序列在每个时间步的隐藏状态。
hidden 代表隐藏层最后一个隐藏状态的输出。

hidden 只保留了最后一步的 hidden_state,但中间的 hidden_state 也有可能会参与计算,所以 pytorch 把中间每一步输出的 hidden_state 都放到 output 中(当然,只保留了 hidden_state 最后一层的输出)

如何使用 nn.RNN

data = torch.randn(batch_size, seq_len, input_size)
h0 = torch.zeros(num_layers, batch_size, hidden_size)rnn_layer = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
output, hidden = rnn_layer(data, h0)print("output.shape: [batch_size, seq_len, hidden_size] -- ", output.shape)
print("hidden.shape: [num_layers, batch_size, hidden_size] -- ", hidden.shape)

手动实现 RNN

点击查看代码
import torch
import torch.nn as nn
import randomclass myRNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, state_dict=None, batch_first=True):super(myRNN, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.num_layers = num_layersself.batch_first = batch_firstif state_dict is not None:self.state_dict = state_dictdef forward(self, input_ori, state=None):if self.batch_first:batch_size, seq_len, input_size = input_ori.size()input_ori = input_ori.permute(1, 0, 2)else:seq_len, batch_size, input_size = input_ori.size()h0 = state if state is not None else torch.zeros(self.num_layers, batch_size, self.hidden_size)ht = h0output = torch.zeros(seq_len, batch_size, self.hidden_size)for t in range(seq_len):input_t = input_ori[t,:,:]for layer in range(self.num_layers):weight_hh = self.state_dict['weight_hh_l{}'.format(layer)]  # [hidden_size, hidden_size]weight_ih = self.state_dict['weight_ih_l{}'.format(layer)]  # [hidden_size, input_size]bias_hh = self.state_dict['bias_hh_l{}'.format(layer)]      # [hidden_size]bias_ih = self.state_dict['bias_ih_l{}'.format(layer)]      # [hidden_size]ht[layer] = torch.tanh(ht[layer]@weight_hh.T + input_t@weight_ih.T + bias_hh + bias_ih)input_t = ht[layer]output[t] = ht[-1]if self.batch_first:output = output.permute(1, 0, 2)return output, htif __name__ == '__main__':# 设置随机种子seed = 0random.seed(seed)torch.manual_seed(seed)# 定义常量num_layers = 1hidden_size = 6input_size = 5batch_size = 4seq_len = 3data = torch.randn(batch_size, seq_len, input_size)h0 = torch.zeros(num_layers, batch_size, hidden_size)# pytorch RNNrnn_layer = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)output, hidden = rnn_layer(data, h0)print("torch RNN:")print("output.shape: [batch_size, seq_len, hidden_size] -- ", output.shape)print("hidden.shape: [num_layers, batch_size, hidden_size] -- ", hidden.shape)# my RNNstate_dict = rnn_layer.state_dict()my_rnn_layer = myRNN(input_size, hidden_size, num_layers, state_dict=state_dict, batch_first=True)output2, hidden2 = my_rnn_layer(data, h0)print("my RNN:")print("output.shape: [batch_size, seq_len, hidden_size] -- ", output2.shape)print("hidden.shape: [num_layers, batch_size, hidden_size] -- ", hidden2.shape)if torch.sum(output - output2) < 1e-6 and torch.sum(hidden - hidden2) < 1e-6:print("The result is the same!")else:print("The result is different!")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/850794.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

百度文本编辑器Ueditor存储、取用以及生成PDF

一.Ueditor存储、取用 1.引用编辑器配置文件 2.设置编辑器的大小 3.加載編輯器的容器 4.从数据库获取到内容后,js赋值到编辑器中(result.NoticeMsg是我获取的内容) 5.点击新增时,编辑器进行清空处理:ue.setContent(""); 6.点击保存时,获取编辑器内容并打包成js…

达梦删除归档的几种方式

1.通过归档日志的相关参数 配置归档 alter database mount;alter database add archivelog dest=/data/dmarch,TYPE=local,FILE_SIZE=1024,SPACE_LIMIT=51200;alter database archivelog;alter database open;SPACE_LIMIT --显示了归档文件夹的大小,文件夹满了会自动覆盖最早…

Librenms强制修改密码

本来没这事的,不小心在网页点了修改密码,填充的自动生成的密码,又没保存生成的密码,改密码又得需要原来的密码,只能通过数据库改回来了。所有的具体操作步骤: 1、找一个在线生成密码的网站,例如https://onlinephp.io/password-hash 2、生成一个新的hash密码,例如“@@12…

太假了,这简历一看就是包装的。。

大家好,我是R哥。 最近做 Java 面试辅导,看了许多小伙伴的简历,有的人的简历一看就知道是包装的,比如这位,他自己都承认了:见过太多这样的同学了,自己瞎折腾,哭笑不得。 包装过的简历,作为多年面试官,我一眼就能看出来,相信其他面试官也会有同样的感觉,这也是为什么…

论文解读-Graph neural networks: A review of methods and applications

论文介绍 这篇论文是图神经网络领域的综述性质的文章,从各个论文里面凝聚和提取了很多精炼的信息和观点,值得一读。 论文是2020年成稿投出去的,有点陈旧的。 GNN的介绍 在introduction里面对比了GNN和CNN,表示出CNN的关键是局部连接,共享权重,和多层的使用。其中CNN操作…

【虚拟机】Windows(x86)上部署Win11 on ARM虚拟机

参考链接: 1. https://blog.csdn.net/XiaoYuHaoAiMin/article/details/140701250 2. https://mbd.baidu.com/newspage/data/dtlandingsuper?nid=dt_4530491488179269409&sourceFrom=search_a 第一步:安装QEMU虚拟机 1. 下载链接:https://www.qemu.org/(这个链接找到的…

VsCode插件CnBlogs博客园客户端使用体验

VsCode插件CnBlogs博客园客户端使用体验 VsCode安装以及插件安装VsCode官网下载 VsCode插件CnBlog下载地址CnBlog插件功能 1.账户登陆2.工作空间3.随笔分类4.随笔列表5.编辑MarkDown博客

MCGS读取经纬度

1.将4G通讯状态调整为GetLocationFromGaoDe的触发条件 2.取消经纬度循环赋值,减少流量消耗 资料说明: 资料为4G屏获取经纬度的样例工程 注意事项: 1.此样例仅支持4G屏,WiFi屏获取经纬度无效 2.4G驱动仅支持V1.009及以上版本,定位失败优先检查驱动版本 操作步骤:添加mcgsI…

Windows 触控笔

平板以及二合一平板均是触控屏,Laptop现在也有很多屏幕带触控 触控屏,都会配置触控笔配件,目前市场上一般是电容屏+电容笔的技术方案。 触控笔分为主动笔和被动笔,主动笔占绝大部分。主动笔是通过内部电池或电源供电的,可以主动发送信号给设备,采用电磁感应原理,通过在屏…

启动终端判断SSH是否启动

启动终端判断SSH是否启动 原理:在Linux系统启动后,会运行shell(bash、zsh等)软件的配置文件~/.bashrc​,~/.zshrc​等 ‍ ‍ 以zsh​为例,在~/.zshrc​中添加如下内容 # ...# 检查 SSH 服务是否正在运行 ssh_status=$(service ssh status)if echo "$ssh_status"…

记一次与Rocketmq的进程异常行为修复过程

rocketmq部署在docker中。 前段时间,阿里云服务器发出安全告警看到curl和startfsrv.sh,下意识地认为这是下载了一个恶意脚本,接下来把恶意脚本找到,分析内容,修复的思路就有了。 但是找到脚本之后,创建时间是2019年,同时也只是rocketmq一个正常的启动脚本。这样思路就断…

LeetCode:2717、半有序队列

LeetCode算法做题记录题目: 给你一个下标从 0 开始、长度为 n 的整数排列 nums 。 如果排列的第一个数字等于 1 且最后一个数字等于 n ,则称其为 半有序排列 。你可以执行多次下述操作,直到将 nums 变成一个 半有序排列 : 选择 nums 中相邻的两个元素,然后交换它们。 返回…