【Pytorch】学习记录分享10——TextCNN用于文本分类处理

【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理

      • 1. TextCNN用于文本分类
      • 2. 代码实现

1. TextCNN用于文本分类

具体流程:
在这里插入图片描述
在这里插入图片描述

2. 代码实现

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'TextCNN'self.train_path = dataset + '/data/train.txt'                                # 训练集self.dev_path = dataset + '/data/dev.txt'                                    # 验证集self.test_path = dataset + '/data/test.txt'                                  # 测试集self.class_list = [x.strip() for x in open(dataset + '/data/class.txt').readlines()]                                # 类别名单self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\if embedding != 'random' else None                                       # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 0.5                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)                         # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                            # epoch数self.batch_size = 128                                           # mini-batch大小self.pad_size = 32                                              # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸self.num_filters = 256                                          # 卷积核数量(channels数)'''Convolutional Neural Networks for Sentence Classification'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.convs = nn.ModuleList([nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])self.dropout = nn.Dropout(config.dropout)self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)def conv_and_pool(self, x, conv):x = F.relu(conv(x)).squeeze(3)x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):#print (x[0].shape)out = self.embedding(x[0])out = out.unsqueeze(1)out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)out = self.dropout(out)out = self.fc(out)return out

该代码对应上述的图像中的模块实现,CNN用于处理文本数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/326199.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python接口自动化-参数关联(超详细的)

前言 我们用自动化发帖之后,要想接着对这篇帖子操作,那就需要用参数关联了,发帖之后会有一个帖子的id,获取到这个id,继续操作传这个帖子id就可以了 (博客园的登录机制已经变了,不能用账号和密…

java.lang.NoSuchFieldError: No static field xxx of type I in class

问题描述 将Library编译成 aar导入到另一个项目中依赖成功,编译成功,运行打开发生了崩溃异常如下图: 原因分析: 异常错误提示找不到id为recycler_1的控件了,我的Library中recycler_1控件是在MainActivity的xml中使用的&#xff0c…

Redis高级特性和应用(慢查询、Pipeline、事务、Lua)

Redis的慢查询 许多存储系统(例如 MySQL)提供慢查询日志帮助开发和运维人员定位系统存在的慢操作。所谓慢查询日志就是系统在命令执行前后计算每条命令的执行时间,当超过预设阀值,就将这条命令的相关信息(例如:发生时间,耗时,命令的详细信息)记录下来,Redis也提供了类似…

Android开发编程从入门到精通,安卓技术从初级到高级全套教学

一、教程描述 本套教程基于JDK1.8版本,教学内容主要有,1、环境搭建,UI布局,基础UI组件,高级UI组件,通知,自定义组件,样式主题;2、四大组件,Intent&#xff0…

Excel中快速隐藏中间四位手机号或者身份证号等

注意:以下方式必须再新增一列,配合旧的一列用来对比操作,即根据旧的一列的数据源,通过新的一列的操作逻辑来生成新的隐藏数据 1、快捷方式是使用CtrlE 新建一列:手动输入第一个手机号隐藏后的号码,即在N2单…

RMAN-03002 RMAN-06059 ORA-19625

有个现场经理反馈,每天的rman备份异常,登录系统查看rman的log日志,报错信息如下 RMAN> run{ 2> backup filesperset 50 archivelog all format /backup/ARCHBAK_%d_%T_%s tag arch_bak delete all input; 3> } 4> Starting …

微信小程序的驾校预约管理系统

🍅点赞收藏关注 → 私信领取本源代码、数据库🍅 本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目希望你能有所收获,少走一些弯路。🍅关注我不迷路🍅一 、设计说明 1.1课题背景 在I…

mysql 单表 操作 最大条数验证 以及优化

1、背景 开车的多年老司机,是不是经常听到过,“mysql 单表最好不要超过 2000w”,“单表超过 2000w 就要考虑数据迁移了”,“你这个表数据都马上要到 2000w 了,难怪查询速度慢”。 2、实验 实验一把看看… 建一张表 CREATE TABL…

MyBatisPlus学习二:常用注解、条件构造器、自定义sql

常用注解 基本约定 MybatisPlus通过扫描实体类&#xff0c;并基于反射获取实体类信息作为数据库表信息。可以理解为在继承BaseMapper 要指定对应的泛型 public interface UserMapper extends BaseMapper<User> 实体类中&#xff0c;类名驼峰转下划线作为表名、名为id的…

单元测试、系统测试、集成测试知识总结

一、单元测试的概念 单元测试是对软件基本组成单元进行的测试&#xff0c;如函数或一个类的方法。当然这里的基本单元不仅仅指的是一个函数或者方法&#xff0c;有可能对应多个程序文件中的一组函数。 单元也具有一些基本的属性。比如&#xff1a;明确的功能、规格定义&#…

Python 面向对象之多态和鸭子类型

Python 面向对象之多态和鸭子类型 【一】多态 【1】概念 多态是面向对象的三大特征之一多态&#xff1a;允许不同的对象对同一操作做出不同的反应多态可以提高代码的灵活性&#xff0c;可扩展性&#xff0c;简化代码逻辑 【2】代码解释 在植物大战僵尸中&#xff0c;有寒冰…

解决列表和元组多索引bug问题(TypeError: list indices must be integers or slices, not tuple)

在对列表和元组进行索引的时候&#xff0c;发现使用多维索引会出现以下bug: TypeError: list indices must be integers or slices, not tuple TypeError: tuple indices must be integers or slices, not tuple list: list1 [[1,2,3], [4,5,6]] m1 list1[1,0]tuple: tup…