python pytorch实现RNN,LSTM,GRU,文本情感分类

python pytorch实现RNN,LSTM,GRU,文本情感分类

数据集格式:
在这里插入图片描述
有需要的可以联系我

实现步骤就是:
1.先对句子进行分词并构建词表
2.生成word2id
3.构建模型
4.训练模型
5.测试模型

代码如下:


import pandas as pd
import torch
import matplotlib.pyplot as plt
import jieba
import numpy as np"""
作业:
一、完成优化
优化思路1 jieba
2 取常用的3000字
3 修改model:rnn、lstm、gru二、完成测试代码
"""# 了解数据
dd = pd.read_csv(r'E:\peixun\data\train.csv')
# print(dd.head())# print(dd['label'].value_counts())# 句子长度分析
# 确定输入句子长度为 500
text_len = [len(i) for i in dd['text']]
# plt.hist(text_len)
# plt.show()
# print(max(text_len), min(text_len))# 基本参数 config
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('my device:', DEVICE)MAX_LEN = 500
BATCH_SIZE = 16
EPOCH = 1
LR = 3e-4# 构建词表 word2id
vocab = []
for i in dd['text']:vocab.extend(jieba.lcut(i, cut_all=True))  # 使用 jieba 分词# vocab.extend(list(i))vocab_se = pd.Series(vocab)
print(vocab_se.head())
print(vocab_se.value_counts().head())vocab = vocab_se.value_counts().index.tolist()[:3000]  # 取频率最高的 3000 token
# print(vocab[:10])
# exit()WORD_PAD = "<PAD>"
WORD_UNK = "<UNK>"
WORD_PAD_ID = 0
WORD_UNK_ID = 1vocab = [WORD_PAD, WORD_UNK] + list(set(vocab))print(vocab[:10])
print(len(vocab))vocab_dict = {k: v for v, k in enumerate(vocab)}# 词表大小,vocab_dict: word2id; vocab: id2word
print(len(vocab_dict))import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import pandas as pd# 定义数据集 Dataset
class Dataset(data.Dataset):def __init__(self, split='train'):# ChnSentiCorp 情感分类数据集path =  r'E:/peixun/data/' + str(split) + '.csv'self.data = pd.read_csv(path)def __len__(self):return len(self.data)def __getitem__(self, i):text = self.data.loc[i, 'text']label = self.data.loc[i, 'label']return text, label# 实例化 Dataset
dataset = Dataset('train')# 样本数量
print(len(dataset))
print(dataset[0])# 句子批处理函数
def collate_fn(batch):# [(text1, label1), (text2, label2), (3, 3)...]sents = [i[0][:MAX_LEN] for i in batch]labels = [i[1] for i in batch]inputs = []# masks = []for sent in sents:sent = [vocab_dict.get(i, WORD_UNK_ID) for i in list(sent)]pad_len = MAX_LEN - len(sent)# mask = len(sent) * [1] + pad_len * [0]# masks.append(mask)sent += pad_len * [WORD_PAD_ID]inputs.append(sent)# 只使用 lstm 不需要用 masks# masks = torch.tensor(masks)# print(inputs)inputs = torch.tensor(inputs)labels = torch.LongTensor(labels)return inputs.to(DEVICE), labels.to(DEVICE)# 测试 loader
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True,drop_last=False)inputs, labels = iter(loader).__next__()
print(inputs.shape, labels)# 定义模型
class Model(nn.Module):def __init__(self, vocab_size=5000):super().__init__()self.embed = nn.Embedding(vocab_size, 100, padding_idx=WORD_PAD_ID)# 多种 rnnself.rnn = nn.RNN(100, 100, 1, batch_first=True, bidirectional=True)self.gru = nn.GRU(100, 100, 1, batch_first=True, bidirectional=True)self.lstm = nn.LSTM(100, 100, 1, batch_first=True, bidirectional=True)self.l1 = nn.Linear(500 * 100 * 2, 100)self.l2 = nn.Linear(100, 2)def forward(self, inputs):out = self.embed(inputs)out, _ = self.lstm(out)out = out.reshape(BATCH_SIZE, -1)  # 16 * 100000out = F.relu(self.l1(out))  # 16 * 100out = F.softmax(self.l2(out))  # 16 * 2return out# 测试 Model
model = Model()
print(model)# 模型训练
dataset = Dataset()
loader = data.DataLoader(dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=True)model = Model().to(DEVICE)# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)model.train()
for e in range(EPOCH):for idx, (inputs, labels) in enumerate(loader):# 前向传播,计算预测值out = model(inputs)# 计算损失loss = loss_fn(out, labels)# 反向传播,计算梯度loss.backward()# 参数更新optimizer.step()# 梯度清零optimizer.zero_grad()if idx % 10 == 0:out = out.argmax(dim=-1)acc = (out == labels).sum().item() / len(labels)print('>>epoch:', e,'\tbatch:', idx,'\tloss:', loss.item(),'\tacc:', acc)# 模型测试
test_dataset = Dataset('test')
test_loader = data.DataLoader(test_dataset,batch_size=BATCH_SIZE,collate_fn=collate_fn,shuffle=False)loss_fn = nn.CrossEntropyLoss()out_total = []
labels_total = []model.eval()
for idx, (inputs, labels) in enumerate(test_loader):out = model(inputs)loss = loss_fn(out, labels)out_total.append(out)labels_total.append(labels)if idx % 50 == 0:print('>>batch:', idx, '\tloss:', loss.item())correct=0
sumz=0
for i in range(len(out_total)):out = out_total[i].argmax(dim=-1)correct = (out == labels_total[i]).sum().item() +correctsumz=sumz+len(labels_total[i])#acc = (out_total == labels_total).sum().item() / len(labels_total)print('>>acc:', correct/sumz)

运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/234958.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2021年9月15日 Go生态洞察:TLS加密套件的自动排序机制

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

指数退避和抖动

目录 引入 OCC 添加退避机制 添加抖动机制 小结 引入 OCC 乐观并发控制&#xff08;Optimistic Concurrency Control&#xff0c;OCC&#xff09;是一种既能保证多个写入者安全地修改单个对象又能避免丢失写入的古老方法OCC具有三个优点&#xff1a;只要底层存储可用&#…

MySQL之JDBC

&#x1f495;"我像离家的孤儿,回到了母亲的怀抱,恢复了青春。"&#x1f495; 作者&#xff1a;Mylvzi 文章主要内容&#xff1a;MySQL之JDBC 一.什么是JDBC? JDBC编程就是通过Java 代码来操纵数据库 数据库编程&#xff0c; 需要数据库服务器提供一些API供程序…

FastDFS+Nginx - 本地搭建文件服务器同时实现在外远程访问「内网穿透」

文章目录 前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx 2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.…

Unity 注释的方法

1、单行注释&#xff1a;使用双斜线&#xff08;//&#xff09;开始注释&#xff0c;后面跟注释内容。通常注释一个属性或者方法&#xff0c;如&#xff1a; //速度 public float Speed;//打印输出 private void DoSomething() {Debug.Log("运行了我"); } …

【C++】异常处理 ③ ( 栈解旋 | 栈解旋概念 | 栈解旋作用 )

文章目录 一、栈解旋1、栈解旋引入2、栈解旋概念3、栈解旋作用 二、代码示例 - 栈解旋1、代码示例2、执行结果 一、栈解旋 1、栈解旋引入 C 程序 抛出异常后 对 局部变量的处理 : 当 C 应用程序 在 运行过程 中发生异常时 , 程序会跳转到异常处理程序 , 并执行一些操作以处理异…

距离向量路由协议——RIP

目录 动态路由动态路由简介为什么需要动态路由动态路由基本原理路由协议的分类 距离向量路由协议RIPv1RIP简介RIPv1的主要特征RIPv1的基本配置RIPv1配置案例被动接口单播更新使用子网地址 RIPv2RIPv2的基本配置RIPv2配置案例 RIPv2的高级配置与RIPv1的兼容性手工路由汇总触发更…

【C++】string模拟

string讲解&#xff1a;【C】String类-CSDN博客 基本框架 #pragma once #include <iostream> using namespace std; ​ namespace wzf {class string{public:// 默认构造函数string(): _str(new char[1]), _size(0), _capacity(0){_str[0] \0; // 在没有内容时仍要有终…

python中的序列

文章目录 序列类型标准类型运算符标准类型运算符序列类型运算符字符串 序列类型 字符串 列表 元组 由元组构成的列表 标准类型运算符 &#xff08;1&#xff09;按字符串大小比较 标准类型运算符 序列类型运算符 序列类型转换内建函数 注&#xff1a; &#xff08;1&#xff…

基于SpringBoot房产销售系统

摘 要 随着科学技术的飞速发展&#xff0c;各行各业都在努力与现代先进技术接轨&#xff0c;通过科技手段提高自身的优势&#xff1b;对于房产销售系统当然也不能排除在外&#xff0c;随着网络技术的不断成熟&#xff0c;带动了房产销售系统&#xff0c;它彻底改变了过去传统的…

frp 配置内网访问

frp介绍 frp 是一个开源、简洁易用、高性能的内网穿透软件&#xff0c;支持 tcp, udp, http, https 等协议。frp 项目官网是 https://github.com/fatedier/frp 下载地址&#xff1a; https://github.com/fatedier/frp/releases frp工作原理 服务端运行&#xff0c;监听一个…

AI Agents 闭门研讨会报名丨CAMEL、AutoAgents、Humanoid agents作者参与

青源Workshop丨No.27 AI Agents主题闭门研讨会 所谓AI智能体&#xff08;AI Agents&#xff09;&#xff0c;是一种能够感知环境、进行决策和执行动作的智能实体。它们拥有自主性和自适应性&#xff0c;可以依靠AI赋予的能力完成特定任务&#xff0c;并在此过程中不断对自我进行…