机器学习入门--双向长短期记忆神经网络(BiLSTM)原理与实践

双向长短记忆网络(BiLSTM)

BiLSTM(双向长短时记忆网络)是一种特殊的循环神经网络(RNN),它能够处理序列数据并保持长期记忆。与传统的RNN模型不同的是,BiLSTM同时考虑了过去和未来的信息,使得模型能够更好地捕捉序列数据中的上下文关系。在本文中,我们将详细介绍BiLSTM的数学原理、代码实现以及应用场景。

数学原理

LSTM(长短期记忆网络)是一种递归神经网络(RNN),通过引入门控机制来解决传统RNN中的梯度消失或梯度爆炸问题。BiLSTM是LSTM的一个变体,它在时间序列上同时运行两个LSTM,一个从前向后处理,一个从后向前处理。

LSTM的关键部分是单元状态(cell state)和各种门控机制,包括遗忘门、输入门和输出门。这些门控机制使用sigmoid函数来控制信息流动的程度。

具体来说,对于给定的时间步 t t t和输入 X t X_t Xt,LSTM计算以下值:
1.遗忘门(forget gate):
f t = σ ( W f ⋅ [ h t − 1 , X t ] + b f ) f_t = σ(W_f \cdot [h_{t-1}, X_t] + b_f) ft=σ(Wf[ht1,Xt]+bf)
2.输入门(input gate):
i t = σ ( W i ⋅ [ h t − 1 , X t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, X_t] + b_i) it=σ(Wi[ht1,Xt]+bi)
3.更新单元状态(new cell state):
C ~ t = t a n h ( W c ⋅ [ h t − 1 , X t ] + b c ) \tilde{C}_t = tanh(W_c \cdot [h_{t-1}, X_t] + b_c) C~t=tanh(Wc[ht1,Xt]+bc)
4.单元状态(cell state):
C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t * C_{t-1} + i_t * \tilde{C}_t Ct=ftCt1+itC~t
5.输出门(output gate):
o t = σ ( W o ⋅ [ h t − 1 , X t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, X_t] + b_o) ot=σ(Wo[ht1,Xt]+bo)
6.隐状态(hidden state)更新:
h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t * \tanh(C_t) ht=ottanh(Ct)
其中, σ \sigma σ表示sigmoid函数

BiLSTM通过将LSTM层沿着时间轴前向和后向运行来计算双向隐藏状态。前向LSTM从序列的第一个元素到最后一个元素顺序计算,而后向LSTM则相反。这两个隐藏状态被连接在一起形成最终的双向隐藏状态。

双向LSTM的输出可以用于各种任务,如序列标注、文本分类等。它能够捕捉到序列中每个时间步之前和之后的上下文信息,从而提供更全面的特征表示。

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.datasets import load_boston
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt# 加载数据集并进行标准化
data = load_boston()
X = data.data
y = data.target
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = y.reshape(-1, 1)# 转换为张量
X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
y = torch.tensor(y, dtype=torch.float32)# 定义BiLSTM模型
class BiLSTMNet(nn.Module):def __init__(self, input_size, hidden_size, num_layers, dropout,output_size):super(BiLSTMNet, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.num_layers = num_layersself.dropout = dropoutself.output_size = output_sizeself.lstm = nn.LSTM(input_size = self.input_size, hidden_size = self.hidden_size,num_layers = self.num_layers,bidirectional = True,batch_first=True,dropout = self.dropout)self.fc1 = nn.Linear(self.hidden_size * 2, self.hidden_size)self.fc2 = nn.Linear(self.hidden_size, self.output_size)def forward(self, x):output, (h_n, c_n) = self.lstm(x)out = torch.concat([h_n[-1,:,:], h_n[-2, :, :]], dim=-1)out = F.relu(out)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)return outinput_size = X.shape[2]
hidden_size = 32
output_size = 1
model = BiLSTMNet(input_size=input_size, hidden_size=hidden_size, num_layers=4,dropout=0, output_size=output_size)# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
num_epochs = 10000
loss_list = []
for epoch in range(num_epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch+1) % 100 == 0:loss_list.append(loss.item())print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 可视化损失曲线
plt.plot(range(100), loss_list)
plt.xlabel('num_epochs')
plt.ylabel('loss of LSTM Training')
plt.show()# 预测新数据
new_data_point = X[0].reshape(1, 1, -1)
prediction = model(new_data_point)
print(f'Predicted value: {prediction.item()}')

上述代码实现了一个双向LSTM模型,用于预测波士顿房价数据集中的房价。其中,首先加载并标准化数据集,然后定义了一个自定义的BiLSTMNet模型类,该类包含了一个双向LSTM层和两个全连接层,并使用MSELoss作为损失函数和Adam作为优化器进行模型训练。在训练过程中,每100个epoch将损失值记录下来,最后使用Matplotlib将损失曲线可视化(如下图所示)。最后,使用模型对新数据点进行预测并输出结果。整个代码实现了一个简单的房价预测模型,可以通过调节模型参数和超参数进行进一步优化。
BiLSTM-损失曲线

总结

BiLSTM是一种强大的序列建模工具,它能够处理序列数据并捕捉长期依赖关系。在实践中,BiLSTM已被广泛应用于语音识别、自然语言处理和股票预测等领域。本文介绍了BiLSTM的数学原理、代码实现以及应用场景,希望读者能够通过本文了解到BiLSTM的基本知识和使用方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/475368.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【EI会议征稿通知】第五届城市工程与管理科学国际会议(ICUEMS 2024)

【Scopus稳定检索】第五届城市工程与管理科学国际会议(ICUEMS 2024) 2024 5th International Conference on Urban Engineering and Management Science 第五届城市工程与管理科学国际会议(ICUEMS 2024)将于2024年5月31日-6月2日…

惠普打印机驱动安装

一、下载驱动 支持 --> 软件与驱动程序 https://www.hp.com/cn-zh/home.html 选择打印机 输入打印机型号,下拉框选择自己的打印机型号 打印机型号正常在打印机的正面会有 往下滑选择安装软件和全功能/基本功能驱动程序-仅支持打印和扫描功能 (1) 点击下载…

代码随想录 Leetcode435. 无重叠区间

题目&#xff1a; 代码(首刷看解析 2024年2月17日&#xff09;&#xff1a; class Solution { private:const static bool cmp(vector<int>& a,vector<int>& b) {return a[0] < b[0];} public:int eraseOverlapIntervals(vector<vector<int>&…

vue生命周期函数

父子组件加载顺序 加载渲染过程 父beforeCreate->父created->父beforeMount->子beforeCreate->子created->子beforeMount->子mounted->父mounted子组件更新过程 父beforeUpdate->子beforeUpdate->子updated->父updated父组件更新过程 父beforeU…

【CSS】设置文字(文本)的渐变色

# 渐变色 文字 第一步 设置渐变颜色 background: linear-gradient(278.83deg, #5022bd 31.42%, #8636d1 75.55%); // 先设置渐变色背景&#xff1b; 第二步 设置颜色的使用范围 background-clip: text; // 背景被裁剪成文字的前景色。 -webkit-background-clip: text; 第三步…

JMeter接口测试数据分离驱动应用

步骤&#xff1a; 创建csv文件&#xff0c;编写接口测试用例 新建线程组——创建循环控制器&#xff08;循环次数填用例总数&#xff09; 创建CSV数据文件设置&#xff0c;设置参数。&#xff08;注意&#xff1a;是否允许带引号&#xff1f;&#xff1a;一定要设置为true&a…

微信安装包为啥越来越大?

一、微信安装包大小的变化趋势 微信作为中国最流行的即时通讯应用之一&#xff0c;其安装包大小一直是用户普遍关注的话题。随着技术的发展和功能的增加&#xff0c;微信安装包的大小也呈现出明显的变化趋势。从最初的几MB到如今的数十MB甚至上百MB&#xff0c;微信安装包的不…

报错405(errAxiosError: Request failed with status code 405)

errAxiosError: Request failed with status code 405 前端调用接口的方法跟后台定义接口的方法不一致

基于 Python 的景区票务人脸识别系统,附源码

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

StartAI V2文生图咒语集合(一)

1.3D海浪风景画 使用关键词&#xff1a;a paper quilling painting showing large waves crashing on a night coastline, in the style of surreal 3d landscapes, pretty, high-contrast shading, fairy tale:: manga watercolor & oil painting with paper quilling by…

第六十三天 服务攻防-框架安全CVE复现DjangoFlaskNode.JSJQuery

第六十三天 服务攻防-框架安全&CVE复现&Django&Flask&Node.JS&JQuery 知识点&#xff1a; 中间件及框架列表&#xff1a; IIS,Apache,Nginx,Tomcat,Docker,K8s,Weblogic.JBoos,WebSphere, Jenkins,GlassFish,Jetty,Jira,Struts2,Laravel,Solr,Shiro,Thin…

【Redis】深入理解 Redis 常用数据类型源码及底层实现(4.详解Hash数据结构)

Hash数据结构 看过前面的介绍,大家应该知道 Redis 的 Hash 结构的底层实现在 6 和 7 是不同的,Redis 6 是 ziplist 和 hashtable,Redis 7 是 listpack 和 hashtable。 我们先使用config get hash*看下 Redis 6 和 Redis 7 的 Hash 结构配置情况(在Redis客户端的命令行界面…