李沐57_长短期记忆网络LSTM——自学笔记

LSTM

1.忘记门:将值朝着0减少

2.输入门:决定不是忽略掉输入数据

3.输出门:决定是不是使用隐状态

!pip install --upgrade d2l==0.17.5  #d2l需要更新

首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...

初始化模型参数:超参数num_hiddens定义隐藏单元的数量。 我们按照标准差0.01的高斯分布初始化权重,并将偏置项设为0。

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型:长短期记忆网络的隐状态需要返回一个额外的记忆元, 单元的值为0,形状为(批量大小,隐藏单元数)。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

实际模型的定义与我们前面讨论的一样: 提供三个门和一个额外的记忆元。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

训练和预测:引入的RNNModelScratch类来训练一个长短期记忆网络

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.3, 28797.1 tokens/sec on cuda:0
time traveller abt whther werl wechat said al somick and the a s
traveller whthe graid that ore seellon twatere whree time s

在这里插入图片描述

简洁实现

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 319759.1 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/643691.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【HarmonyOS4学习笔记】《HarmonyOS4+NEXT星河版入门到企业级实战教程》课程学习笔记(一)

课程地址: 黑马程序员HarmonyOS4NEXT星河版入门到企业级实战教程,一套精通鸿蒙应用开发 (本篇笔记对应课程第 1 - 2节) P1《课程介绍》 开场白,HarmonyOS 的一个简介,话不多说,直接看图吧&…

jar中没有主清单属性

运行springboot的jar 提示&#xff1a;jar中没有主清单属性 我的pom.xml 的plugins配置是下面 <build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-compiler-plugin</artifactId><version&g…

政安晨:【Keras机器学习示例演绎】(十二)—— 用利用 MIRNet 增强弱光图像效果

目录 简介 下载 LOL 数据集 创建 TensorFlow 数据集 MIRNet 模型 选择性核特征融合 双注意单元 多尺度残差块 MIRNet 模型 训练 推论 测试图像推理 政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机器学习实战…

tiktok如何影响用户行为的分析兼论快速数据分析的策略

tiktok如何影响用户行为的分析 快速数据分析的策略流程&#xff1a; 1.确定指标变量&#xff0c;也就确定了数据分析想要回答的问题。想回答不同的问题&#xff0c;就选择不同的指标变量。 变量筛选方法选出指标变量相关的变量&#xff1b; 针对筛选出的变量进行描述性分析和因…

研究发现:提示中加入数百个示例显著提升大型语言模型的性能

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

阿赵UE学习笔记——30、HUD简单介绍

阿赵UE学习笔记目录 大家好&#xff0c;我是阿赵。   继续学习虚幻引擎&#xff0c;这次来学习一下HUD的基础使用。 一、 什么是HUD HUD(Head-Up Display)&#xff0c;也就是俗称的抬头显示。很多其他领域里面有用到这个术语&#xff0c;比如开车的朋友可能会接触过&#xf…

编译器的学习

常用的编译器&#xff1a; GCCVisual CClang&#xff08;LLVM&#xff09;&#xff1a; Clang 可以被看作是建立在 LLVM 之上的一个项目, 实际上LLVM是clang的后端&#xff0c;clang作为前端前端生成LLVM IR&#xff0c;https://zhuanlan.zhihu.com/p/656699711MSVC &#xff…

(done) 什么是 SVD 奇异值分解?什么是 TruncatedSVD 截断奇异值分解?

来源&#xff1a;https://www.bilibili.com/video/BV16A411T7zX/?spm_id_from333.337.search-card.all.click&vd_source7a1a0bc74158c6993c7355c5490fc600 奇异值分解其实就是如下图&#xff0c;把矩阵 M 分解成一个正交方阵 U&#xff0c;乘以一个不规则奇异值矩阵 sigma…

Spring Boot入门(21):使用Spring Boot和Log4j2进行高效日志管理:配置详解

Spring Boot 整合 Log4j2 前言 Log4j2是Apache软件基金会下的一个日志框架&#xff0c;它是Log4j的升级版。与Log4j相比&#xff0c;它在性能和功能上有着极大的提升。Spring Boot本身已经默认集成了Logback作为日志框架&#xff0c;但如果需要使用Log4j2来替代Logback&#…

分享基于鸿蒙OpenHarmony的Unity团结引擎应用开发赛

该赛题旨在鼓励更多开发者基于OpenHarmony4.x版本&#xff0c;使用团结引擎创造出精彩的游戏与应用。本次大赛分为“创新游戏”与“创新3D 化应用”两大赛道&#xff0c;每赛道又分“大众组”与“高校组”&#xff0c;让不同背景的开发者同台竞技。无论你是游戏开发者&#xff…

8.4.3 使用3:配置单臂路由实现VLAN间路由

1、实验目的 通过本实验可以掌握&#xff1a; 路由器以太网接口上的子接口配置和调试方法。单臂路由实现 VLAN间路由的配置和调试方法。 2、实验拓扑 实验拓扑如下图所示。 3、实验步骤 &#xff08;1&#xff09;配置交换机S1 S1(config)#vlan 2 S1(config-vlan)#exit S…

LayuiMini使用时候初始化模板修改(下载源码)

忘记加了 下载 地址 &#xff1a; layui-mini: layuimini&#xff0c;后台admin前端模板&#xff0c;基于 layui 编写的最简洁、易用的后台框架模板。只需提供一个接口就直接初始化整个框架&#xff0c;无需复杂操作。 LayuiMini使用时候初始化模板官网给的是&#xff1a; layu…