深度学习(RNN,LSTM,GRU)

news/2025/2/3 13:36:53/文章来源:https://www.cnblogs.com/tiandsp/p/18237323

三个网络的架构图:

RNN:

LSTM:

GRU:

特性对比列表:

特性
RNN
LSTM
GRU
门控数量
3门(输入/遗忘/输出)
2门(更新/重置)
记忆机制
 仅隐藏状态ht
显式状态Ct + 隐藏状态ht
隐式记忆(通过门控更新状态)
核心操作
 直接状态传递
门控细胞状态更新 
门控候选状态混合 
计算复杂度
O(d2)(1组权重)
 O(4d2)(4 组权重)
 O(3d2)(3 组权重)
长期依赖学习
差(<10步)
优秀(>1000步)
良好(~100步)
梯度消失问题
严重
显著缓解
较好缓解
参数数量
最少
最多(3倍于RNN)
中等(2倍于RNN)
训练速度
最快
最慢
较快
过拟合风险
中等
典型应用场景
简单序列分类
机器翻译/语音识别
文本生成/时间序列预测
下面是两个例子:

一:LSTM识别数字:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasetsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)class RNN(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, num_classes)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)out, _ = self.lstm(x, (h0, c0))  out = self.fc(out[:, -1, :])return outsequence_length = 28
input_size = 28  
hidden_size = 128
num_layers = 2
num_classes = 10model =RNN( input_size, hidden_size, num_layers, num_classes).to(device)
model.train()optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()num_epochs = 10
for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for images, labels in train_loader:B,_,_,_= images.shape  images = images.reshape(B,sequence_length,input_size)images = images.to(device)labels = labels.to(device)output = model(images)loss = criterion(output, labels)optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(output.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {(100 * correct / total):.2f}%")

二:GRU数据拟合:

import torch
import torch.nn as nn
import matplotlib.pyplot as pltclass RNN(nn.Module):def __init__(self):super().__init__()# self.rnn=nn.RNN(input_size=1,hidden_size=128,num_layers=1,batch_first=True)# self.rnn=nn.LSTM(input_size=1,hidden_size=128,num_layers=1,batch_first=True)self.rnn=nn.GRU(input_size=1,hidden_size=128,num_layers=1,batch_first=True)self.linear=nn.Linear(128,1)def forward(self,x):output,_=self.rnn(x)x=self.linear(output)return xif __name__ == '__main__':x = torch.linspace(-300,300,1000)*0.01 y = torch.sin(x*3.0) + torch.linspace(-300,300,1000)*0.01plt.plot(x, y,'r')  x = x.unsqueeze(1).cuda()y = y.unsqueeze(1).cuda()model=RNN().cuda()optimizer=torch.optim.Adam(model.parameters(),lr=5e-4)criterion = nn.MSELoss().cuda()for epoch in range(5000):preds=model(x)loss=criterion(preds,y)optimizer.zero_grad()loss.backward()optimizer.step()print('loss',epoch, loss.item())x = torch.linspace(-300,310,1000)*0.01x = x.unsqueeze(1).cuda()pred = model(x)plt.plot(x.cpu().detach().numpy(), pred.cpu().detach().numpy(),'b')plt.show()

参考:Understanding LSTM Networks -- colah's blog

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/878261.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI定制祝福视频,广州塔、动态彩灯、LED表白,直播互动新玩法(附下载链接)

在追剧的时候经常能看到一些浪漫的告白桥段,男主用圣诞彩灯表白、用城市标志性建筑的LED表白,或者在五光十色的烟花绽放后刻下女主角的名字,充满了仪式感和氛围感~现在,这样的表白效果用AI软件就能实现了,在社交平台上甚至还出现了类似的直播内容,观众送热气球或者其他礼…

VMware ARIA缺陷,黑客可获用户权限,哪些版本受影响?

VMware发布了安全更新,以修补影响VMware ARIA操作和日志操作的五个安全缺陷,并警告客户,黑客可以利用他们获得提升的访问或获得敏感信息。发行 安全 更新要补丁五安全 缺陷 影响VMware ARIA操作和日志操作,警告客户攻击者可以开发他们提高了使用权或者获得 敏感的信息。 列…

1. 2025年:致每一位在软件测试道路上奋斗的伙伴

亲爱的读者朋友们: 新年好!时光荏苒,转眼间我们已经迈入2025年。在这辞旧迎新的时刻,我怀着无比感恩的心情,向一路相伴的每一位软件测试从业者、爱好者以及关注者们致以最诚挚的祝福!愿大家在新的一年里,健康平安,事业有成,代码无Bug,需求皆清晰! 过去的一年,是软件测试行业蓬勃…

执行npm run dev时,报错10% building 2/5 modules 3 active node,如何解决?

错误信息如下:原因:版本问题,为了不替换node版本使用如下方法 在package.json文件下 将 "dev": " vue-cli-service serve", "build:prod": "vue-cli-service build", "build:stage": "vue-cli-service build --mode…

Make your ternimal more useful

目录引入Iterm2配置和Zshell配置TmuxVim配置基本使用插件配置Coc默认配置快捷键说明NerdTree快捷键分屏:Buffer, Windows和Tab 引入 本着好程序员要用好终端的信念,加之在使用mac过程中对快捷键依赖度增加,对鼠标的依赖逐渐减少,所以打算尝试配置终端的代码编写环境。 不曾…

龙哥量化:通达信技术指标编写技巧分享篇1-成交量和换手率

龙哥微信:Long622889代写通达信技术指标、选股公式(通达信,同花顺,东方财富,大智慧,文华,博易,飞狐)代写期货量化策略(TB交易开拓者,文华8,金字塔) 春节假期, 和朋友闲聊,发现在选股思路上很杂乱, 完全没有体系,但是大致可以分为两种,趋势策略和震荡策略,其…

昆明理工大学材料科学与工程学院 2025年硕士研究生招生预测调剂名额 (供考生提前规划)

亲爱的考生: 为助力各位考生提前规划考研调剂方向,昆明理工大学材料科学与工程学院结合近年招生趋势及学科发展需求,预测2025年材料工程相关专业将有部分调剂名额,具体信息如下。欢迎符合条件的考生持续关注! 一、预测调剂专业及名额注: 最终调剂名额以2025年研招网官方发…

hive-pig--pig安装

1.下载 curl https://dlcdn.apache.org/pig/pig-0.17.0/pig-0.17.0.tar.gz -o /opt/software/pig-0.17.0.tar.gz2.解压 tar -zxvf /opt/software/pig-0.17.0.tar.gz -C /usr/local/src/ mv /usr/local/src/pig-0.17.0/ /usr/local/src/pig 3.把二进制路径添加到命令行路径 echo…

PyTorch生态系统中的连续深度学习:使用Torchdyn实现连续时间神经网络

神经常微分方程(Neural ODEs)是深度学习领域的创新性模型架构,它将神经网络的离散变换扩展为连续时间动力系统。与传统神经网络将层表示为离散变换不同,Neural ODEs将变换过程视为深度(或时间)的连续函数。这种方法为机器学习开创了新的研究方向,尤其在生成模型、时间序…

[ArkUI] 记录一次 ArkUI 学习心得 (1) -- 基础概念

1.一个原生鸿蒙应用的源码目录其中:ets是项目的源码目录.ets/pages是页面目录, 用于渲染页面.resources是资源目录,下面会讲. 2.第一个原生鸿蒙应用 话不多说,直接上代码. @Entry @Component struct Index {@State message: string = My First Program!;@State num: number = 0…

互联网已经没法用了

图片:作者制作我们已经到了这样的地步——曾经能让我们随时随地获取全世界信息的互联网,现在已经完全没法用了。 罪魁祸首是广告,情况糟糕到一种极端的程度,以至于它被称为“广告末日”(adpocalypse)。 现在我打开的几乎每个网站都塞满了广告,整个页面都快撑爆了。在电脑…

uniCloud(dcloud.net.cn)https证书配制

前端网页托管-->参数配置-->域名信息-->更新证书 阿里云 https--SSL证书获取