Word2Vec实现文本识别分类

深度学习训练营之使用Word2Vec实现文本识别分类

  • 原文链接
  • 环境介绍
  • 前言
  • 前置工作
    • 设置GPU
    • 数据查看
    • 构建数据迭代器
  • Word2Vec的调用
  • 生成数据批次和迭代器
  • 模型训练
    • 初始化
    • 拆分数据集并进行训练
  • 预测

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第N4周:用Word2Vec实现文本分类
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.12
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前言

本次内容我本来是使用miniconda的环境的,但是好像有文件发生了损坏,出现了如下报错,据我所了解应该是某个文件发生了损坏,应该是之前将anaconda误删有关,有所了解或者有同样问题的朋友可以一起进行探讨

前置工作

设置GPU

如果

# 先进行数据加载
import torch
import torch.nn as nn
import torchvision
import os,PIL,pathlib,warnings
import time
from torchvision import transforms, datasets
from torch import nn
from torch.utils.data.dataset import random_splitwarnings.filterwarnings("ignore")#忽略警告信息
device=torch.device("cuda"if torch.cuda.is_available()else "cpu")
device

device(type=‘cpu’)

数据查看

本次使用的数据集和之前中文文本识别分类的是一样的

import pandas as pd
train_data=pd.read_csv('train.csv',sep='\t',header=None)
train_data.head()

在这里插入图片描述

构建数据迭代器

#构建数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,yx=train_data[0].values[:]
y=train_data[1].values[:]    

添加数据迭代器是为了让数据的随机性增强,进行数据集的划分,可以有效的发挥内存的高利用率

Word2Vec的调用

对Word2Vec进行直接的调用

from gensim.models.word2vec import Word2Vec
import numpy as np
#训练浅层神经网络模型
w2v=Word2Vec(vector_size=100,min_count=3)w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=30)

build_vocab统计输入每一个词汇出现的次数

def average_vec(text):vec=np.zeros(100).reshape((1,100))#表示平均向量#(n,100),其中n表示x中的元素的数量 for word in text:try:vec+=w2v.wv[word].reshape((1,100))except KeyError:continue#未找到,再进行迭代下一个词return vecx_vec=np.concatenate([average_vec(z) for z in x])
w2v.save('w2v_model.pkl')

该步骤将输入的文本转变成了平均向量
对于输入进来的text当中的每一个单词都进行一个查询,确认是否当中有该词,如果有那么就将其添加到vector当中,否则跳出本层循环,查找下一个词.
最后通过np当中的concatenate方法进行一个向量的连接

train_iter=coustom_data_iter(x_vec,y)#训练迭代器
print(len(x),len(y))

12100 12100

设置训练的迭代器

label_name=list(set(train_data[1].values[:]))
print(label_name)
['FilmTele-Play', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'HomeAppliance-Control', 'Alarm-Update', 'Travel-Query', 'Video-Play', 'Calendar-Query', 'TVProgram-Play', 'Music-Play', 'Other']

生成数据批次和迭代器

text_pipeline=lambda x:average_vec(x)
label_pipeline=lambda x:label_name.index(x)
#lambda语法:lambda  arguments
text_pipeline("我想你了")

在这里插入图片描述

label_pipeline("Travel-Query")

6

这里的结果每次都会不太一样,具有一定的随机性

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list= [], []for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)text_list.append(processed_text)# 偏移量,即语句的总词汇量label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)return text_list.to(device),label_list.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

和之前的不同在于没有了offset

模型训练

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, num_class):super(TextClassificationModel, self).__init__()self.fc = nn.Linear(100, num_class)def forward(self, text):return self.fc(text)

初始化

num_class  = len(label_name)
vocab_size = 100000
em_size    = 12
model      = TextClassificationModel(num_class).to(device)
import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time   = time.time()for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)optimizer.zero_grad()                    # grad属性归零loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值loss.backward()                          # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text,label) in enumerate(dataloader):predicted_label = model(text)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

拆分数据集并进行训练

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 30 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss,lr))print('-' * 69)
| epoch 1 |   50/ 152 batches | train_acc 0.742 train_loss 0.02635
| epoch 1 |  100/ 152 batches | train_acc 0.820 train_loss 0.02033
| epoch 1 |  150/ 152 batches | train_acc 0.838 train_loss 0.01927
---------------------------------------------------------------------
| epoch 1 | time: 0.95s | valid_acc 0.819 valid_loss 0.023 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 |   50/ 152 batches | train_acc 0.850 train_loss 0.01876
| epoch 2 |  100/ 152 batches | train_acc 0.849 train_loss 0.02012
| epoch 2 |  150/ 152 batches | train_acc 0.847 train_loss 0.01736
---------------------------------------------------------------------
| epoch 2 | time: 0.92s | valid_acc 0.869 valid_loss 0.016 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 |   50/ 152 batches | train_acc 0.858 train_loss 0.01588
| epoch 3 |  100/ 152 batches | train_acc 0.833 train_loss 0.02008
| epoch 3 |  150/ 152 batches | train_acc 0.864 train_loss 0.01813
---------------------------------------------------------------------
| epoch 3 | time: 0.86s | valid_acc 0.835 valid_loss 0.023 | lr 5.000000
---------------------------------------------------------------------
| epoch 4 |   50/ 152 batches | train_acc 0.883 train_loss 0.01309
| epoch 4 |  100/ 152 batches | train_acc 0.899 train_loss 0.00996
| epoch 4 |  150/ 152 batches | train_acc 0.895 train_loss 0.00927
---------------------------------------------------------------------
| epoch 4 | time: 0.87s | valid_acc 0.888 valid_loss 0.011 | lr 0.500000
---------------------------------------------------------------------
| epoch 5 |   50/ 152 batches | train_acc 0.906 train_loss 0.00834
...
| epoch 30 |  150/ 152 batches | train_acc 0.900 train_loss 0.00717
---------------------------------------------------------------------
| epoch 30 | time: 0.92s | valid_acc 0.886 valid_loss 0.010 | lr 0.000000
---------------------------------------------------------------------
test_acc, test_loss = evaluate(valid_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

在这里插入图片描述

预测

def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text),dtype=torch.float32)print(text.shape)output = model(text)return output.argmax(1).item()ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
#ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to(device)print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])
torch.Size([1, 100])
该文本的类别是:Music-Play

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/23232.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零代码编程:用ChatGPT批量识别图片PDF中的文字

有些PDF页面是图片格式,要怎么批量把图片中的文字识别出来?借助ChatGPT可以轻松完成这个任务。 首先要安装一些相关的软件和Python库。 安装tesseract-ocr(OCR)软件,最新版的是tesseract-ocr-w64-setup-v5.3.0.20221…

API全场景零码测试机器人——ATGen带来“超自动化”测试模式

HDC期间可参与新手入驻华为云Testplan抽奖活动,活动链接在文末 众所周知,软件服务及组件之间的交互主要依赖大量的API接口。以华为云300多个商用云服务为例,平均每个服务含500接口,接口总数高达10万,接口调用上下文业务…

汽车网卡驱动之TJA1101B

TJA1101B汽车网卡驱动(汽车以太网) 1总体描述 2特点和优点 2.1通用 2.2针对汽车用例优化

使用USB转TTL线连接树莓派4B

一般我们刷完树莓派系统后,都是通过连接鼠标键盘及显示器来进行操作,当我们开启SSH功能后我们才可以通过ssh客户端进行远程访问,那么是否有更方便的方式进行连接,并且不需连接外部设备进行操作呢? 串口通信 当然可以…

2022 Robocom CAIP国赛 第四题 变牛的最快方法

原题链接: PTA | 程序设计类实验辅助教学平台 题面: 这里问的是把任意一种动物的图像变成牛的方法…… 比如把一只鼠的图像变换成牛的图像。方法如下: 首先把屏幕上的像素点进行编号;然后把两只动物的外轮廓像素点编号按顺时针记…

Jmeter二次开发实现rsa加密

jmeter函数助手提供了大量的函数,像 counter、digest、random、split、strLen,这些函数在接口测试、性能测试中大量被使用,但是大家在实际工作,形形色色的测试需求不同,导致jmeter自带或者扩展插件给我们提供的函数无法…

【反向代理】反向代理及其作用

反向代理及其作用 一、什么是正向代理 在介绍反向代理之前我们先介绍什么是正向代理 首先要明确的是,在http协议中正向代理一般被称为代理,在web服务中我们可以通过主动配置代理服务器的方式来发送请求,并通过代理服务器接收服务器的响应。…

时序预测 | MATLAB实现Hamilton滤波AR时间序列预测

时序预测 | MATLAB实现Hamilton滤波AR时间序列预测 目录 时序预测 | MATLAB实现Hamilton滤波AR时间序列预测预测效果基本介绍程序设计参考资料预测效果 基本介绍 预测在很大程度上取决于适合周期的模型和所采用的预测方法,就像它们依赖于过滤器提取的周期一样。标准 Hodrick-P…

【DBA课程-笔记】第 3 章:MongoDB数据库核心知识

内容 一、MongoDB 数据库架构 A. MongoDB数据库体系架构 1. 存储引擎(MongoDB Storage Engines): 2. MongoDB 数据逻辑架构 二、MongoDB 存储引擎 A. 查看mongodb服务器的状态 B. 查看引擎信息(4.2.1 没有这个命令&#xf…

火山引擎徐广治:边缘云,下一代云计算

6月30日,2023稀土开发者大会在北京举办。大会以「代码不止,掘金不停」为主题,与上百位海内外技术专家一起剖析行业最新动态,为一直在路上的技术开发者们,拓宽技术视野,传播前沿的技术理念。火山引擎边缘云资…

给LLM装上知识:从LLM+LangChain的本地知识库问答到LLM与知识图谱的结合

前言 过去半年,随着ChatGPT的火爆,直接带火了整个LLM这个方向,然LLM毕竟更多是基于过去的经验数据预训练而来,没法获取最新的知识,以及各企业私有的知识 为了获取最新的知识,ChatGPT plus版集成了bing搜…

1770_VirtualBox下安装Debian

全部学习汇总: GreyZhang/little_bits_of_linux: My notes on the trip of learning linux. (github.com) 作为我自己的日常使用,Debian基本上没有出现过。最多是让它运行在某个设备上作为一个服务的平台,因为很多东西我懒得去配置。 Debia…