NLP|LSTM+Attention文本分类

目录

一、Attention原理简介

二、LSTM+Attention文本分类实战

1、数据读取及预处理

2、文本序列编码

3、LSTM文本分类

三、划重点

少走10年弯路


        LSTM是一种特殊的循环神经网络(RNN),用于处理序列数据和时间序列数据的建模和预测。而在NLP和时间序列领域上Attention-注意力机制也早已有了大量应用,本文将介绍在LSTM基础上如何添加Attention来优化模型效果。

一、Attention原理简介

        注意力机制通过聚焦于重要的信息,忽略不重要的信息,从而有效地处理输入信息。在神经网络中,注意力机制可以帮助模型更好地关注输入中的重要特征,从而提高模型的性能。

        简单而言,在文本处理任务中,self-attention对每一个词会随机初始化q、k、v三个向量,用每个词的q向量和其他k向量做点积、再归一化得到这个词的权重向量w,用w给v向量加权求和得到z向量(该词attention之后的向量)。再延伸一点,其实可以初始化多组q、k、v矩阵,从而得到多组z矩阵拼接起来(类似于CNN中的多个卷积核、来提取不同信息),再乘上一个矩阵压缩回原来的维度,得到最终的embedding。

        细节原理相对繁琐,推荐大家可以去看一下这篇博客的bert介绍,其中self-attention部分详细且清晰。

https://blog.csdn.net/jiaowoshouzi/article/details/89073944

二、LSTM+Attention文本分类实战

1、数据读取及预处理

import re
import os
from sqlalchemy import create_engine
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve,roc_auc_score
import xgboost as xgb
from xgboost.sklearn import XGBClassifier
import lightgbm as lgb
import matplotlib.pyplot as plt
import gcfrom tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers# 2、数据读取+预处理
data=pd.read_excel('Inshorts Cleaned Data.xlsx')def data_preprocess(data):df=data.drop(['Publish Date','Time ','Headline'],axis=1).copy()df.rename(columns={'Source ':'Source'},inplace=True)df=df[df.Source.isin(['YouTube','India Today'])].reset_index(drop=True)df['y']=np.where(df.Source=='YouTube',1,0)df=df.drop(['Source'],axis=1)return dfdf=data.pipe(data_preprocess)
print(df.shape)
df.head()# 导入英文停用词
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
stop_english=stopwords.words('english')  
stop_spanish=stopwords.words('spanish') 
stop_english# 4、文本预处理:处理简写、小写化、去除停用词、词性还原
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
import nltkdef replace_abbreviation(text):rep_list=[("it's", "it is"),("i'm", "i am"),("he's", "he is"),("she's", "she is"),("we're", "we are"),("they're", "they are"),("you're", "you are"),("that's", "that is"),("this's", "this is"),("can't", "can not"),("don't", "do not"),("doesn't", "does not"),("we've", "we have"),("i've", " i have"),("isn't", "is not"),("won't", "will not"),("hasn't", "has not"),("wasn't", "was not"),("weren't", "were not"),("let's", "let us"),("didn't", "did not"),("hadn't", "had not"),("waht's", "what is"),("couldn't", "could not"),("you'll", "you will"),("i'll", "i will"),("you've", "you have")]result = text.lower()for word_replace in rep_list:result=result.replace(word_replace[0],word_replace[1])
#     result = result.replace("'s", "")return resultdef drop_char(text):result=text.lower()result=re.sub('[^\w\s]',' ',result) # 去掉标点符号、特殊字符result=re.sub('\s+',' ',result) # 多空格处理为单空格return resultdef stemed_words(text,stop_words,lemma):word_list = [lemma.lemmatize(word, pos='v') for word in text.split() if word not in stop_words]result=" ".join(word_list)return resultdef text_preprocess(text_seq):stop_words = stopwords.words("english")lemma = WordNetLemmatizer()result=[]for text in text_seq:if pd.isnull(text):result.append(None)continuetext=replace_abbreviation(text)text=drop_char(text)text=stemed_words(text,stop_words,lemma)result.append(text)return resultdf['short']=text_preprocess(df.Short)
df[['Short','short']]# 5、划分训练、测试集
test_index=list(df.sample(2000).index)
df['label']=np.where(df.index.isin(test_index),'test','train')
df['label'].value_counts()

2、文本序列编码

        按照词频排序,创建长度为6000的高频词词典、来对文本进行序列化编码。

from tensorflow.keras.preprocessing.text import Tokenizer
def word_dict_fit(train_text_list,num_words):'''train_text_list: ['some thing today ','some thing today2']'''tok_params={'num_words':num_words,  # 词典的长度,仅保留词频top的num_words个词'filters':'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n','lower':True, 'split':' ', 'char_level':False, 'oov_token':None, # 设定词典外的词编码}tok = Tokenizer(**tok_params) # 分词tok.fit_on_texts(train_text_list)return tokdef word_dict_apply_sequences(tok_model,text_list,len_vec):'''text_list: ['some thing today ','some thing today2']'''list_tok = tok_model.texts_to_sequences(text_list) # 编码映射pad_params={'sequences':list_tok,'maxlen':len_vec,  # 补全后向量长度'padding':'pre', # 'pre' or 'post',在前、在后补全'truncating':'pre', # 'pre' or 'post',在前、在后删除长度多余的部分'value':0, # 补全0}seq_tok = pad_sequences(**pad_params) # 补全编码向量,返回二维arrayreturn seq_toknum_words,len_vec = 6000,40
tok_model= word_dict_fit(df[df.label=='train'].short,num_words)
tok_train = word_dict_apply_sequences(tok_model,df[df.label=='train'].short,len_vec)
tok_test = word_dict_apply_sequences(tok_model,df[df.label=='test'].short,len_vec)
tok_test

图片

3、LSTM文本分类

        LSTM层的输入是三维张量(batch_size, timesteps, input_dim),所以使用的数据可以是时间序列、也可以是文本数据的embedding;输出设置return_sequences为False,返回尺寸为 (batch_size, units) 的 2D 张量。

'''
LSTM层核心参数units:输出维度activation:激活函数recurrent_activation: RNN循环激活函数use_bias: 布尔值,是否使用偏置项dropout:0~1之间的浮点数,神经元失活比例recurrent_dropout:0~1之间的浮点数,循环状态的神经元失活比例return_sequences: True时返回RNN全部输出序列(3D),False时输出序列的最后一个输出(2D)
'''
def init_lstm_model(max_features, embed_size):model = Sequential()model.add(Embedding(input_dim=max_features, output_dim=embed_size))model.add(Bidirectional(LSTM(units=32,activation='relu', recurrent_dropout=0.1)))model.add(Dropout(0.25,seed=1))model.add(Dense(64))model.add(Dropout(0.3,seed=1))model.add(Dense(1, activation='sigmoid'))model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return modeldef model_fit(model, x, y,test_x,test_y):return model.fit(x, y, batch_size=100, epochs=2, validation_data=(test_x,test_y))embed_size = 128
lstm_model=init_lstm_model(num_words, embed_size)
model_train=model_fit(lstm_model,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))
lstm_model.summary()def model_fit(model, x, y,test_x,test_y):return model.fit(x, y, batch_size=100, epochs=2, validation_data=(test_x,test_y))embed_size = 128
lstm_model=init_lstm_model(num_words, embed_size)
model_train=model_fit(lstm_model,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))
lstm_model.summary()

 

def ks_auc_value(y_value,y_pred):fpr,tpr,thresholds= roc_curve(list(y_value),list(y_pred))ks=max(tpr-fpr)auc= roc_auc_score(list(y_value),list(y_pred))return ks,aucprint('train_ks_auc',ks_auc_value(df[df.label=='train'].y,lstm_model.predict(tok_train)))
print('test_ks_auc',ks_auc_value(df[df.label=='test'].y,lstm_model.predict(tok_test)))'''train_ks_auc (0.7223217797649937, 0.922939132379851)test_ks_auc (0.7046603930606234, 0.9140880065296716)
'''

4、LSTM+Attention文本分类

        在LSTM层之后添加Attention层优化效果。

from tensorflow.keras.models import Model
def init_lstm_model(max_features, embed_size ,embedding_matrix):input_=layers.Input(shape=(40,))x=Embedding(input_dim=max_features, output_dim=embed_size,weights=[embedding_matrix],trainable=False)(input_)x=Bidirectional(layers.LSTM(units=32,activation='relu', recurrent_dropout=0.1,return_sequences=True))(x)x=layers.Attention(40)([x,x])x=Dropout(0.25)(x)x=layers.Flatten()(x)x=Dense(64)(x)x=Dropout(0.3)(x)x=Dense(1,activation='sigmoid')(x)model = Model(inputs=input_, outputs=x)model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return modeldef model_fit(model, x, y,test_x,test_y):return model.fit(x, y, batch_size=100, epochs=5, validation_data=(test_x,test_y))num_words,embed_size = 6000,128
lstm_model2=init_lstm_model(num_words, embed_size ,embedding_matrix)
model_train=model_fit(lstm_model2,tok_train,np.array(df[df.label=='train'].y),tok_test,np.array(df[df.label=='test'].y))print('train_ks_auc',ks_auc_value(df[df.label=='train'].y,gru_model.predict(tok_train)))
print('test_ks_auc',ks_auc_value(df[df.label=='test'].y,gru_model.predict(tok_test)))
'''train_ks_auc (0.7126925954159541, 0.9199721561742299)test_ks_auc (0.7239373279559567, 0.917086274086166)
'''

三、划重点

少走10年弯路

        关注威信公众号 Python风控模型与数据分析,回复 文本分类5 获取本篇数据及代码

        还有更多理论、代码分享等你来拿

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/337686.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Poi实现根据word模板导出-图表篇

往期系列传送门&#xff1a; Poi实现根据word模板导出-文本段落篇 &#xff08;需要完整代码的直接看最后位置&#xff01;&#xff01;&#xff01;&#xff09; 前言&#xff1a; 补充Word中图表的知识&#xff1a; 每个图表在word中都有一个内置的Excel&#xff0c;用于…

20240107移远的4G模块EC20在Firefly的AIO-3399J开发板的Android11下调通能上网

20240107移远的4G模块EC20在Firefly的AIO-3399J开发板的Android11下调通能上网 2024/1/7 11:17 开发板&#xff1a;Firefly的AIO-3399J【RK3399】SDK&#xff1a;rk3399-android-11-r20211216.tar.xz【Android11】 Android11.0.tar.bz2.aa【ToyBrick】 Android11.0.tar.bz2.ab …

开启鸿蒙开发探索之旅ArkTS基本语法介绍(3)

上一章简单的介绍了鸿蒙HUAWEI DevEco Studio框架的搭建&#xff0c;这一章讲一下鸿蒙的主要开发一眼ArkTS的基本语法结构 1.ArkTS语法解释 ArkTS是HarmonyOS优选的主力应用开发语言。ArkTS围绕应用开发在TypeScript&#xff08;简称TS&#xff09;生态基础上做了进一步扩展&…

Android 通知简介

Android 通知简介 1. 基本通知 图1: 基本通知详情 小图标 : 必须提供,通过 setSmallIcon( ) 进行设置.应用名称 : 由系统提供.时间戳 : 由系统提供,也可隐藏时间.大图标(可选) : 可选内容(通常仅用于联系人照片,请勿将其用于应用图标),通过setLargeIcon( ) 进行设置.标题 : 可选…

念数字(C语言)

做法非常巧妙&#xff08;这也是我看别人的写法写的&#xff09; #include <string.h> #include <stdio.h> int main() { int i 0; char str[100000] { \0}; char arr[10][5] { "ling","yi","er","san",&…

压测必经之路,Jmeter分布式压测教程

01、分布式压测原理 Jemter分布式压测是选择其中一台作为调度机&#xff08;master&#xff09;&#xff0c;其他机器作为执行机&#xff08;slave&#xff09;&#xff1b;当然一台机器也可以既做调度机&#xff0c;也做执行机。 调度机执行脚本的时候&#xff0c;master将会…

阿里云 云数据库 Redis 版测评

1. 试用 地址&#xff1a;https://developer.aliyun.com/topic/freetier/database 点击试用 选择相应信息后点击立即试用&#xff0c;此处务必注意ECS和Redis需要在一个地域(可用区)&#xff0c;否则后续连接不方便。 2. 创建实例 购买后&#xff0c;进入控制台&#xff0c…

微信小程序的基本使用1:数据绑定同步,导航方式,导航传参,全局配置,下拉刷新,上拉触底等

各组件属性参考微信官方文档&#xff1a;微信开放文档 视图容器 scroll-view 可滚动视图区域&#xff08;轮播图&#xff09;。使用竖向滚动时&#xff0c;需要给scroll-view一个固定高度&#xff0c;通过 WXSS 设置 height。组件属性的长度单位默认为px&#xff0c;2.4.0起支…

链上数据分析:解读加密生态的秘密武器

作者&#xff1a;shellyfootprint.network 数据源&#xff1a; Wallet Profile 在加密货币的世界里&#xff0c;信息是力量。但如何获取真实、有价值的数据呢&#xff1f;普通个人投资者浏览 Reddit 帖子或观看 YouTube 视频&#xff0c;并根据基本价格图表做出投资决定。这种…

AirSim 的 ROS 功能包测试

参考链接&#xff1a; Ubuntu18.04搭建AirSimROS仿真环境_airsim ros-CSDN博客 ROS: AirSim ROS Wrapper - AirSim 1.编译 ros 包&#xff08;必须是 gcc-8&#xff09; 如果您的默认 GCC 不是 8 或更高&#xff08;使用 gcc --version 检查&#xff09;&#xff0c;那么编译…

离散数学-二元关系

4.1关系的概念 1)序偶及n元有序组 由两个个体x和y&#xff0c;按照一定顺序排序成的、有序数组称为有序偶或有序对、二元有序组&#xff0c; 记作<x&#xff0c;y>&#xff0c;其中x是第一分量&#xff0c;y是第二分量。 相等有序偶&#xff1a;第一分量和第二分量分…

Python-面向对象

面向对象 1.初识对象1.1理解使用对象完成数据组织的思路 2.成员方法2.1类的定义和使用语法2.2成员方法的使用 3.类和对象4.构造方法4.1使用构造方法向成员变量赋值 5.其他内置方法5.1__str__字符串方法5.2__lt__小于符号比较方法5.3__le__小于等于比较符号5.4__eq__比较运算符实…