第N6周:使用Word2vec实现文本分类

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
deviceimport pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()# 构造数据集迭代器
def coustom_data_iter(texts,labels):for x,y in zip(texts,labels):yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)# 将文本转化为向量
def average_vec(text):vec =np.zeros(100).reshape((1,100))for word in text:try:vec +=w2v.wv[word].reshape((1,100))except KeyError:continuereturn vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)label_name =list(set(train_data[1].values[:]))
print(label_name)text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)text_pipeline("你在干嘛")
label_pipeline("Travel-Query")from torch.utils.data import DataLoader
def collate_batch(batch):label_list,text_list=[],[]for(_text,_label)in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)text_list.append(processed_text)label_list = torch.tensor(label_list,dtype=torch.int64)text_list = torch.cat(text_list)return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)from torch import nn
class TextclassificationModel(nn.Module):def __init__(self,num_class):super(TextclassificationModel,self).__init__()self.fc = nn.Linear(100,num_class)def forward(self,text):return self.fc(text)num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)import time
def train(dataloader):model.train()#切换为训练模式total_acc,train_loss,total_count =0,0,0log_interval=50start_time= time.time()for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)# grad属性归零optimizer.zero_grad()loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,labelloss.backward()#反向传播torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪optimizer.step()#每一步自动更新#记录acc与losstotal_acc+=(predicted_label.argmax(1)==label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval==0 and idx>0:elapsed =time.time()-start_timeprint('Iepoch {:1d}I{:4d}/{:4d} batches''|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))total_acc,train_loss,total_count =0,0,0start_time = time.time()
def evaluate(dataloader):model.eval()#切换为测试模式total_acc,train_loss,total_count =0,0,0with torch.no_grad():for idx,(text,label)in enumerate(dataloader):predicted_label= model(text)loss = criterion(predicted_label,label)# 计算loss值# 记录测试数据total_acc+=(predicted_label.argmax(1)== label).sum().item()train_loss += loss.item()total_count += label.size(0)return total_acc/total_count,train_loss/total_countfrom torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):epoch_start_time = time.time()train(train_dataloader)val_acc,val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr =optimizer.state_dict()['param_groups'][0]['1r']if total_accu is not None and total_accu>val_acc:scheduler.step()else:total_accu = val_accprint('-'*69)print('|epoch {:1d}|time:{:4.2f}s |''valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,time.time()-epoch_start_time,val_acc,val_loss,lr))print('-'*69)# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.120211511.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.805954991.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.996075184.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646-0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.63949711.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986-1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518-0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139-2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857-0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.279315671.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217-1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.222498812.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.277228682.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259-1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.808146040.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.835356470.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/589140.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OpenHarmony实战开发-如何通过Stage模型实现一个简单的游戏卡片

介绍 本示例展示了如何通过Stage模型实现一个简单的游戏卡片。 通过卡片支持的点击事件进行交互,让用户通过点击的先后顺序把一个乱序的成语排列成正确的成语。使用了C和TS的混合编程方式,将获取随机数的能力下沉到C实现,并通过NAPI的能力将…

动态规划详细讲解c++|经典例题讲解认识动态规划|0-1背包问题详解

引言 uu们,你们好!这次的分享是动态规划,其中介绍了动态规划的相关概念和做题模板(三要素),同时为了uu们对动态规划方法有更加形象的认识,特地找了两个经典问题,和大家一起分析。并…

数字乡村创新之路:科技引领农村实现高质量发展

随着信息技术的快速发展,数字乡村建设已成为推动农村高质量发展的重要引擎。数字乡村通过科技创新,不仅改变了传统农业生产方式,也提升了乡村治理水平,为农民带来了更加便捷的生活。本文将从数字乡村的内涵、科技引领农村高质量发…

SV学习笔记(三)

类和对象概述 类和对象 面向对象的编程语言更符号人对自然语言的理解(属性property和功能function)。 这个世界由无数的类(class)和对象(object)构成的。 类是将相同的个体抽象出来的描述方式&#xff0c…

ObjectiveC-08-OOP面向对象程序设计-类的分离与组合

本节用一简短的文章来说下是ObjectiveC中的类。类其实是OOP中的一个概念,概念上简单来讲类是它是一组关系密切属性的集合,所谓的关系就是对现实事物的抽象。 上面提到的关系包括很多种,比如has a, is a,has some等&…

Linux TCP连接数查询

1 tcp连接查看 netstat -anput 2 统计连接数 2.1统计80端口的连接数 netstat -nat|grep -i "80"|wc -l 2.2统计总连接数 netstat -nat|wc -l 2.3统计已连接上的,状态为established netstat -na|grep ESTABLISHED|wc -l 3 统计所有请求状态及数量 …

Redis 应用问题解决——缓存穿透、缓存击穿、缓存雪崩、分布式锁

缓存穿透 key对应的数据在数据源不存在,每次针对此key的请求从缓存获取不到,请求都会压到数据源,从而可能压垮数据源。比如用一个不存在的用户id获取用户信息,不论缓存还是数据库都没有,若黑客利用此漏洞进行攻击可能…

JavaSE:抽象类和接口

目录 一、前言 二、抽象类 (一)抽象类概念 (二)使用抽象类的注意事项 (三)抽象类的作用 三、接口 (一)接口概念 (二)接口语法规则 (三&a…

Keil不能生成.bin文件,解决方法

脚本: D:\ProgramFiles\Keil_v5\ARM\ARM5.06\bin\fromelf.exe --bin --outputBin\keyboard.bin ..\..\Output\keyboard.axf 说明: fromelf.exe --bin --outputBin\keyboard.bin ..\..\Output\keyboard.axf 通过生成后的keyboard.axf, 执行f…

Loadrunner的使用

Loadrunner的使用 选项公网测试地址:http://cfgjt.cn:8981/devt-web 用户名admin,密码11111111 1.Loadrunner介绍 ​ LoadRunner,是一种预测系统行为和性能的负载测试工具。通过模拟上千万用户实施并发负载及实时性能监测的方式来确认和查…

坦克大战_java源码_swing界面_带毕业论文

一. 演示视频 坦克大战_java源码_swing界面_带毕业论文 二. 实现步骤 完整项目获取 https://githubs.xyz/y22.html 部分截图 启动类是 TankClinet.java,内置碰撞检测算法,线程,安全集合,一切皆对象思想等,是java进阶…

SSTI模板注入(jinja2)

前面学习了SSTI中的smarty类型,今天学习了Jinja2,两种类型都是flask框架的,但是在注入的语法上还是有不同 SSTI:服务器端模板注入,也属于一种注入类型。与sql注入类似,也是通过凭借进行命令的执行&#xff…