文档分类FastText模型 (pytorch实现)

文档分类FastText

        • FastText简介
        • 层次softmax
        • N-gram特征
        • FastText代码(文档分类)

FastText简介

FastText与之前介绍过的CBOW架构相似,我们先来会议一下CBOW架构,如下图:

在这里插入图片描述

CBOW的任务是通过上下文去预测中间的词,具体做法是使用滑动窗口内部的词的embedding的均值作为中间词的embedding。

FastText的任务是通过文章中的词去预测文章的类别(文档分类),具体做法是使用文章中的所有词的embedding的均值作为文章的embedding。最后从隐层再经过一次的非线性变换得到输出层的label。

CBOW和FastText的相似之处:

  1. 每个特征都是词向量的平均值。

总结一下CBOW和FastText的不同之处:

  1. FastText是有监督学习,而CBOW是无监督学习
  2. FastText是预测文章的label,而CBOW是预测中心词
  3. FastText使用的是文章中所有词的embedding,而CBOW使用的是中心词所在滑动窗口内其他所有词的embedding

从模型架构上来说,沿用了CBOW的单层神经网络的模式,不过fastText的处理速度才是这个算法的创新之处。

fastText模型的输入是一个词的序列(一段文本或者一句话),输出是这个词序列属于不同类别的概率。在序列中的词和词组构成特征向量,特征向量通过线性变换映射到中间层,再由中间层映射到标签。fastText在预测标签时使用了非线性激活函数,但在中间层不使用非线性激活函数。

fastText是一个快速文本分类算法,与基于神经网络的分类算法相比有两大优点:

  1. fastText在保持高精度的情况下加快了训练速度和测试速度
  2. fastText不需要预训练好的词向量,fastText会自己训练词向量
  3. fastText两个重要的优化:Hierarchical Softmax、N-gram

fastText方法包含三部分,模型架构,层次SoftmaxN-gram特征。

层次softmax

分层 softmax(Hierarchical Softmax)是一种用于加速词嵌入模型训练的技术,特别是在训练大型词汇表时。它通过将词汇表组织成一棵二叉树(通常是霍夫曼树),从而将原来的线性 softmax 运算转换为对树结构进行的多次二元分类,从而减少了计算量。

在这里插入图片描述

构建哈夫曼树

  • 首先,根据词汇表中每个词的词频构建一棵霍夫曼树。
  • 霍夫曼树是一种最优的二叉树,它通过最小化编码长度来实现对频繁出现的词进行更短的编码,以及对不太频繁出现的词进行较长的编码。

在这里插入图片描述

对数学模型进行改造

  • 对于一个普通的 softmax 模型,它的输出层是一个与词汇表大小相同的全连接层,需要对所有词汇进行一次计算。
  • 而在分层 softmax 中,将词汇表组织成二叉树,每个内部节点代表一个二元分类任务。模型的输出层不再是一个全连接层,而是根据霍夫曼树的结构构建的一系列内部节点。

预测过程

  • 在预测过程中,对于给定的目标词,从树的根节点开始,根据二元分类的规则逐级向下遍历,直到达到叶子节点,从而确定目标词的概率分布。
  • 通过遍历的路径,可以确定目标词在霍夫曼树中的编码,从而得到目标词的概率分布。

我们发现对于每一个节点,都是一个二分类[0,1],也就是我们可以使用sigmod来处理节点信息;
θ ( x ) = 1 1 + e − x \theta \left(x \right)=\frac{1}{1+e{-x}} θ(x)=1+ex1
此时,当我们知道了目标单词x,之后,我们只需要计算root节点,到该词的路径累乘,即可. 不需要去遍历所有的节点信息,时间复杂度变为O(log2(V))。

N-gram特征

n-gram是基于语言模型的算法,基本思想是将文本内容按照字节顺序进行大小为N的窗口滑动操作,最终形成窗口为N的字节片段序列。而且需要额外注意一点是n-gram可以根据粒度不同有不同的含义,有字粒度的n-gram和词粒度的n-gram,下面分别给出了字粒度和词粒度的例子:

#我爱中国
2-gram特征为:我爱 爱中 中国
3-gram特征为:我爱中 爱中国#我 爱 中国
2-gram特征为:我/爱 爱/中国
3-gram特征为:我//中国

从上面来看,使用n-gram有如下优点

  1. 为罕见的单词生成更好的单词向量:根据上面的字符级别的n-gram来说,即是这个单词出现的次数很少,但是组成单词的字符和其他单词有共享的部分,因此这一点可以优化生成的单词向量。
  2. 在词汇单词中,即使单词没有出现在训练语料库中,仍然可以从字符级n-gram中构造单词的词向量。
  3. n-gram可以让模型学习到局部单词顺序的部分信息, 如果不考虑n-gram则便是取每个单词,这样无法考虑到词序所包含的信息,即也可理解为上下文信息,因此通过n-gram的方式关联相邻的几个词,这样会让模型在训练的时候保持词序信息。

但正如上面提到过,随着语料库的增加,内存需求也会不断增加,严重影响模型构建速度,针对这个有以下几种解决方案:
1、过滤掉出现次数少的单词
2、使用hash存储
3、由采用字粒度变化为采用词粒度

FastText代码(文档分类)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self):self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 设备self.dropout = 0.5  # 随机失活self.require_improvement = 1000  # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = 10  # 类别数self.n_vocab = 10000  # 词表大小,在运行时赋值self.num_epochs = 20  # epoch数self.batch_size = 128  # mini-batch大小self.pad_size = 32  # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3  # 学习率self.embed = 300  # 字向量维度self.hidden_size = 256  # 隐藏层大小self.n_gram_vocab = 250499  # ngram 词表大小'''Bag of Tricks for Efficient Text Classification'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.embedding_ngram2 = nn.Embedding(config.n_gram_vocab, config.embed)self.embedding_ngram3 = nn.Embedding(config.n_gram_vocab, config.embed)self.dropout = nn.Dropout(config.dropout)self.fc1 = nn.Linear(config.embed * 3, config.hidden_size)self.fc2 = nn.Linear(config.hidden_size, config.num_classes)def forward(self, x):out_word = self.embedding(x[0])  # x[0] [batch_size,sentence_len] 经过embedding变为 [batch_size,sentence_len,wmbed_size] torch.Size([128, 32, 300])out_bigram = self.embedding_ngram2(x[2])  # torch.Size([128, 32, 300])out_trigram = self.embedding_ngram3(x[3])  # torch.Size([128, 32, 300])out = torch.cat((out_word, out_bigram, out_trigram), -1)  # torch.Size([128, 32, 900])out = out.mean(dim=1)  # torch.Size([128, 900]),沿着第二个维度(即特征维度)对每个样本的特征值进行平均池化out = self.dropout(out)  # torch.Size([128, 900])out = self.fc1(out)  # torch.Size([128, 900])经过fc1 torch.Size([128, 256])out = F.relu(out)  # torch.Size([128, 256])out = self.fc2(out)  # torch.Size([128, 256])经过fc1 torch.Size([128, 10])return outconfig=Config()
model=Model(config)
print(model)

输出:

Model((embedding): Embedding(10000, 300, padding_idx=9999)(embedding_ngram2): Embedding(250499, 300)(embedding_ngram3): Embedding(250499, 300)(dropout): Dropout(p=0.5, inplace=False)(fc1): Linear(in_features=900, out_features=256, bias=True)(fc2): Linear(in_features=256, out_features=10, bias=True)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/707040.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详解动态规划之01背包问题及其空间压缩(图文并茂+例题讲解)

1. 动态规划问题的本质 记忆化地暴力搜索所有可能性来得到问题的解 我们常常会遇到一些问题,需要我们在n次操作,且每次操作有k种选择时,求出最终需要的最小或最大代价。处理类似的问题,我们一般需要遍历所有的可能性(相当于走一遍…

STM32-串口通信波特率计算以及寄存器的配置详解

您好,我们一些喜欢嵌入式的朋友一起建立的一个技术交流平台,本着大家一起互相学习的心态而建立,不太成熟,希望志同道合的朋友一起来,抱歉打扰您了QQ群372991598 串口通信基本原理 处理器与外部设备通信的两种方式 并行…

邮箱地址验证软件有哪些-邮件地址验证软件

邮箱地址验证软件是帮助用户验证电子邮箱地址是否有效和真实存在的工具。以下是一些常用的邮箱地址验证软件: 易邮件地址验证大师:这是电子邮件营销平台MailerLite提供的一个简单的电子邮件验证工具,通过多层验证过程保证高准确率。寅甲邮件…

ChatGPT-4o 实战 如何快速分析混淆加密和webpack打包的源码

ChatGPT-4o 几个特点 一个对话拥有长时间的记忆,可以连续上传文件,让其分析,最大一个代码文件只能3M,超出3M的文件,可以通过split-file可以进行拆分 其次ChatGPT-4o可以生成文件的下载链接,这有利于大文件的…

TypeScript的数据类型系统

TypeScript的数据类型系统 在上一篇文章中,我们介绍了TypeScript的基本概念和它与JavaScript的关系。TypeScript的核心优势之一是其强大的类型系统,它提供了丰富的数据类型,使得代码更加可靠和易于维护。本文将深入探讨TypeScript中的各种数…

gpt4o在哪用?

GPT-4o功能? 1.感知用户情绪:前沿研究部门主管陈信翰(Mark Chen)让ChatGPT-4o聆听他的呼吸,聊天机器人侦测到他急促的呼吸,并幽默地建议他不要像吸尘器那样呼吸,要放慢速度。随后Mark深呼吸一次…

浏览器插件Video Speed Controller(视频倍速播放),与网页自身快捷键冲突/重复/叠加的解决办法

浏览器插件Video Speed Controller(视频倍速播放),与网站自身快捷键冲突/重复/叠加的解决办法 插件介绍问题曾今尝试的办法今日发现插件列表中打开Video Speed Controller的设置设置页面翻到下面,打开实验性功能。将需要屏蔽的原网…

邮件API接口的优势有哪些?如何有效整合?

邮件API怎么选?SendCloud与AokSend的性能对比分析? 邮件API接口作为企业与用户沟通的重要桥梁,其重要性不言而喻。Aok将深入探讨邮件API接口的优势、有效整合的方法、选择标准以及SendCloud与AokSend两款邮件发送服务的性能对比分析。 邮件…

杨校老师项目之基于SpringBoot+Shiro+Vue的企业人事管理系统

1.获取代码: 有偿获取:mryang511688 2.技术栈 后端 SpringBoot MySQL mybatis-plus shiro Redis 前端 Vue Element-UI 3.开发环境 JDK1.8、Maven3.5.4、MySQL5.7、Redis5.0.5、IntelliJ IDEA、nodejs 4.内置功能 Springboot的项目,…

Hive的窗口函数

定义: 聚合函数是针对定义的行集(组)执行聚集,每组只返回一个值.如sum()、avg()、max() 窗口函数也是针对定义的行集(组)执行聚集,可为每组返回多个值.如既要显示聚集前的数据,又要显示聚集后的数据.步骤: 1.将记录分割成多个分区. 2.在各个分区上调用窗…

工业派-配置Intel神经计算棒二代(NCS2)

最近两天在工业派ubuntu16.04上配置了Intel神经计算棒二代——Intel Neural Compute Stick,配置过程之艰辛我都不想说了,实在是太折磨人。不过历尽千辛万苦,总算让计算棒可以在工业派ubuntu16.04系统上跑了,还是蛮欣慰的。 注&…

究极完整版!!Centos6.9安装最适配的python和yum,附带教大家如何写Centos6.9的yum.repos.d配置文件。亲测可行!

前言! 这里我真是要被Centos6.9给坑惨了,最刚开始学习linux的时候并没有在意那么的,没有考虑到选版本问题,直到23年下半年,官方不维护Centos6.9了,基本上当时配置的文件和安装的依赖都用不了了&#xff0c…