基于word2vec 和 fast-pytorch-kmeans 的文本聚类实现,利用GPU加速提高聚类速度

文章目录

    • 简介
      • GPU加速
    • 代码实现
    • kmeans
    • 聚类结果
    • kmeans 绘图函数
    • 相关资料参考

简介

本文使用text2vec模型,把文本转成向量。使用text2vec提供的训练好的模型权重进行文本编码,不重新训练word2vec模型。

直接用训练好的模型权重,方便又快捷

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

GPU加速

传统sklearn的TF-IDF文本转向量,在CPU上计算速度较慢。使用text2vec通过cuda加速,加快文本转向量的速度。
传统使用sklearn的kmeans聚类算法在CPU上计算,如遇到大批量的数据,计算耗时太长。
故本文使用kmeans_pytorch包,基于pytorch在GPU上计算,提高聚类速度。

代码实现

装包

pip install fast-pytorch-kmeans text2vec
import torch
import numpy as npfrom text2vec import SentenceModel

不使用SentenceModel模型也可以,在 text2vec 中,还有很多其他的向量编码模型供选择。

文本编码模型

embedder = SentenceModel()

异常情况说明,该模型需要从huggingface下载模型权重,目前被墙了。(请想办法解决,或者尝试其他的编码模型)
在这里插入图片描述

语料库如下:

# Corpus with example sentences
corpus = ['花呗更改绑定银行卡','我什么时候开通了花呗','A man is eating food.','A man is eating a piece of bread.','The girl is carrying a baby.','A man is riding a horse.','A woman is playing violin.','Two men pushed carts through the woods.','A man is riding a white horse on an enclosed ground.',
]
corpus_embeddings = embedder.encode(corpus)
# numpy 转成 pytorch, 并转移到GPU显存中
corpus_embeddings = torch.from_numpy(corpus_embeddings).to('cuda')

如下图所示,编码的向量是768纬;

type(corpus_embeddings), corpus_embeddings.shape

在这里插入图片描述

kmeans

kmeans_pytorch vs fast-pytorch-kmeans:
在实验过程中,利用kmeans_pytorch 针对30万个词进行聚类的时候,发现显存炸了,程序崩溃退出。30万个词的词向量,占用显存还不到2G,但是运行kmeans_pytorch后,显存就炸了。

fast-pytorch-kmeans不存在上述显存崩溃的问题。本以为词向量很多会跑很长时间,但fast-pytorch-kmeans在非常短的时间内就完成了kmeans聚类。

# kmeans
# from kmeans_pytorch import kmeans
from fast_pytorch_kmeans import KMeansnum_class = 3 # 分类类别数
kmeans = KMeans(n_clusters=num_class, mode='euclidean', verbose=1)# 模型预测结果
labels = kmeans.fit_predict(corpus_embeddings)

聚类程序运行如下:

used 2 iterations (0.3682s) to cluster 9 items into 3 clusters

模型中心点坐标:

kmeans.centroids

在这里插入图片描述

聚类结果

class_data = {i:[]for i in range(3)
}for text,cls in zip(corpus, labels):class_data[cls.item()].append(text)class_data

文本聚类结果如下:
0: 女
1:男
2: 花呗
在这里插入图片描述

kmeans 绘图函数

封装了KMeansPlot 绘图类,方便聚类结果可视化

from sklearn.decomposition import PCA
import matplotlib.pyplot as pltclass KMeansPlot:def __init__(self, numClass=4, func_type='PCA'):if func_type == 'PCA':self.func_plot = PCA(n_components=2)elif func_type == 'TSNE':from sklearn.manifold import TSNEself.func_plot = TSNE(2)self.numClass = numClassdef plot_cluster(self, result, pos, cluster_centers=None):plt.figure(2)Lab = [[] for i in range(self.numClass)]index = 0for labi in result:Lab[labi].append(index)index += 1color = ['oy', 'ob', 'og', 'cs', 'ms', 'bs', 'ks', 'ys', 'yv', 'mv', 'bv', 'kv', 'gv', 'y^', 'm^', 'b^', 'k^','g^'] * 3for i in range(self.numClass):x1 = []y1 = []for ind1 in pos[Lab[i]]:# print ind1try:y1.append(ind1[1])x1.append(ind1[0])except:passplt.plot(x1, y1, color[i])if cluster_centers is not None:#绘制初始中心点x1 = []y1 = []for ind1 in cluster_centers:try:y1.append(ind1[1])x1.append(ind1[0])except:passplt.plot(x1, y1, "rv") #绘制中心plt.show()def plot(self, weight, label, cluster_centers=None):pos = self.func_plot.fit_transform(weight)# 高纬的中心点坐标,也经过降纬处理cluster_centers = self.func_plot.fit_transform(cluster_centers)self.plot_cluster(list(label), pos, cluster_centers)

kmeans.centroids :是一个高纬空间的中心点坐标,故在plot函数中,将其降纬到2D平面上;

k_plot = KMeansPlot(num_class)
k_plot.plot(corpus_embeddings.to('cpu'),labels.to('cpu'),kmeans.centroids.to('cpu')
)

在这里插入图片描述

完整可运行代码如下:
https://github.com/JieShenAI/csdn/blob/main/machine_learning/kmeans_pytorch.ipynb

相关资料参考

  • 动手实战基于 ML 的中文短文本聚类
  • tfidf和word2vec构建文本词向量并做文本聚类
    提到训练word2vec模型,silhouette_score_show(word2vec, 'word2vec') 轮廓系数,判断分几个类别最好。
  • 机器学习:Kmeans聚类算法总结及GPU配置加速demo
    PyTorch kmeans 加速。from scratch 实现;
  • KMeans算法全面解析与应用案例 通俗易懂的原理讲解
  • pytorch K-means算法的实现 底层代码实现
  • 【pytorch】Kmeans_pytorch用于一般聚类任务的代码模板 使用pytorch封装的kmeans包实现,包括训练和预测;
  • text2vec 包

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539752.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工具篇--从零开始学Git

一、git概述 1.1安裝 windows版本 官方下载(比较慢):Git - Downloads Linux版本 ​yum install git查看git版本。 git --version 1.2创建仓库gitee 注册账号 Gitee - 基于 Git 的代码托管和研发协作平台 新建仓库 honey2024 配置 git confi…

安装kibaba

官方地址:Past Releases of Elastic Stack Software | Elastic 直接下载就可以 安装好了之后开始配置文件/kibana/config打开kibanba.yml server.port:5601 服务器地址 sercer.name:kibana 服务器名称 kibana.index:.kibana 索引 elasticsearch.hosts:[http://1…

Kafka是什么,以及如何使用SpringBoot对接Kafka

系列文章目录 上手第一关,手把手教你安装kafka与可视化工具kafka-eagle 架构必备能力——kafka的选型对比及应用场景 Kafka存取原理与实现分析,打破面试难关 防止消息丢失与消息重复——Kafka可靠性分析及优化实践 Kafka是什么,以及如何使用…

mysql 主从延迟分析

一、如何分析主从延迟 分析主从延迟一般会采集以下三类信息。 从库服务器的负载情况 为什么要首先查看服务器的负载情况呢?因为软件层面的所有操作都需要系统资源来支撑。 常见的系统资源有四类:CPU、内存、IO、网络。对于主从延迟,一般会…

Gitlab光速发起Merge Request

前言 在我们日常开发过程中需要经常使用到Merge Request,在使用过程中我们需要来回在开发工具和UI界面之前来回切换,十分麻烦。那有没有一种办法可以时间直接开发开工具中直接发起Merge Request呢? 答案是有的。 使用 Git 命令方式创建 Me…

3dmax导入模型渲染过亮---模大狮模型网

在3ds Max中导入模型后渲染过亮可能是由于以下原因导致的: 材质和贴图设置: 检查导入的模型的材质和贴图设置,确保它们没有过度亮度或反射。调整材质的Diffuse(漫反射)颜色和Specular(高光)属性,以使渲染看起来更加自然。 光源设…

操作系统——中断

目录 前置知识 ​编辑 基本概念 1.中断特点 2.PSW(程序状态字,Program statement word) 中断的作用 中断的类型 中断嵌套、中断优先级、中断屏蔽 中断响应过程 前置知识 内核程序 :内核是操作系统的核心部分&#xff…

封装、继承、多态

1.封装 1.1 概念 举个例子解释,我们使用计算机时,并不关心内部核心部件,而是只需要知道如何开机,如何通过键盘鼠标与计算机进行交互即可.因此厂家在出厂计算机时,在外部套上壳子将内部实现细节隐藏起来,仅仅对外提供开机等,让用户可以与计算机进行交互即可. 封装: 将数据和…

14.WEB渗透测试--Kali Linux(二)

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于: 易锦网校会员专享课 上一个内容:13.WEB渗透测试--Kali Linux(一)-CSDN博客 netcat简介内容:13.WE…

log4cplus在Qt linux中的应用与问题解决

log4cplus在Qt linux中的应用与问题解决 背景log4cplus下载遇到问题:libm.so.6:undefined reference to __strtof128_nanGLIBC_PRIVATE‘解决方案编译生成在Qt工程里面添加对应依赖编译运行成功 背景 最近工作中需要用到log4cplus的日志做一些记录,用了…

R语言深度学习-3-过拟合问题(无监督正则化/Lasso回归/岭回归/集成和平均算法)

本教程参考《RDeepLearningEssential》 我们从上一个教程看到,我们看到在我们训练迭代或者训练更大神经网络的时候,往往会产生过拟合,而且越来越严重,它可能会把训练它的数据拟合的很好,但是未必能把新数据做的很好。…

通过路由器监控,优化网络效率

路由器是网络的基本连接组件,路由器监控涉及将路由器网络作为一个整体进行管理,其中持续监控路由器的性能、运行状况、安全性和可用性,以确保更好的操作和最短的停机时间,因此监控路由器至关重要。 为什么路由器监控对组织很重要…