文本分类TextRNN_Att模型(pytorch实现)

TextRNN_Att

        • TextRNN-Att简介
        • 模型结构:
        • pytorch代码实现:

TextRNN-Att简介

TextRNN前面已经介绍过了,主体结构就是一个双向/单向的LSTM层,由于LSTM获得每个时间点的输出信息之间的“影响程度”都是一样的,而在关系分类中,为了能够突出部分输出结果对分类的重要性,引入加权的思想。而本篇模型在LSTM层之后引入了attention层,其实就是对lstm每刻的隐层进行加权平均。

在这里插入图片描述

模型结构:
  • 输入层:输入是一个一个的句子,通过对它进行划分batch,sentence,然后进行编码

  • 词嵌入层:将文本中的离散词汇表示(如单词或者字符)转换为连续的实值向量表示,也称为词嵌入(Word Embedding)。这些实值向量具有语义信息,能够捕捉词汇之间的语义关系,从而提供更丰富的特征表示。

  • LSTM层:双向LSTM是RNN的一种改进,其主要包括前后向传播,每个时间点包含一个LSTM单元用来选择性的记忆、遗忘和输出信息。模型的输出包括前后向两个结果,通过拼接作为最终的Bi-LSTM输出。公式如下:

  • 注意力层:对lstm每刻的隐层进行加权平均,将词级别的特征合并到句子级别的特征。

M = tanh ⁡ ( H ) M=\tanh \left(H \right) M=tanh(H)

α = s o f t max ⁡ ( W T M ) \alpha =soft\max \left(W^TM \right) α=softmax(WTM)

r = H α T r=H\alpha ^T r=HαT

  • 输出层:将句子层级的特征用于关系分类。
pytorch代码实现:
  1. 模型输入: [batch_size, seq_len]
  2. 经过embedding层:加载预训练词向量或者随机初始化, 词向量维度为embed_size: [batch_size, seq_len, embed_size]
  3. 双向LSTM:隐层大小为hidden_size,得到所有时刻的隐层状态(前向隐层和后向隐层拼接) [batch_size, seq_len, hidden_size * 2]
  4. 初始化一个可学习的权重矩阵w=[hidden_size * 2, 1]
  5. 对LSTM的输出进行非线性激活后与w进行矩阵相乘,并经行softmax归一化,得到每时刻的分值:[batch_size, seq_len, 1]
  6. 将LSTM的每一时刻的隐层状态乘对应的分值后求和,得到加权平均后的终极隐层值[batch_size, hidden_size * 2]
  7. 对终极隐层值进行非线性激活后送入两个连续的全连接层[batch_size, num_class]
  8. 预测:softmax归一化,将num_class个数中最大的数对应的类作为最终预测[batch_size, 1]
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self):self.model_name = 'TextRNN_Att'self.dropout = 0.5  # 随机失活self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = 10 # 类别数self.n_vocab = 10000  # 词表大小,在运行时赋值self.num_epochs = 10  # epoch数self.batch_size = 128  # mini-batch大小self.pad_size = 32  # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3  # 学习率self.embed =  300  # 字向量维度, 若使用了预训练词向量,则维度统一self.hidden_size = 128  # lstm隐藏层self.num_layers = 2  # lstm层数self.hidden_size2 = 64class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,bidirectional=True, batch_first=True, dropout=config.dropout)self.tanh1 = nn.Tanh()self.w = nn.Parameter(torch.zeros(config.hidden_size * 2))self.tanh2 = nn.Tanh()self.fc1 = nn.Linear(config.hidden_size * 2, config.hidden_size2)self.fc = nn.Linear(config.hidden_size2, config.num_classes)def forward(self, x):x, _ = x# 词嵌入层emb = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]# LSTM层H, _ = self.lstm(emb)  # [batch_size, seq_len, hidden_size * num_direction]=[128, 32, 256]# 注意力层M = self.tanh1(H)  # [128, 32, 256]alpha = F.softmax(torch.matmul(M, self.w), dim=1).unsqueeze(-1)  # [128, 32, 1]out = H * alpha  # [128, 32, 256]#输出层out = torch.sum(out, 1)  # [128, 256]out = F.relu(out)  # [128, 256]out = self.fc1(out)  # [128, 64]out = self.fc(out)  # [128, 10]return outconfig=Config()
model=Model(config)
print(model)

输出:

Model((embedding): Embedding(10000, 300, padding_idx=9999)(lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)(tanh1): Tanh()(tanh2): Tanh()(fc1): Linear(in_features=256, out_features=64, bias=True)(fc): Linear(in_features=64, out_features=10, bias=True)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/706768.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】缓冲区

目录 一、初识缓冲区 二、用户级缓冲区 三、内核级缓冲区 四、内核级缓冲区 VS 用户级缓冲区 五、用户级缓冲区在哪里? 一、初识缓冲区 缓冲区是什么?可以简单理解成一部分内存。例如用户缓冲区(char arr[])、C标准库提供的缓冲区、操作系统提供的缓…

【Transformer-BEV编码(9)】Sparse4D v2 v3源代码分析。稀疏感知方向新的baseline,相机参数泛化能力差的问题。

前言: 基于BEV的稠密融合算法或许并不是最优的多摄融合感知框架。同时特征级的多摄融合也并不等价于BEV。这两年,PETR系列(PETR, PETR-v2, StreamPETR) 也取得了卓越的性能,并且其输出空间是稀疏的。在PETR系列方法中,对于每个in…

智能边缘计算 | 2024高通边缘智能创新应用大赛赛道解读

随着物联网设备的普及和数据的井喷式增长,用户对数据处理的效率要求进一步提升,而边缘设备的计算能力日益增强,在边缘端完成复杂计算已经成为可能。 除降低时延与减少宽带资源占用外,边缘计算在离数据源更接近的地方完成数据处理…

SpringBoot环境隔离Profiles

前言 通常我们开发不可能只有一个生产环境,还会有其它的开发,测试,预发布环境等等。为了更好的管理每个环境的配置项,springboot也提供了对应的环境隔离的方法。 直接上干货 知识点 激活环境方法 1,在application…

Aspose.PDF功能演示:在 JavaScript 中将 TXT 转换为 PDF

您是否正在寻找一种在 JavaScript 项目中将纯文本文件从TXT无缝转换为PDF格式的方法?您来对地方了!无论您是要构建 Web 应用程序、创建生产力工具,还是只是希望简化工作流程,直接从 JavaScript 代码中将 TXT 转换为 PDF 的功能都可…

Windows 安装mysql 和 Redis

mysql Windows 图形界面安装: 下载mysql https://dev.mysql.com/downloads/ 1.下载完成后,找到文件双击安装程序 2. 等待一段时间, 选择默认,点击next 3. 选择安装目录 下载mysql产品 安装mysql产品 产品配置向导 安装…

相关的形态

相关的形态可以分为完全线性相关、线性相关、非线性相关和不相关。 (a)中的观测点恰好落在一条直线上,表示了两个变量之间是一一对应的函数关系,可以用直线方程来准确描述这两个变量的关系,称为完全线性相关。 (b)中观测点散落在一条直线周…

多联机常见各部件功能及常见机组制冷原理图

一、各部件名称和主要功能 1、压缩机 压缩机根据实际系统需要,调整其转速达到节能目的。 2、压缩机油温加热带 在待机状态下,保证压缩的油温确再启动可靠性。 3、压缩机 排气 感温包 检测压缩机的排气温度,达到控制和保护目的。 4、高压开…

mybaties查询!!!你就说灵不灵活吧

你就说灵不灵活吧 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> <mapper namespace"com.ruoyi.sys…

浅析扩散模型与图像生成【应用篇】(二十五)——Plug-and-Play

25. Plug-and-Play: Diffusion Features for Text-Driven Image-to-Image Translation 该文提出一种文本驱动的图像转换方法&#xff0c;输入一张图像和一个目标文本描述&#xff0c;按照文本描述对输入图像进行转换&#xff0c;得到目标图像。图像转换任务其实本质上属于图像编…

web安全学习笔记(16)

记一下第27-28课的内容。Token 验证 URL跳转漏洞的类型与三种跳转形式&#xff1b;URL跳转漏洞修复 短信轰炸漏洞绕过挖掘 一、token有关知识 什么是token&#xff1f;token是用来干嘛的&#xff1f;_token是什么意思-CSDN博客 二、URL跳转漏洞 我们在靶场中&#xff0c;…

JVS物联网模拟点位:如何配置并自动生成点位数据全教程

模拟点位 功能描述 模拟点位常用于业务的调试或数据展示&#xff0c;通过配置对应点位实现自动生成点位数据的功能。 界面操作 如下图所示&#xff0c;从模拟点位菜单进入模拟点位管理界面 模拟点位新增 点击新增按钮&#xff0c;如下图所示&#xff1a; ①&#xff1a;用户…