NLP项目实战01之电影评论分类

介绍:

欢迎来到本篇文章!在这里,我们将探讨一个常见而重要的自然语言处理任务——文本分类。具体而言,我们将关注情感分析任务,即通过分析电影评论的情感来判断评论是正面的、负面的。

展示:
训练展示如下:

在这里插入图片描述
在这里插入图片描述

实际使用如下:

请添加图片描述

实现方式:

选择PyTorch作为深度学习框架,使用电影评论IMDB数据集,并结合torchtext对数据进行预处理。

环境:

Windows+Anaconda
重要库版本信息
torch==1.8.2+cu102
torchaudio==0.8.2
torchdata==0.7.1
torchtext==0.9.2
torchvision==0.9.2+cu102

实现思路:

1、数据集
本次使用的是IMDB数据集,IMDB是一个含有50000条关于电影评论的数据集
数据如下:
请添加图片描述
请添加图片描述

2、数据加载与预处理
使用torchtext加载IMDB数据集,并对数据集进行划分
具体划分如下:

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
# Load the IMDB dataset
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)

创建一个 Field 对象,用于处理文本数据。同时使用spacy分词器对文本进行分词,由于IMDB是英文的,所以使用en_core_web_sm语言模型。
创建一个 LabelField 对象,用于处理标签数据。设置dtype 参数为 torch.float,表示标签的数据类型为浮点型。

使用 datasets.IMDB.splits 方法加载 IMDB 数据集,并将文本字段 TEXT 和标签字段 LABEL 传递给该方法。返回的 train_data 和 test_data 包含了 IMDB 数据集的训练和测试部分。
下面是train_data的输出
请添加图片描述

3、构建词汇表与加载预训练词向量

TEXT.build_vocab(train_data,max_size=25000,vectors="glove.6B.100d",unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

train_data:表示使用train_data中数据构建词汇表
max_size:限制词汇表的大小为 25000
vectors=“glove.6B.100d”:表示使用预训练的 GloVe 词向量,其中 “glove.6B.100d” 指的是包含 100 维向量的 6B 版 GloVe。
unk_init=torch.Tensor.normal_ :表示指定未知单词(UNK)的初始化方式,这里使用正态分布进行初始化。
LABEL.build_vocab(train_data):表示对标签进行类似的操作,构建标签的词汇表

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device)

使用data.BucketIterator.splits 来创建数据加载器,包括训练、验证和测试集的迭代器。这将确保你能够方便地以批量的形式获取数据进行训练和评估。

4、定义神经网络
这里的网络定义比较简单,主要采用在词嵌入层(embedding)后接一个全连接层的方式完成对文本数据的分类。
具体如下:

class NetWork(nn.Module):def __init__(self,vocab_size,embedding_dim,output_dim,pad_idx):super(NetWork,self).__init__()self.embedding = nn.Embedding(vocab_size,embedding_dim,padding_idx=pad_idx)self.fc = nn.Linear(embedding_dim,output_dim)self.dropout = nn.Dropout(0.5)self.relu = nn.ReLU()def forward(self,x):embedded = self.embedding(x)embedded = embedded.permute(1,0,2) pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)pooled = self.relu(pooled)pooled = self.dropout(pooled)output = self.fc(pooled)return output

5、模型初始化

vocab_size = len(TEXT.vocab)
embedding_dim  = 100
output = 1
pad_idx = TEXT.vocab.stoi[TEXT.pad_token]
model = NetWork(vocab_size,embedding_dim,output,pad_idx)
pretrained_embeddings = TEXT.vocab.vectors
model.embedding.weight.data.copy_(pretrained_embeddings)

定义模型的超参数,包括词汇表大小(vocab_size)、词向量维度(embedding_dim)、输出维度(output,在这个任务中是1,因为是二元分类,所以使用1),以及 PAD 标记的索引(pad_idx)

之后需要将预训练的词向量加载到嵌入层的权重中。TEXT.vocab.vectors 包含了词汇表中每个单词的预训练词向量,然后通过 copy_ 方法将这些词向量复制到模型的嵌入层权重中对网络进行初始化。这样做确保了模型的初始化状态良好。

6、训练模型

 total_loss = 0train_acc = 0 
model.train()
for batch in train_iterator:optimizer.zero_grad()preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)total_loss += loss.item()batch_acc = (torch.round(torch.sigmoid(preds)) == batch.label).sum().item()train_acc += batch_accloss.backward()optimizer.step()average_loss = total_loss / len(train_iterator)train_acc /= len(train_iterator.dataset)

optimizer.zero_grad():表示将模型参数的梯度清零,以准备接收新的梯度。
preds = model(batch.text).squeeze(1):表示一次前向传播的过程,由于model输出的是torch.tensor(batch_size,1)所以使用squeeze(1)给其中的1维度数据去除,以匹配标签张量的形状
criterion(preds,batch.label):定义的损失函数 criterion 计算预测值 preds 与真实标签 batch.label 之间的损失

(torch.round(torch.sigmoid(preds)) == batch.label).sum().item():
通过比较模型的预测值与真实标签,计算当前批次的准确率,并将其累加到 train_acc 中
后面的就是进行反向传播更新参数,还有就是计算loss和train_acc的值了
7、模型评估:

model.eval()valid_loss = 0valid_acc = 0best_valid_acc = 0with torch.no_grad():for batch in valid_iterator:preds = model(batch.text).squeeze(1)loss = criterion(preds,batch.label)valid_loss += loss.item()batch_acc = ((torch.round(torch.sigmoid(preds)) == batch.label).sum().item())valid_acc += batch_acc

和训练模型的类似,这里就不解释了

8、保存模型
这里一共使用了两种保存模型的方式:

torch.save(model, "model.pth")
torch.save(model.state_dict(),"model.pth")

第一种方式叫做模型的全量保存
第二种方式叫做模型的参数保存

全量保存是保存了整个模型,包括模型的结构、参数、优化器状态等信息
参数量保存是保存了模型的参数(state_dict),不包括模型的结构
9、测试模型
测试模型的基本思路:
加载训练保存的模型、对待推理的文本进行预处理、将文本数据加载给模型进行推理

加载模型:

saved_model_path = "model.pth"
saved_model = torch.load(saved_model_path)

输入文本:
input_text = “Great service! The staff was very friendly and helpful.”

文本进行处理:

tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
tokenized_text = tokenizer(input_text)
indexed_text = [TEXT.vocab.stoi[token] for token in tokenized_text]
tensor_text = torch.LongTensor(indexed_text).unsqueeze(1).to(device)

模型推理:

saved_model.eval()
with torch.no_grad():output = saved_model(tensor_text).squeeze(1)prediction = torch.round(torch.sigmoid(output)).item()probability = torch.sigmoid(output).item()

由于笔者能力有限,所以在描述的过程中难免会有不准确的地方,还请多多包含!

更多NLP和CV文章以及完整代码请到"陶陶name"获取。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/260077.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

软件设计中如何画各类图之八深入解析部署图:物理布局与系统架构的视觉化呈现

目录 1 前言2 部署图的符号及说明3 画部署图的步骤3.1 **识别节点**3.2 **定义组件**3.3 **标识部署关系**3.4 **添加细节** 4 部署图的用途4.1 **系统设计与规划**4.2 **系统架构分析**4.3 **系统维护与升级** 5 实际场景举例5.1 Web应用部署图5.2 云端服务部署图 6 结语 1 前…

【GAMES101】观测变换

图形学不等于 OpenGL,不等于光线追踪,而是一套生成整个虚拟世界的方法 记得有个概念叫光栅化,就是把三维虚拟世界的事物显示在二维的屏幕上,这里就涉及到观察变换 观察变换,叫viewing transformation,包括…

二分查找|前缀和|滑动窗口|2302:统计得分小于 K 的子数组数目

作者推荐 贪心算法LeetCode2071:你可以安排的最多任务数目 本文涉及的基础知识点 二分查找算法合集 题目 一个数组的 分数 定义为数组之和 乘以 数组的长度。 比方说,[1, 2, 3, 4, 5] 的分数为 (1 2 3 4 5) * 5 75 。 给你一个正整数数组 nums 和一个整数…

【带头学C++】----- 九、类和对象 ---- 9.10 C++设计模式之单例模式设计

❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️麻烦您点个关注,不迷路❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️❤️ 目 录 9.10 C设计模式之单例模式设计 举例说明: 9.10 C设计模式之单例模式设计 看过我之前的文章的,简单讲解过C/Q…

Python:核心知识点整理大全8-笔记

目录 ​编辑 4.5 元组 4.5.1 定义元组 dimensions.py 4.5.2 遍历元组中的所有值 4.5.3 修改元组变量 4.6 设置代码格式 4.6.1 格式设置指南 4.6.2 缩进 4.6.3 行长 4.6.4 空行 4.6.5 其他格式设置指南 4.7 小结 第5章 if语句 5.1 一个简单示例 cars.py 5.2 条…

Kafka快速实战以及基本原理详解

文章目录 一、Kafka介绍为什么要用Kafka 二、Kafka快速上手实验环境单机服务体验 三、理解Kakfa的消息传递机制四、Kafka集群服务五、理解服务端的Topic、Partition和Broker七、Kafka集群的整体结构八、Kraft集群Kraft集群简介配置Kraft集群 一、Kafka介绍 ChatGPT对于Apache …

探索HarmonyOS开发—Slider滑动条组件

Slider Slider 滑块组件 Slider({min: 0, // 最小值max: 350, // 最大值value: 30, // 当前值step:10, // 滑动步长style:SliderStyle.OutSet, // Inset 滑块的位置direction:Axis.Horizontal, // Verticalreverse:false // 是否反向滑动 }) style属性可以控制滑块在整个滑块…

元宇宙vr党建云上实景展馆扩大党的影响力

随着科技的飞速发展,VR虚拟现实技术已经逐渐融入我们的日常生活,尤其在党建领域,VR数字党建展馆更是成为引领红色教育新风尚的重要载体。今天,就让我们一起探讨VR数字党建展馆如何提供沉浸式体验,助力党建工作创新升级…

使用STM32 HAL库进行GPIO控制的实例

✅作者简介:热爱科研的嵌入式开发者,修心和技术同步精进, 代码获取、问题探讨及文章转载可私信。 ☁ 愿你的生命中有够多的云翳,来造就一个美丽的黄昏。 🍎获取更多嵌入式资料可点击链接进群领取,谢谢支持!…

通过误差改变控制的两种策略

如果反馈误差越来越大,需要改变调节方向以减小误差并实现更好的控制。以下是两种常见的调节方向改变的方法: PID控制器中的积分限制:在PID控制中,积分项可以用来减小稳态误差。然而,当反馈误差持续增大时,积…

28、pytest实战:获取多用户鉴权

前提 测试过程中有用户体系,例如包括管理员、商家、用户角色,不同测试用例需要使用不同角色来操作,操作权限根据用户的鉴权来判断实现。 技能点 建立全局变量文件,保存账号相关信息获取鉴权信息变为module级别fixture&#xff…

Kafka在微服务架构中的应用:实现高效通信与数据流动

微服务架构的兴起带来了分布式系统的复杂性,而Kafka作为一款强大的分布式消息系统,为微服务之间的通信和数据流动提供了理想的解决方案。本文将深入探讨Kafka在微服务架构中的应用,并通过丰富的示例代码,帮助大家更全面地理解和应…