勾八头歌之RNN

一、RNN快速入门

1.学习单步的RNN:RNNCell

# -*- coding: utf-8 -*-
import tensorflow as tf# 参数 a 是 BasicRNNCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatRNNCell(a,b,c):# 请在此添加代码 完成本关任务# ********** Begin *********#x1=tf.placeholder(tf.float32,[b,c])cell=tf.nn.rnn_cell.BasicRNNCell(num_units=a)h0=cell.zero_state(batch_size=b,dtype=tf.float32)output,h1=cell.__call__(x1,h0)print(cell.state_size)print(h1)# ********** End **********#

2.探幽入微LSTM

# -*- coding: utf-8 -*-
import tensorflow as tf# 参数 a 是 BasicLSTMCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape = [ b , c ]
def creatLSTMCell(a,b,c):# 请在此添加代码 完成本关任务# ********** Begin *********#x1=tf.placeholder(tf.float32,[b,c])cell=tf.nn.rnn_cell.BasicLSTMCell(num_units=a)h0=cell.zero_state(batch_size=b,dtype=tf.float32)output,h1=cell.__call__(x1,h0)print(h1.h)print(h1.c)# ********** End **********#

3.进阶RNN:学习一次执行多步以及堆叠RNN

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np# 参数 a 是RNN的层数, 参数 b 是每个BasicRNNCell包含的神经元数即state_size
# 参数 c 是输入序列的批量大小即batch_size,参数 d 是时间序列的步长即time_steps,参数 e 是单个输入input的维数即input_size
def MultiRNNCell_dynamic_call(a,b,c,d,e):# 用tf.nn.rnn_cell MultiRNNCell创建a层RNN,并调用tf.nn.dynamic_rnn# 请在此添加代码 完成本关任务# ********** Begin *********#cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.BasicRNNCell(num_units=b) for _ in range(a)]) # a层RNNinputs = tf.placeholder(np.float32, shape=(c, d, e)) # a 是 batch_size,d 是time_steps, e 是input_sizeh0=cell.zero_state(batch_size=c,dtype=tf.float32)output, h1 = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0)print(output)# ********** End **********#

二、RNN循环神经网络

1.Attention注意力机制(A  ABC  B  C  A)

2.Seq2Seq

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variabledtype = torch.FloatTensor
char_list = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']
char_dic = {n: i for i, n in enumerate(char_list)}
seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low']]
seq_len = 8
n_hidden = 128
n_class = len(char_list)
batch_size = len(seq_data)##########Begin##########
#对数据进行编码部分
##########End##########
def make_batch(seq_data):batch_size = len(seq_data)input_batch, output_batch, target_batch = [], [], []for seq in seq_data:for i in range(2):seq[i] += 'P' * (seq_len - len(seq[i]))input = [char_dic[n] for n in seq[0]]output = [char_dic[n] for n in ('S' + seq[1])]target = [char_dic[n] for n in (seq[1] + 'E')]input_batch.append(np.eye(n_class)[input])output_batch.append(np.eye(n_class)[output])target_batch.append(target)return Variable(torch.Tensor(input_batch)), Variable(torch.Tensor(output_batch)), Variable(torch.LongTensor(target_batch))##########Begin##########
#模型类定义
input_batch, output_batch, target_batch = make_batch(seq_data)
class Seq2Seq(nn.Module):def __init__(self):super(Seq2Seq, self).__init__()self.encoder = nn.RNN(input_size=n_class, hidden_size=n_hidden)self.decoder = nn.RNN(input_size=n_class, hidden_size=n_hidden)self.fc = nn.Linear(n_hidden, n_class)def forward(self, enc_input, enc_hidden, dec_input):enc_input = enc_input.transpose(0, 1)dec_input = dec_input.transpose(0, 1)_, h_states = self.encoder(enc_input, enc_hidden)outputs, _ = self.decoder(dec_input, h_states)outputs = self.fc(outputs)return outputs
##########End##########model = Seq2Seq()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)##########Begin##########
#模型训练过程
for epoch in range(5001):hidden = Variable(torch.zeros(1, batch_size, n_hidden))optimizer.zero_grad()outputs = model(input_batch, hidden, output_batch)outputs = outputs.transpose(0, 1)loss = 0for i in range(batch_size):loss += criterion(outputs[i], target_batch[i])loss.backward()optimizer.step()
##########End####################Begin##########
#模型验证过程函数
def translated(word):input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]])hidden = Variable(torch.zeros(1, 1, n_hidden))outputs = model(input_batch, hidden, output_batch)predict = outputs.data.max(2, keepdim=True)[1]decode = [char_list[i] for i in predict]end = decode.index('P')translated = ''.join(decode[:end])print(translated)
##########End##########translated('highh')
translated('kingh')

三、RNN和LSTM

1.循环神经网络简介

import torchdef rnn(input,state,params):"""循环神经网络的前向传播:param input: 输入,形状为 [ batch_size,num_inputs ]:param state: 上一时刻循环神经网络的状态,形状为 [ batch_size,num_hiddens ]:param params: 循环神经网络的所使用的权重以及偏置:return: 输出结果和此时刻网络的状态"""W_xh,W_hh,b_h,W_hq,b_q = params"""W_xh : 输入层到隐藏层的权重W_hh : 上一时刻状态隐藏层到当前时刻的权重b_h : 隐藏层偏置W_hq : 隐藏层到输出层的权重b_q : 输出层偏置"""H = state# 输入层到隐藏层H = torch.matmul(input, W_xh) + torch.matmul(H, W_hh) + b_hH = torch.tanh(H)# 隐藏层到输出层Y = torch.matmul(H, W_hq) + b_qreturn Y,Hdef init_rnn_state(num_inputs,num_hiddens):"""循环神经网络的初始状态的初始化:param num_inputs: 输入层中神经元的个数:param num_hiddens: 隐藏层中神经元的个数:return: 循环神经网络初始状态"""init_state = torch.zeros((num_inputs,num_hiddens),dtype=torch.float32)return init_state

2.长短时记忆网络

import torchdef lstm(X,state,params):"""LSTM:param X: 输入:param state: 上一时刻的单元状态和输出:param params: LSTM 中所有的权值矩阵以及偏置:return: 当前时刻的单元状态和输出"""W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params"""W_xi,W_hi,b_i : 输入门中计算i的权值矩阵和偏置W_xf,W_hf,b_f : 遗忘门的权值矩阵和偏置W_xo,W_ho,b_o : 输出门的权值矩阵和偏置W_xc,W_hc,b_c : 输入门中计算c_tilde的权值矩阵和偏置W_hq,b_q : 输出层的权值矩阵和偏置"""#上一时刻的输出 H 和 单元状态 C。(H,C) = state# 遗忘门F = torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_fF = torch.sigmoid(F)# 输入门I = torch.sigmoid(torch.matmul(X,W_xi)+torch.matmul(H,W_hi) + b_i)C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)C = F * C + I * C_tilde# 输出门O = torch.sigmoid(torch.matmul(X,W_xo)+torch.matmul(H,W_ho) + b_o)H = O * C.tanh()# 输出层Y = torch.matmul(H,W_hq) + b_qreturn Y,(H,C)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/644121.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【复现代码——环境配置】

目录 一、复现代码举例二、创建环境——选择一个Python版本2.1 创建基本环境2.1.1 基于AutoDL2.1.2 基于PyCharm 2.2 终端激活环境2.3 退出环境2.4 删除环境 三、PyTorch安装3.1 查看cuda3.2 安装PyTorch 四、其他依赖安装4.1 tensorboardX4.2 matplotlib4.3 medpy4.4 visdom4.…

【Day 8】MySQL 多表查询 + Mybatis 基础

1 多表查询 笛卡尔积:在数学中,两个集合(A集合 和 B集合)的所有组合情况 在多表查询时,需要消除无效的笛卡尔积 select * from tb_emp,tb_dept where dept_id tb_dept.id;多表查询分为: 连接查询 内连接:相当于查…

5-内核开发-/proc File System 学习

5-内核开发-/proc File System 学习 课程简介: Linux内核开发入门是一门旨在帮助学习者从最基本的知识开始学习Linux内核开发的入门课程。该课程旨在为对Linux内核开发感兴趣的初学者提供一个扎实的基础,让他们能够理解和参与到Linux内核的开发过程中。…

移动零 ----双指针

题目链接 题目: 分析: 上述题目, 是将数组分块, 分为前半非零, 后半零, 这种数组分块题我们首先想到双指针 思路: 定义两个指针, 一个cur 一个dest, cur用来遍历数组, dest 指向分界处的第一个零位置, 将数组分块首先让cur 0; dest 0;cur 遍历数组, 如果cur 0, 那么cur…

python+django校园社交高校交友网站2x7r5.

本课题使用Python语言进行开发。代码层面的操作主要在PyCharm中进行,将系统所使用到的表以及数据存储到MySQL数据库中,方便对数据进行操作本课题基于WEB的开发平台,设计的基本思路是: 前端:vue.jselementui 框架&#…

后端工程师——Java工程师岗位要求

在国内,Java 程序员是后端开发工程师中最大的一部分群体,其市场需求量也是居高不下,C++ 程序员也是热门岗位之一,此二者的比较也常是热点话题,例如新学者常困惑的问题之一 —— 后端开发学 Java 好还是学 C++ 好。读完本文后,我们可以从自身情况、未来的发展,岗位需求量…

SD-WAN制造业网络优化方案

制造业在数字化浪潮的推动下,进行转型的需求越来越强烈。网络作为制造业数字化转型的关键基础设施,其稳定性、安全性和灵活性直接影响着企业的运营效率和市场竞争力。而SD-WAN可以为制造业提供有效的解决方案,让制造业顺利高效地进行数字化转…

揭开六西格玛培训真实面貌,为何它仍是企业优选

近年来,网络上时常有声音称六西格玛培训已经过时,不再适应当今快速变化的商业环境。然而,这种观点并不全面,也未能深入理解六西格玛管理的核心价值和现代应用。事实上,六西格玛作为一种以数据为驱动、旨在减少缺陷和提…

基于Springboot的幼儿园管理系统

基于SpringbootVue的幼儿园管理系统的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录 用户管理 教师管理 幼儿园信息管理 班级信息管理 工作日志管理 会议记录管理…

【AI写作】未来科技趋势:揭秘DreamFusion的革新力量

首先,这篇文章是基于笔尖AI写作进行文章创作的,喜欢的宝子,也可以去体验下,解放双手,上班直接摸鱼~ 按照惯例,先介绍下这款笔尖AI写作,宝子也可以直接下滑跳过看正文~ 笔尖Ai写作:…

如何在Facebook上发布广告?

在广告管理工具中创建广告 创建广告系列和广告组。在广告名称文本框中输入描述性名称。选择代表您业务的Facebook 公共主页和Instagram 帐户。 所有广告都必须具有关联的Facebook 公共主页。选择广告格式。 选择素材。 您可能还会看到其他选项,具体取决于您先前所做…

详细分析MySQL中的distinct函数(附Demo)

目录 前言1. 基本知识2. 基础Demo3. 进阶Demo 前言 该函数主要用于去重,对于细节知识,此文详细补充说明 1. 基本知识 DISTINCT 是一种用于查询结果中去除重复行的关键字 在查询数据库时,可能会得到重复的结果行,但有时只需要这…