【PyTorch实战演练】基于全连接网络构建RNN并生成人名

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文基于PyTorch中的全连接模块 nn.Linear() 构建RNN,并使用人名数据训练RNN,最后使用RNN生成人名。

1. RNN简介

循环神经网络(Recurrent Neural Network,简称RNN)是一种深度学习模型,我在之前的文章介绍过RNN的结构及算法基础——基于Numpy构建RNN模块并进行实例应用(附代码)这里不再赘述。

RNN独特之处在于它能够处理序列数据,并且在处理过程中,上一时刻的隐藏状态h_{t-1}会作为当前时刻的一部分输入。这种结构使得RNN具有捕捉和处理时间序列数据中长期依赖关系的能力,非常适合于自然语言处理、语音识别、音乐生成等各种涉及时序数据的任务。

这里再介绍下根据输入输出长度的RNN分类:

  • 1 vs 1 RNN: 在这种结构中,网络接受一个单一时间步长的输入,并产生一个单一时间步长的输出。这通常用于处理不需要考虑时序依赖或只需要针对单个输入元素生成单个输出元素的任务,例如情感分析或文本分类,其中每个样本代表整个输入。

  • 1 vs N RNN: 这种结构接收一个单独的时间步长作为输入,但产生一个包含多个时间步长的输出序列。例如,在音乐创作或者生成任务中,模型可能从一个音符开始,然后预测接下来的一系列音符,形成一段旋律。

  • N vs 1 RNN: 此类RNN接受一个包含多个时间步长的输入序列,但只输出一个单一的总结性结果。例如,在文本摘要任务中,模型读取一整段文本(可能是多个句子),然后生成一个简洁的总结;或者是语音识别任务,输入是一段音频信号,输出是识别出的一个词或一句话。

  • N vs N RNN: 它处理同样长度不同的输入和输出序列。比如在机器翻译中,源语言句子被转换为目标语言句子,两者长度往往不等,但经过处理后成为等长的序列。每一个时间步,RNN都会基于之前的隐藏状态和当前输入计算新的隐藏状态,并输出对应位置的预测值。Seq2Seq模型中的Encoder-Decoder结构就是典型的N vs N结构,其中Encoder将输入序列编码为固定长度的上下文向量,而Decoder则根据该上下文向量逐步解码出目标序列。

2. 实例说明

本文使用人名数据,具体来说是日文中的姓氏数据来训练RNN。训练后,给定首写字母使用训练好的RNN生成名字的剩余部分。

为什么使用小日子的姓氏呢?因为我只有这个数据。。。

2.1 训练数据

日文中的姓氏:

总共992个名字。

需要源文件可以评论留下邮箱

2.2 数据导入及处理

这一步需要把原始数据导入成一维列表,并且在每个名字后加上“!”作为结束符号(好让RNN知道什么时候停止)

import unicodedata
import string
from io import openall_letters = string.ascii_letters+'!'
n_letters = len(all_letters)+1name_path = 'names.txt'names = open(name_path, encoding='utf-8').read().strip().split('\n')
names_with_endmark=[]
for name in names:names_with_endmark.append(name + '!')
# print(names_with_endmark)

生成的一维列表为:

['Abe!', 'Abukara!', 'Adachi!', 'Aida!'... 'Yuhara!', 'Yunokawa!']
2.3 onehot编码

One-hot编码是一种将分类变量或离散特征转换为数值型数据的常用方法,在机器学习和深度学习领域中广泛应用。它通过创建一个“独热”向量来表示每个类别,该向量的长度等于所有可能类别的总数,且向量中只有一个位置(对应类别所在的位置)的值为1,其他所有位置的值均为0。

在本文实例中共使用53个字符:

abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!

所以字母‘a’的onehot编码为[1, 0, 0, 0.....,0],‘b’的onehot编码为[0, 1, 0, 0.....,0],以此类推。

本文实例使用N vs N RNN结构,需要进行两种onehot编码:

def input_onehot(name):onehot_tensor = torch.zeros(len(name),1,n_letters)for i in range(len(name)):onehot_tensor[i][0][all_letters.find(name[i])] = 1return onehot_tensordef target_onehot(name):onehot = []for i in range(1, len(name)):onehot.append(all_letters.find(name[i]))onehot.append(all_letters.find('!'))onehot_tensor = torch.tensor(onehot)return onehot_tensor

用‘Abe’这个名字举例来说,input_onehot就是对应[Abe]的onehot向量(训练输入),而target_onehot就是对应[be!]的onehot向量(训练输出目标)。

还需要定义一个onehot解码的函数,用于把训练后的onehot向量转回字母:

def onehot_letter(onehot):  #onehot编码转letter_,letter_index = torch.topk(onehot,k=1)return all_letters[letter_index]
2.4 RNN构建

本文构建的RNN是基于RNN的改进版,改进前后的对比原理图如下:

首先我们看下用全连接层 nn.Linear() 构建正常RNN的方法:

class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.Linear(input_size + hidden_size, hidden_size) # 输入到隐藏层的权重和偏置self.h2o = nn.Linear(hidden_size, output_size) # 隐藏层到输出层的权重和偏置self.activation = nn.Tanh()  # 非线性激活函数,这里使用tanhdef forward(self, input_step, hidden_state):combined_input = torch.cat((input_step, hidden_state), dim=1)hidden_state = self.activation(self.i2h(combined_input))output = self.h2o(hidden_state)return output, hidden_state

本文的改进结构为:

class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.h0 = torch.zeros(1, self.hidden_size)# i2h input → hidden,hidden理解为语义# i2o input → output# o2o output→ outputself.i2h = nn.Linear(input_size + hidden_size, hidden_size)self.i2o = nn.Linear(input_size + hidden_size, output_size)self.o2o = nn.Linear(hidden_size + output_size, output_size)self.dropout = nn.Dropout(0.1) #抑制过拟合self.softmax = nn.LogSoftmax(dim=1)def forward(self, input, hidden):input_combined = torch.cat(( input, hidden), 1)hidden = self.i2h(input_combined)output = self.i2o(input_combined)output_combined = torch.cat((hidden, output), 1)output = self.o2o(output_combined)output = self.dropout(output)output = self.softmax(output)return output, hidden

我们可以把隐藏层输出理解为“语义”,本文的改进目的是让最终输出不仅考虑t-1时刻的隐藏层输出(语义),也把t时刻的隐藏层输出纳入考虑。

3. 模型训练

训练相关参数设定如下:

criterion = nn.NLLLoss()   #Negative Log Likelihood Loss,即负对数似然损失。
opt = torch.optim.SGD(params=rnn.parameters(),lr = 5e-4)    #随机梯度下降优化方法
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=100, last_epoch= -1)  #增加余弦退火自调整学习率
epoch = 1000

训练过程如下:

这里loss值很大是因为取的所有992个名字的loss总和。

4. 验证结果

以字母'A'开头,使用训练好的RNN输出的名字为'Aso',即あそ,麻生(或阿苏,两个姓氏同音)。

5. 完整代码

5.1 训练组
from io import open
import unicodedata
import string
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as pltall_letters = string.ascii_letters+'!'
n_letters = len(all_letters)+1name_path = 'names.txt'names = open(name_path, encoding='utf-8').read().strip().split('\n')
names_with_endmark=[]
for name in names:names_with_endmark.append(name + '!')
# print(names_with_endmark)Ascii_names = []   #把names格式转为Ascii
for name in names_with_endmark:Ascii_names.append(''.join(letter for letter in unicodedata.normalize('NFD',name) if unicodedata.category(letter) != 'Mn' and letter in all_letters))#上面这行代码解读:
#         1. unicodedata.normalize('NFD', name):对输入的字符串name进行NFD(Normalization Form D)标准化。NFD将每个字符分解为其基本形式和所有可分解的组合标记。
#         2. letter for letter in ...:这是一个生成器表达式,它会遍历经过NFD标准化后的字符串name中的每一个字符letter。
#         3. if unicodedata.category(letter) != 'Mn':检查每个字符c的Unicode类别是否不等于'Mn'。'Mn'代表"Mark, Non-Spacing",即非-spacing组合标记,这些标记不占据自己的空间位置,而是附加在其他字符上改变其样式或语意。
#         4. ''.join(...):将所有满足条件(非'Mn'类别)的字符连接成一个新的字符串。由于连接符是空字符串'',所以结果是一个没有分隔符的连续字符串。
# print(Ascii_names)   #这里和上面pring(names_with_endmark)输出结果看不出差别,因为只是编码方式不同def input_onehot(name):onehot_tensor = torch.zeros(len(name),1,n_letters)for i in range(len(name)):onehot_tensor[i][0][all_letters.find(name[i])] = 1return onehot_tensordef target_onehot(name):onehot = []for i in range(1, len(name)):onehot.append(all_letters.find(name[i]))onehot.append(all_letters.find('!'))onehot_tensor = torch.tensor(onehot)return onehot_tensor# print(input_onehot('Arai'))
# print(target_onehot('Arai'))
class RNN(nn.Module):   #注意,这不是完全意义上的RNNdef __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.h0 = torch.zeros(1, self.hidden_size)# i2h input → hidden,hidden理解为语义# i2o input → output# o2o outputself.i2h = nn.Linear(input_size + hidden_size, hidden_size)self.i2o = nn.Linear(input_size + hidden_size, output_size)self.o2o = nn.Linear(hidden_size + output_size, output_size)self.dropout = nn.Dropout(0.1) #抑制过拟合self.softmax = nn.LogSoftmax(dim=1)def forward(self, input, hidden):input_combined = torch.cat(( input, hidden), 1)hidden = self.i2h(input_combined)output = self.i2o(input_combined)output_combined = torch.cat((hidden, output), 1)output = self.o2o(output_combined)output = self.dropout(output)output = self.softmax(output)return output, hiddenif __name__ == '__main__':rnn = RNN(n_letters, 128, n_letters)criterion = nn.NLLLoss()   #Negative Log Likelihood Loss,即负对数似然损失。opt = torch.optim.SGD(params=rnn.parameters(),lr = 5e-4)    #随机梯度下降优化方法scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=100, last_epoch= -1)  #增加余弦退火自调整学习率epoch = 1000def train(input_name_tensor, target_name_tensor):target_name_tensor.unsqueeze_(-1)hidden = rnn.h0   #对h0进行初始化opt.zero_grad()name_loss = 0for i in range(input_name_tensor.size(0)):output, hidden = rnn(input_name_tensor[i], hidden)loss = criterion(output, target_name_tensor[i])name_loss += lossname_loss.backward()  #对整个名字的loss进行backwardopt.step()return name_lossfor e in tqdm(range(epoch)):total_loss = 0for name in Ascii_names:total_loss = total_loss + train(input_onehot(name),target_onehot(name))print(total_loss)plt_loss = total_loss.detach()plt.scatter(e, plt_loss, s=2, c='r')scheduler.step()torch.save(rnn.state_dict(), 'weight/epoch=1000--initial_lr=5e-4.pth')  #保存训练好的权重plt.xlabel('epoch')plt.ylabel('loss')plt.show()
5.2 验证组
import torch
from rnn_main import RNN, input_onehot
import stringall_letters = string.ascii_letters+'!'
n_letters = len(all_letters)+1rnn_predict = RNN(n_letters, 128, n_letters)
rnn_predict.load_state_dict(state_dict=torch.load('weight/epoch=1000--initial_lr=5e-4.pth'))def onehot_letter(onehot):  #onehot编码转letter_,letter_index = torch.topk(onehot,k=1)return all_letters[letter_index]rnn_predict.eval()current_letter_onehot = input_onehot('A').squeeze(0)
current_letter = onehot_letter(current_letter_onehot)
hpre = rnn_predict.h0
full_name = ''
while current_letter != '!':  #判断是不是该结束了full_name = full_name + current_letterpredict_onehot, hcur = rnn_predict(current_letter_onehot, hpre)hpre = hcurcurrent_letter_onehot = predict_onehotcurrent_letter = onehot_letter(current_letter_onehot)
print(full_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/336733.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringIOC之support模块FileSystemXmlApplicationContext

博主介绍:✌全网粉丝5W,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战,博主也曾写过优秀论文,查重率极低,在这方面有丰富的经验…

【samba】Ubuntu20.04安装 error255解决方法

目录 使用samba报错 net usershare returned error 255时(如下图)解决方法如下: 1、安装 Samba 服务: 2、配置 Samba 共享: 3、设置 Samba 用户密码: 4、重启 Samba 服务: 6、在 Windows 上…

用通俗易懂的方式讲解:一文讲清大模型 RAG 技术全流程

目录 一、为什么业界普遍关注RAG?通俗易懂讲解大模型系列技术交流 二、RAG技术要怎么干?(1)安装pdf解析库(2)检索引擎准备工作将文本片段灌入检索引擎实现关键字检索 (3)LLM 接口封装…

在Linux中使用Apache HTTP服务器

Apache HTTP服务器,也被称为Apache,是全球使用最广泛的Web服务器软件之一。它以其稳定性、强大的功能和灵活性而闻名,尤其在Linux操作系统上表现得尤为出色。以下是关于如何在Linux中使用Apache HTTP服务器的详细指南。 1. 安装Apache 首先…

Yolov4重大的更新,结构组件

YOLO之父在2020年初宣布退出CV界,YOLOv4 的作者并不是YOLO系列 的原作者。YOLO V4是YOLO系列一个重大的更新,其在COCO数据集上的平均精度(AP)和帧率精度(FPS)分别提高了10% 和12%,并得到了Joseph Redmon的官方认可,被认为是当前最…

护眼台灯有AAA级吗?国家AA级护眼灯推荐

儿童的近视年龄是越来越小,在我国儿童以及青少年总体的近视率为52.7%,在想到自己两个孩子的也有近视的预兆,我们应该对近视低龄化的现象感到警惕, 在日常生活中学习时的环境光线过亮或过暗以及不好的用眼习惯,都可能诱…

coredump+gdb调试

1、什么是coredump Coredump(核心转储)是操作系统在程序异常终止(例如由于段错误或其他严重错误)时创建的一种文件。这个文件包含了程序崩溃时刻进程的内存镜像,通常还包括程序计数器、寄存器内容和堆栈内存等信息&am…

基于模块自定义扩展字段的后端逻辑实现(一)

目录 一:背景介绍 二:实现过程 三:字段标准化 四:数据存储 五:数据扩展 六:表的设计 一:背景介绍 最近要做一个系统,里面涉及一个模块是使用拖拉拽的形式配置模块使用的字段表…

bilibi分类id的秘密

问题 今天想通过rss来阅读bilibili的相关信息,但是如何获取排行榜的分类呢?研究了一下。 办法 浏览器最喜欢的F12,过滤关键才v2?rid,后面的数字就是分类id。 rss获取路径 [最后的数字是0,是所有投稿,数字是1的话是…

失去记忆的朱令对父亲说:如果你不照顾我,就再也没有人可以了

这句话深深触动了朱父和朱母,他们最害怕的就是:除了他们,还有谁会如此细心地照料女儿?他们担心有一天女儿苏醒,他们却无法再支撑自己。 这样的苦难并没有击垮两位老人,时间的流逝是最无情的。随着年岁的增长…

LabVIEW在微生物检测中的应用

随着对食品安全关注的增加,食品检测的准确性变得越来越重要。其中,微生物计数作为食品合格的关键指标,对其检测技术的准确性和实时性要求极高。传统的微生物检测面临着菌落识别困难、设备实时性差和自动化程度不高等问题,尤其在疫…

K8S的存储卷---数据卷

容器内的目录和宿主机的目录进行挂载 容器在系统上的生命周期是短暂的。delete,K8S用控制器创建的pod,delete相当于重启,容器的状态也会恢复到初始状态。一旦回到初始状态,所有的后天编辑的文件都会消失 容器和节点之间创建一个…