手写Word2vec算法实现

news/2025/1/11 2:25:04/文章来源:https://www.cnblogs.com/zhangyh-blog/p/18200191

1. 语料下载:https://dumps.wikimedia.org/zhwiki/latest/zhwiki-latest-pages-articles.xml.bz2 【中文维基百科语料】

2. 语料处理

(1)提取数据集的文本

下载的数据集无法直接使用,需要提取出文本信息。

安装python库:

pip install numpy
pip install scipy
pip install gensim
python代码:
      
'''
Description: 提取中文语料
Author: zhangyh
Date: 2024-05-09 21:31:22
LastEditTime: 2024-05-09 22:10:16
LastEditors: zhangyh
'''
import logging
import os.path
import six
import sys
import warningswarnings.filterwarnings(action='ignore', category=UserWarning, module='gensim')
from gensim.corpora import WikiCorpusif __name__ == '__main__':program = os.path.basename(sys.argv[0])logger = logging.getLogger(program)logging.basicConfig(format='%(asctime)s: %(levelname)s: %(message)s')logging.root.setLevel(level=logging.INFO)logger.info("running %s" % ' '.join(sys.argv))# check and process input argumentsif len(sys.argv) != 3:print("Using: python process_wiki.py enwiki.xxx.xml.bz2 wiki.en.text")sys.exit(1)inp, outp = sys.argv[1:3]space = " "i = 0output = open(outp, 'w',encoding='utf-8')wiki = WikiCorpus(inp, dictionary={})for text in wiki.get_texts():output.write(space.join(text) + "\n")i=i+1if (i%10000==0):logger.info("Saved " + str(i) + " articles")output.close()logger.info("Finished Saved " + str(i) + " articles")

运行代码提取文本:

PS C:\Users\zhang\Desktop\nlp 自然语言处理\data> python .\process_wiki.py .\zhwiki-latest-pages-articles.xml.bz2 wiki_zh.text
2024-05-09 21:43:10,036: INFO: running .\process_wiki.py .\zhwiki-latest-pages-articles.xml.bz2 wiki_zh.text
2024-05-09 21:44:02,944: INFO: Saved 10000 articles
2024-05-09 21:44:51,875: INFO: Saved 20000 articles
...
2024-05-09 22:22:34,244: INFO: Saved 460000 articles
2024-05-09 22:23:33,323: INFO: Saved 470000 articles

提取后的文本(有繁体字):

(2)转繁体为简体

  • opencc工具进行繁简转换,下载opencc:https://bintray.com/package/files/byvoid/opencc/OpenCC
  • 执行命令进行转换
opencc -i wiki_zh.text -o wiki_sample_chinese.text -c "C:\Program Files\OpenCC\build\share\opencc\t2s.json"
  • 转换后的简体文本如下:

 (3)分词(使用jieba分词)

  • 分词代码:
      
'''
Description: 
Author: zhangyh
Date: 2024-05-10 22:48:45
LastEditTime: 2024-05-10 23:02:57
LastEditors: zhangyh
'''
#文章分词
import jieba
import jieba.analyse
import codecs
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))# def cut_words(sentence):
#     return " ".join(jieba.cut(sentence)).encode('utf-8')f=codecs.open('data\\wiki_sample_chinese.text','r',encoding="utf8")
target = codecs.open("data\\wiki_word_cutted_result.text", 'w',encoding="utf8")line_num=1
line = f.readline()
while line:print('---- processing', line_num, 'article----------------')line_seg = " ".join(jieba.cut(line))target.writelines(line_seg)line_num = line_num + 1line = f.readline()f.close()
target.close()# exit()
# while line:
#     curr = []
#     for oneline in line:
#         #print(oneline)
#         curr.append(oneline)
#     after_cut = map(cut_words, curr)
#     target.writelines(after_cut)
#     print ('saved',line_num,'articles')
#     exit()
#     line = f.readline1()
# f.close()
# target.close()
  • 分词后的结果

 

3. 模型训练

(1)skip-gram模型

      
'''
Description: 
Author: zhangyh
Date: 2024-05-12 21:51:03
LastEditTime: 2024-05-16 11:08:59
LastEditors: zhangyh
'''
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import os
import sys
import randomsys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))def load_stop_words(file = "作业-skipgram\\stopwords.txt"):with open(file,"r",encoding = "utf-8") as f:return f.read().split("\n")def load_cutted_data(num_lines: int):stop_words = load_stop_words()data = []# with open('wiki_word_cutted_result.text', mode='r', encoding='utf-8') as file:with open('作业-skipgram\\wiki_word_cutted_result.text', mode='r', encoding='utf-8') as file:for line in tqdm(file.readlines()[:num_lines]):words_list = line.split()words_list = [word for word in words_list if word not in stop_words]data += words_listdata = list(set(data))return datadef get_dict(data):index_2_word = []word_2_index = {}for word in tqdm(data):if word not in word_2_index:index = len(index_2_word)word_2_index[word] = indexindex_2_word.append(word)word_2_onehot = {}word_size = len(word_2_index)for word, index in tqdm(word_2_index.items()):one_hot = np.zeros((1, word_size))one_hot[0, index] = 1word_2_onehot[word] = one_hotreturn word_2_index, index_2_word, word_2_onehotdef softmax(x):ex = np.exp(x)return ex/np.sum(ex,axis = 1,keepdims = True)# 负采样
# def negative_sampling(word_2_index, word_count, num_negative_samples):
#     word_probs = [word_count[word]**0.75 for word in word_2_index]
#     word_probs = np.array(word_probs) / sum(word_probs)
#     neg_samples = np.random.choice(len(word_2_index), size=num_negative_samples, replace=True, p=word_probs)
#     return neg_samplesif __name__ == "__main__":batch_size = 562  # 定义批量大小data = load_cutted_data(5)word_2_index, index_2_word, word_2_onehot = get_dict(data)word_size = len(word_2_index)embedding_num = 100lr = 0.01epochs = 200n_gram = 3# num_negative_samples = 5# 计算词频# word_count = dict.fromkeys(word_2_index, 0)# for word in data:#     word_count[word] += 1batches = [data[j:j+batch_size] for j in range(0, len(data), batch_size)]w1 = np.random.normal(-1,1,size = (word_size,embedding_num))w2 = np.random.normal(-1,1,size = (embedding_num,word_size))for i in range(epochs):print(f'-------- epoch {i + 1} --------')for batch in tqdm(batches):for i in tqdm(range(len(batch))):now_word = batch[i]now_word_onehot = word_2_onehot[now_word]other_words = batch[max(0, i - n_gram): i] + batch[i + 1: min(len(batch), i + n_gram + 1)]for other_word in other_words:other_word_onehot = word_2_onehot[other_word]hidden = now_word_onehot @ w1p = hidden @ w2pre = softmax(p)# A @ B = C# delta_C = G# delta_A = G @ B.T# delta_B = A.T @ GG2 = pre - other_word_onehotdelta_w2 = hidden.T @ G2G1 = G2 @ w2.Tdelta_w1 = now_word_onehot.T @ G1w1 -= lr * delta_w1w2 -= lr * delta_w2with open("作业-skipgram\\word2vec_skipgram.pkl","wb") as f:# with open("word2vec_skipgram.pkl","wb") as f:pickle.dump([w1, word_2_index, index_2_word, w2], f) 

  

(2)CBOW 模型

      
'''
Description: 
Author: zhangyh
Date: 2024-05-13 20:47:57
LastEditTime: 2024-05-16 09:21:40
LastEditors: zhangyh
'''
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import os
import syssys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))def load_stop_words(file = "stopwords.txt"):with open(file,"r",encoding = "utf-8") as f:return f.read().split("\n")def load_cutted_data(num_lines: int):stop_words = load_stop_words()data = []with open('wiki_word_cutted_result.text', mode='r', encoding='utf-8') as file:# with open('作业-CBOW\\wiki_word_cutted_result.text', mode='r', encoding='utf-8') as file:for line in tqdm(file.readlines()[:num_lines]):words_list = line.split()words_list = [word for word in words_list if word not in stop_words]data += words_listdata = list(set(data))return datadef get_dict(data):index_2_word = []word_2_index = {}for word in tqdm(data):if word not in word_2_index:index = len(index_2_word)word_2_index[word] = indexindex_2_word.append(word)word_2_onehot = {}word_size = len(word_2_index)for word, index in tqdm(word_2_index.items()):one_hot = np.zeros((1, word_size))one_hot[0, index] = 1word_2_onehot[word] = one_hotreturn word_2_index, index_2_word, word_2_onehotdef softmax(x):ex = np.exp(x)return ex/np.sum(ex,axis = 1,keepdims = True)if __name__ == "__main__":batch_size = 562  data = load_cutted_data(5)word_2_index, index_2_word, word_2_onehot = get_dict(data)word_size = len(word_2_index)embedding_num = 100lr = 0.01epochs = 200context_window = 3batches = [data[j:j+batch_size] for j in range(0, len(data), batch_size)]w1 = np.random.normal(-1,1,size = (word_size,embedding_num))w2 = np.random.normal(-1,1,size = (embedding_num,word_size))for i in range(epochs):print(f'-------- epoch {i + 1} --------')for batch in tqdm(batches):for i in tqdm(range(len(batch))):target_word = batch[i]context_words = batch[max(0, i - context_window): i] + batch[i + 1: min(len(batch), i + context_window + 1)]# 获取上下文词的词向量的平均值作为输入context_vectors = np.mean([word_2_onehot[word] for word in context_words], axis=0)# 计算输出层hidden = context_vectors @ w1p = hidden @ w2pre = softmax(p)# 交叉熵损失函数# loss = -np.log(pre[word_2_index[target_word], 0])# 反向传播更新参数G2 = pre - word_2_onehot[target_word]delta_w2 = hidden.T @ G2G1 = G2 @ w2.Tdelta_w1 = context_vectors.T @ G1w1 -= lr * delta_w1w2 -= lr * delta_w2# with open("作业-CBOW\\word2vec_cbow.pkl","wb") as f:with open("word2vec_cbow.pkl","wb") as f:pickle.dump([w1, word_2_index, index_2_word, w2], f)

  

4. 训练结果

(1)余弦相似度计算

      
'''
Description: 
Author: zhangyh
Date: 2024-05-13 20:12:56
LastEditTime: 2024-05-16 21:16:19
LastEditors: zhangyh
'''
import pickle
import numpy as np# w1, voc_index, index_voc, w2 = pickle.load(open('word2vec_cbow.pkl','rb'))
w1, voc_index, index_voc, w2 = pickle.load(open('作业-CBOW\\word2vec_cbow.pkl','rb'))def word_voc(word):return w1[voc_index[word]]def voc_sim(word, top_n):v_w1 = word_voc(word)word_sim = {}for i in range(len(voc_index)):v_w2 = w1[i]theta_sum = np.dot(v_w1, v_w2)theta_den = np.linalg.norm(v_w1) * np.linalg.norm(v_w2)theta = theta_sum / theta_denword = index_voc[i]word_sim[word] = thetawords_sorted = sorted(word_sim.items(), key=lambda kv: kv[1], reverse=True)for word, sim in words_sorted[:top_n]:# print(f'word: {word}, similiar: {sim}, vector: {w1[voc_index[word]]}')print(f'word: {word}, similiar: {sim}')voc_sim('学院', 20)

  

(2)可视化展示

      
'''
Description: 
Author: zhangyh
Date: 2024-05-16 21:41:33
LastEditTime: 2024-05-17 23:50:07
LastEditors: zhangyh
'''
import numpy as np
import pandas as pd
import pickle
from sklearn.decomposition import PCA
import matplotlib.pyplot as pltplt.rcParams['font.family'] = ['Microsoft YaHei', 'SimHei', 'sans-serif']# Load trained word embeddings
with open("word2vec_cbow.pkl", "rb") as f:w1, word_2_index, index_2_word, w2 = pickle.load(f)# Select specific words for visualization
visual_words = ['研究', '电脑', '雅典', '数学', '数学家', '学院', '函数', '定理', '实数', '复数']# Get the word vectors corresponding to the selected words
subset_vectors = np.array([w1[word_2_index[word]] for word in visual_words])# Perform PCA for dimensionality reduction
pca = PCA(n_components=2)
reduced_vectors = pca.fit_transform(subset_vectors)# Visualization
plt.figure(figsize=(10, 8))
plt.scatter(reduced_vectors[:, 0], reduced_vectors[:, 1], marker='o')
for i, word in enumerate(visual_words):plt.annotate(word, xy=(reduced_vectors[i, 0], reduced_vectors[i, 1]), xytext=(5, 2),textcoords='offset points', ha='right', va='bottom')
plt.title('Word Embeddings Visualization')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.grid(True)
plt.show()

 (3)类比实验探索(例如:王子 - 男 + 女 = 公主)

'''
Description: 
Author: zhangyh
Date: 2024-05-16 23:13:21
LastEditTime: 2024-05-19 11:51:53
LastEditors: zhangyh
'''
import numpy as np
import pickle
from sklearn.metrics.pairwise import cosine_similarity# 加载训练得到的词向量
with open("word2vec_cbow.pkl", "rb") as f:w1, word_2_index, index_2_word, w2 = pickle.load(f)# 计算类比关系
v_prince = w1[word_2_index["王子"]]
v_man = w1[word_2_index["男"]]
v_woman = w1[word_2_index["女"]]
v_princess = v_prince - v_man + v_woman# 找出最相近的词向量
similarities = cosine_similarity(v_princess.reshape(1, -1), w1)
most_similar_index = np.argmax(similarities)
most_similar_word = index_2_word[most_similar_index]print("结果:", most_similar_word)

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/708673.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IP地址与子网掩码的关系

IP地址与子网掩码的关系IP地址与子网掩码的关系网站:http://shibowl.topgithub:https://github.com/hanbinjxnc博客园:https://www.cnblogs.com/hool 博客:https://blog.shibowl.top 作者:世博 2019年4月28日

2024/05/19

复盘金龙汽车 是趋势/不能突破买入,首次突破时假突破,有时候二次突破也是假突破,第三次才成功,假突破后会深度快速回踩深蹲,如果买入会吃大面!真突破后往往连续快速上涨,不会给上车机会。 像这种一涨三回头的票,经常有假突破,不适合用突破买入,适合用低吸,均线位置缩…

操作系统基础——01 操作系统基本概念

操作系统基础——01 操作系统基本概念目录计算机系统的层次结构操作系统的定义操作系统的功能和目标作为系统资源的管理者向上层提供方便易用的服务作为最接近硬件的层次操作系统的四个特征并发共享虚拟异步操作系统的发展与分类操作系统的运行机制中断和异常中断的作用中断类型…

SpringCloud(3)-OpenFeign相关配置

OpenFeign 是个声明式 WebService 客户端,使用 OpenFeign 让编写 Web Service 客户端更简单。 Spring Cloud 对 OpenFeign 进 行 了 封 装 使 其 支 持 了 Spring MVC 标 准 注 解 和 HttpMessageConverters。 OpenFeign 可以与 Eureka 和 Ribbon 组合使用以支持负载均衡。1.配…

如此丝滑的API设计,用起来真香

谈及软件中的设计,无论是架构设计还是程序设计还是说API设计, 原则其实都差不多,要能够松耦合、易扩展、注意性能。遵循上述这些API的设计规则, 相信大家都能设计出比较丝滑的API。当然如果还有其他的API设计中的注意点也欢迎在评论区留言。分享是最有效的学习方式。 博客:…

《user-agent(UA)识别 Api 接口助力智能应用开发》

在现代智能应用的开发中,往往需要对用户的设备和浏览器进行识别,以便适配不同的操作系统和浏览器。而user-agent是一种非常重要的信息,它包含了用户设备、操作系统和浏览器的相关信息。在本文中,我们将介绍一个强大的user-agent识别 API 接口,它可以帮助开发者轻松实现用户…

科学时如何更快进行DNS解析及微信双开

如何更快进行DNS解析科学了,发现访问很慢,有时还无法访问,明显是被某种神秘的东方力量给阻断了。 DNS解析就起作用了。可以快速寻址,目前国内比较知名的且比较快的就是阿里云的:223.5.5.5。但是呢,这还需要看你自己的网络是哪家的,去访问国际的时候路由节点是否在国内来…

eclipse安装tomcat

一、确保Tomcat服务器处于关闭状态在配置之前确保tomcat服务器处于关闭状态,若tomcat处于启动状态则将其关闭,Service Status的值为Stopped表明Tomcat已经关闭 二、在Eclipse中配置Tomcat打开Eclipse---->点击Window---->点击Preferences点击Server---->点击Runtime…

Redis安装之集群-集群(cluster)模式

一、背景 Redis 哨兵模式在一定程度上解决的系统的高可用问题,但单 master 节点的写入也成为了系统处理高并发请求时的瓶颈。 二、方案原理采用多个 master 节点集群模式实现 Redis 水平扩容,提供并发请求处理能力; cluster 自带 sentinel 故障转移机制,无需再使用哨兵功能…

主流原型设计工具介绍

当谈到原型设计工具时,Axure 和墨刀是两个备受推崇的选择。它们各自拥有独特的特点和优势,适用于不同的设计需求和团队工作流程。今天我会重点介绍这两种工具的特点以及使用方法,并且简单介绍其他的一些原型设计工具例如:Sketch,Figma Axure Axure 是一款功能强大的原型设计…

【HFSS】看多个频点的三维方向图

1.扫频设置 扫频种类为Discrete,记得要保存场,Save Fields2.查看结果solution选择Sweep1,就是刚才新建的扫频设置即可在选项卡Families里面可以选择要查看的频点