文档分类DPCNN简介(pytorch实现)

文档分类DPCNN简介

        • DPCNN简介
      • 模型结构
          • 区域嵌入
          • 等长卷积
          • 1/2池化
          • DPCNN模型代码实现

DPCNN简介

论文中提出了一种基于 word-level 级别的网络-DPCNN,由于 TextCNN 不能通过卷积获得文本的长距离依赖关系,而论文中 DPCNN 通过不断加深网络,可以抽取长距离的文本依赖关系。

实验证明在不增加太多计算成本的情况下,增加网络深度就可以获得最佳的准确率。‍

前面我们提到过TextRCNN就是将CNN中的池化加入到RNN中,来解决RNN是一个有的偏倚,现在DPCNN通过不断加深网络,来弥补自身短缺的长距离依赖问题,可见每一种模型都不是十全十美的,只有不断探索,不断创新,相互借鉴,才能够使性能进一步提升。

模型结构

在这里插入图片描述

区域嵌入

这里是将TextCNN的包含多尺寸卷积滤波器的卷积层的卷积结果称之为区域嵌入,即对一个文本区域文本片段(比如3-gram)进行一组卷积操作后生成的embedding。这里不同于textCNN的二维卷积,DPCNN采用的是一维卷积。以3-gram为例子回顾textCNN,设置了一个大小为3xD的二维卷积核进行卷积操作(其中D是词嵌入的维度),其实这是一种保留词序的做法。那么对于DPCNN,采用的是不保留词序的做法,即:首先对3-grm中的3个词的词向量取均值得到一个大小为1xD的向量,然后设置一组大小为1*D的一维卷积核对该3-grm进行卷积操作。

等长卷积

经过区域嵌入后,是两层卷积层,这里采用的是等长卷积,以此来提高词位embedding的表示的丰富性。首先先介绍一下三种卷积的概念:

假设输入的序列长度为 n,卷积核大小为 m,步长(stride)为 s,输入序列两端各填补 p 个零(zero padding),那么该卷积层的输出序列为 (n-m+2p)/s+1。

  1. 窄卷积:步长 s=1,两端不补零,即 p=0,卷积后输出长度为 n-m+1。
  2. 宽卷积:步长 s=1,两端补零 p=m-1,卷积后输出长度 n+m-1。
  3. 等长卷积:步长 s=1,两端补零 p=(m-1)/2,卷积后输出长度为 n。

输入输出序列的位置数一样多,即为等长卷积,该卷积的意义是:输出的词是由该位置输入的词以及其左右词的上下文信息提取得到的,也就是说,这个词包含被上下文信息修饰过的更高级别的语义。

1/2池化

本文使用一个 size=3,stride=2(大小为3,步长为2)的池化层进行最大池化,在此称为1/2池化层。每经过一个1/2池化层,序列的长度就被压缩成了原来的一半。因此,经过1/2池化后,同样一个size为3的卷积核,其能够感知到的文本片段就比之前长了一倍。
在堆叠多层卷积池化层之后,就得到了加深的可以抽取长距离的文本依赖关系的网络。最后的池化层把每段文本聚合为一个向量。

主要区别在于输入层由无监督词嵌入层作为输入,把文档的每个词的词向量作出二维数组作为输入;卷积层有两个卷积层组成,卷积层输入通过跳跃连接,恒等映射和卷积层输出相加作为卷积层输出;采样层以尺度大小为2进行下采样,达到尺度缩放的目的。堆叠几层卷积层和采样层,形成尺度缩放金字塔,达到维度缩放的目的。最终将卷积层输出拼接成向量通过隐藏层和softmax层作为输出分类。

DPCNN模型代码实现
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self):self.dropout = 0.5  # 随机失活self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练self.num_classes =10 # 类别数self.n_vocab = 10000  # 词表大小,在运行时赋值self.num_epochs = 20  # epoch数self.batch_size = 128  # mini-batch大小self.pad_size = 32  # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3  # 学习率self.embed = 300  # 字向量维度self.num_filters = 250  # 卷积核数量(channels数)'''Deep Pyramid Convolutional Neural Networks for Text Categorization'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.conv_region = nn.Conv2d(1, config.num_filters, (3, config.embed), stride=1)self.conv = nn.Conv2d(config.num_filters, config.num_filters, (3, 1), stride=1)self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2)# (pad_left, pad_right, pad_top, pad_bottom)填充self.padding1 = nn.ZeroPad2d((0, 0, 1, 1))  # top bottomself.padding2 = nn.ZeroPad2d((0, 0, 0, 1))  # bottomself.relu = nn.ReLU()self.fc = nn.Linear(config.num_filters, config.num_classes)def forward(self, x):x = x[0]  # torch.Size([128, 32])x = self.embedding(x)  # torch.Size([128, 32,300])x = x.unsqueeze(1)  # torch.Size([128, 1, 32, 300])x = self.conv_region(x)  # torch.Size([128, 250, 30, 300])x = self.padding1(x)  # [128, 250, 32, 1]x = self.relu(x)  # [128, 250, 32, 1]x = self.conv(x)  # [125, 250, 30, 1]x = self.padding1(x)  # [128, 250, 32, 1]x = self.relu(x)  # [128, 250, 32, 1]x = self.conv(x)  # [128, 250, 30, 1]while x.size()[2] > 2:x = self._block(x)# print("x10", x)#torch.Size([128, 250, 1, 1])x = x.squeeze()  # [128, 250]x = self.fc(x)  # [128, 10]return xdef _block(self, x):x = self.padding2(x)px = self.max_pool(x)x = self.padding1(px)x = F.relu(x)x = self.conv(x)x = self.padding1(x)x = F.relu(x)x = self.conv(x)x = x + pxreturn xconfig=Config()
model=Model(config)
print(model)

输出:

Model((embedding): Embedding(10000, 300, padding_idx=9999)(conv_region): Conv2d(1, 250, kernel_size=(3, 300), stride=(1, 1))(conv): Conv2d(250, 250, kernel_size=(3, 1), stride=(1, 1))(max_pool): MaxPool2d(kernel_size=(3, 1), stride=2, padding=0, dilation=1, ceil_mode=False)(padding1): ZeroPad2d((0, 0, 1, 1))(padding2): ZeroPad2d((0, 0, 0, 1))(relu): ReLU()(fc): Linear(in_features=250, out_features=10, bias=True)
)

参考:
https://blog.csdn.net/sikh_0529/article/details/126912490
https://blog.csdn.net/qq_43592352/article/details/122764889

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/706846.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RTMP低延迟推流

人总是需要压力才能进步, 最近有个项目, 需要我在RK3568上, 推流到公网, 最大程度的降低延迟. 废话不多说, 先直接看效果: 数据经过WiFi发送到Inenter的SRS服务器, 再通过网页拉流的. 因为是打金任务, 所以逼了自己一把, 把RTMP推流好好捋一遍. 先说说任务目标, 首先是MPP编码…

按照官网引擎问题重新设置监控目录,仍然存在空三等待的问题怎么办?

答:任务目录和引擎目录设置一样,然后取消任务重新写入. 重建大师是一款专为超大规模实景三维数据生产而设计的集群并行处理软件,输入倾斜照片,激光点云,POS信息及像控点,输出高精度彩色网格模型&#xff0…

Iphone更换后摄像头蓝光珠

拆蓝光珠 风枪加热240℃,风速70,直接融化掉蓝光珠 清除残胶 风枪加热140℃,风速50 更换新的蓝光珠 点UV胶紫外灯加固,防止晃动 安装完毕!

RT-Thread的 FAL 组件_使用笔记

RT-Thread的FAL分区表组件 1、FAL介绍 FAL (Flash Abstraction Layer) Flash 抽象层,是对 Flash 及基于 Flash 的分区进行管理、操作的抽象层,对上层统一了 Flash 及 分区操作的 API (框架图如下所示),并具有以下特性: 1.1 FAL目…

【EI稳定检索|主题广泛】2024年航空航天、遥感与计算机国际会议(ARSC 2024)

2024年航空航天、遥感与计算机国际会议(ARSC 2024) 2024 International Conference on Aerospace, Remote Sensing, and Computing 【会议简介】 2024年航空航天、遥感与计算机国际会议将在古都西安召开。本次会议是航空航天、遥感与计算机领域的一次…

[牛客网]——C语言刷题day3

答案&#xff1a;A 解析&#xff1a; A.表示将数组a的首地址赋值给指针变量p B.将一个int型变量直接赋值给一个int型的指针是不行的 C.道理同B D.j2是一个右值&#xff0c;右值是不能进行取地址操作的 #include <iostream> using namespace std;#define N 7 int fun…

Ubuntu16 扩展磁盘空间

一、扩展容量 关闭虚拟机->硬盘->扩展->输入要扩展的空间大小 二、重新磁盘分区 打开虚拟机&#xff0c;在终端安装gparted&#xff1a; sudo apt-get install gparted 打开gparted&#xff1a; sudo gparted 磁盘分区如下图所示 选择/dev/sda5分区&#xff0c;选择…

web3 ETF软件开发难点

开发一个涉及到 Web3 ETF&#xff08;Exchange-Traded Fund&#xff0c;交易所交易基金&#xff09;的软件可能会面临一些挑战和难点&#xff0c;特别是在整合 Web3 技术和金融服务方面。以下是一些可能的难点。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&am…

微信在线预约系统怎么做_让您的业务更高效!

在这个数字化飞速发展的时代&#xff0c;传统的业务预约方式已经逐渐无法满足现代人的需求。随着智能手机的普及和微信用户数量的不断攀升&#xff0c;微信在线预约系统已成为许多企业和个人提升服务效率、优化客户体验的不二之选。今天&#xff0c;就让我们一起探讨微信在线预…

【Linux玩物志】Linux环境开发基本工具使用(1)——vim

W...Y的主页 &#x1f60a; 代码仓库分享&#x1f495; Linux开发工具 首先我们要知道vim是什么&#xff1f; vi&#xff08;Visual Editor&#xff09;是由美国程序员比尔乌尔曼&#xff08;Bill Joy&#xff09;于1976年开发的&#xff0c;最初是为了在Unix系统上进行文本编…

npm install [Error]

npm install 依赖的时候报错 依赖版本问题的冲突&#xff0c;忽视即可 使用 npm install --legacy-peer-deps

2024年重庆等保测评公司有哪些?分别位于哪里?

2024年重庆等保测评公司有哪些&#xff1f;分别位于哪里&#xff1f; 【回答】&#xff1a;目前2024年重庆等保测评公司有四家&#xff0c;具体公司名称以及地址如下&#xff1a; 1、重庆信安网络安全等级测评有限公司&#xff0c;重庆市两江新区黄山大道中段55号附2号麒麟D座…