边写代码边学习之LSTM

1.  什么是LSTM

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
 

在这里插入图片描述

 

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

2.2. 验证LSTM里的逻辑

 假设我的输入数据是x = [1,0], 

kernel = [[[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],

              [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],]]

recurrent_kernel = [[1, 0, 0, 1, 2,1,0,1,2,0,1,0],

                              [1, 1, 0, 0, 2,1,0,1,2,2,0,0],

                              [1, 0, 1, 2, 0,1,0,1,1,0,1,0]]

biase = [3, 1, 0, 1, 1,0,0,1,0,2,0.0,0]

通过下面手算,h的结果是[0, 4,1], c 的结果是[0,4,1].  注意无激活函数。

代码验证上面的结果


def change_weight():# Create a simple Dense layerlstm_layer = LSTM(units=3, input_shape=(3, 2), activation=None, recurrent_activation=None, return_sequences=True,return_state= True)# Simulate input data (batch size of 1 for demonstration)input_data = np.array([[[1.0, 2], [2, 3], [3, 4]],[[5, 6], [6, 7], [7, 8]],[[9, 10], [10, 11], [11, 12]]])# Pass the input data through the layer to initialize the weights and biaseslstm_layer(input_data)kernel, recurrent_kernel, biases = lstm_layer.get_weights()# Print the initial weights and biasesprint("recurrent_kernel:", recurrent_kernel, recurrent_kernel.shape ) # (3,3)print('kernal:',kernel, kernel.shape) #(2,3)print('biase: ',biases , biases.shape) # (3)kernel = np.array([[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],[1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],])recurrent_kernel = np.array([[1, 0, 0, 1, 2,1,0,1,2,0,1,0],[1, 1, 0, 0, 2,1,0,1,2,2,0,0],[1, 0, 1, 2, 0,1,0,1,1,0,1,0]])biases = np.array([3, 1, 0, 1, 1,0,0,1,0,2,0.0,0])lstm_layer.set_weights([kernel, recurrent_kernel, biases])print(lstm_layer.get_weights())# test_data = np.array([#     [[1.0, 3], [1, 1], [2, 3]]# ])test_data = np.array([[[1,0.0]]])output, memory_state, carry_state  = lstm_layer(test_data)print(output)print(memory_state)print(carry_state)
if __name__ == '__main__':change_weight()

执行结果:

recurrent_kernel: [[-0.36744034 -0.11181469 -0.10642298  0.5450207  -0.30208975  0.54054320.09643812 -0.14983998  0.1859854   0.2336958  -0.16187981  0.11621032][ 0.07727922 -0.226477    0.1491096  -0.03933501  0.31236103 -0.129630920.10522162 -0.4815724  -0.2093935   0.34740582 -0.60979587 -0.15877807][ 0.15371156  0.01244636 -0.09840634 -0.32093546  0.06523462  0.189349320.38859126 -0.3261706  -0.05138849  0.42713478  0.49390993  0.37013963]] (3, 12)
kernal: [[-0.47606698 -0.43589187 -0.5371355  -0.07337284  0.30526626 -0.18241835-0.03675252  0.2873094   0.33218485  0.24838251  0.17765659  0.4312396 ][ 0.4007727   0.41280174  0.40750778 -0.6245315   0.6382301   0.428892250.11961156 -0.6021105  -0.43556038  0.39798307  0.6390712   0.16719025]] (2, 12)
biase:  [0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0.] (12,)
[array([[2., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.],[1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32), array([[1., 0., 0., 1., 2., 1., 0., 1., 2., 0., 1., 0.],[1., 1., 0., 0., 2., 1., 0., 1., 2., 2., 0., 0.],[1., 0., 1., 2., 0., 1., 0., 1., 1., 0., 1., 0.]], dtype=float32), array([3., 1., 0., 1., 1., 0., 0., 1., 0., 2., 0., 0.], dtype=float32)]
tf.Tensor([[[0. 4. 0.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[0. 4. 0.]], shape=(1, 3), dtype=float32)
tf.Tensor([[0. 4. 1.]], shape=(1, 3), dtype=float32)

可以看出h=[0,4,0], c=[0,4,1]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/55932.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3.2 防火墙

数据参考:CISP官方 目录 防火墙基础概念防火墙的典型技术防火墙企业部署防火墙的局限性 一、防火墙基础概念 防火墙基础概念: 防火墙(Firewall)一词来源于早期的欧式建筑,它是建筑物之间的一道矮墙,用…

无人车沿着指定线路自动驾驶与远程控制的实践应用

有了前面颜色识别跟踪的基础之后,我们就可以设定颜色路径,让无人车沿着指定线路做自动驾驶了,视频:PID控制无人车自动驾驶 有了前几章的知识铺垫,就比较简单了,也是属于颜色识别的一种应用,主要…

约数个数和欧拉函数

1.约数个数 一个数等于它的质因子的c次方相乘,那么约数个数为所有的次数分别1再相乘。 2. 大概时间复杂度 1-n中,所有数的约数个数之和 3.int范围内约数最t多的数大概1600个左右 一个数的约数大概 根号n 的复杂度

使用Python + Flask搭建web服务

示例脚本 from flask import Flask# 获取一个实例对象 app Flask(__name__)# 1、注册 app.route(/reg, methods[get]) def reg():return {code: 200,msg: reg ok!}# 2、登录 app.route(/login, methods[get]) def login():return login ok!if __name__ __main__:…

Linux 终端操作命令(1)

Linux 命令 终端命令格式 command [-options] [parameter] 说明: command:命令名,相应功能的英文单词或单词的缩写[-options]:选项,可用来对命令进行控制,也可以省略parameter:传给命令的参…

Linux基础与应用开发系列四:ARM-GCC与交叉编译

三个问题: ARM-GCC是什么?它与GCC有什么关系? 编译工具链和目标程序运行相同的架构平台,就叫本地编译 编译工具链和目标程序运行在不同的架构平台,叫做交叉编译 ARM-GCC是针对arm平台的一款编译器,它是GCC编译工具链的一个分支 虚拟机…

百度智能创做AI平台

家人们好,在数字化时代,人工智能正引领着一场前所未有的创新浪潮。今天,我们将为大家介绍百度智能创做AI平台,这个为创意赋能、助力创作者的强大工具。无论你是创意工作者、内容创作者,还是想要释放内心创造力的个人&a…

深入探索Python数据容器:绚丽字符串、神奇序列切片与魔幻集合奇遇

一 数据容器:str(字符串) 1.1 字符串初识 字符串也是数据容器的一员,字符串是一种数据容器,用于存储和处理文本数据。字符串是字符的容器,一个字符串可以存放任意数量的字符,可以包含字母、数字、标点符号、空格等字…

c++11 标准模板(STL)(std::basic_ofstream)(五)

定义于头文件 <fstream> template< class CharT, class Traits std::char_traits<CharT> > class basic_ofstream : public std::basic_ostream<CharT, Traits> 类模板 basic_ofstream 实现文件上基于流的高层输出操作。它将 std::basic_ost…

无人驾驶实战-第八课(定位算法)

无人驾驶中定位的作用&#xff1a; 定位高精度地图&#xff1a;提供当前位置的静态环境感知 &#xff08;车道线/交通指示牌/红绿灯/柱子/建筑物/等&#xff09; 定位动态物体感知&#xff1a;将感知到的动态物体正确放入静态环境 定位获取位置姿态&#xff1a;用于路径规划/决…

Qt实现自定义QDoubleSpinBox软键盘

在Qt应用程序开发中&#xff0c;经常会遇到需要自定义输入控件的需求。其中&#xff0c;对于QDoubleSpinBox控件&#xff0c;如果希望在点击时弹出一个自定义的软键盘&#xff0c;以便用户输入数值&#xff0c;并将输入的值设置给QDoubleSpinBox&#xff0c;该如何实现呢&#…

冠达管理投资前瞻:三星加码机器人领域 大信创建设提速

上星期五&#xff0c;沪指高开高走&#xff0c;盘中一度涨超1%打破3300点&#xff0c;但随后涨幅收窄&#xff1b;深成指、创业板指亦强势震动。截至收盘&#xff0c;沪指涨0.23%报3288.08点&#xff0c;深成指涨0.67%报11238.06点&#xff0c;创业板指涨0.95%报2263.37点&…