Word2Vec的CBOW模型

Word2Vec中的CBOW(Continuous Bag of Words)模型是一种用于学习词向量的神经网络模型。CBOW的核心思想是根据上下文中的周围单词来预测目标单词。

例如,对于句子“The cat climbed up the tree”,如果窗口大小为5,那么当中心单词为“climbed”时,上下文单词为“The”、“cat”、“up”和“the”。CBOW模型要求根据这四个上下文单词,计算出“climbed”的概率分布。

一个简单的CBOW模型

import torch
import torch.nn as nn
import torch.optim as optim# 定义CBOW模型
class CBOWModel(nn.Module):def __init__(self, vocab_size, embed_size):super(CBOWModel, self).__init__()self.embeddings = nn.Embedding(vocab_size, embed_size)self.linear = nn.Linear(embed_size, vocab_size)def forward(self, context):embedded = self.embeddings(context)embedded_sum = torch.sum(embedded, dim=1)output = self.linear(embedded_sum)return output# 定义训练函数
def train_cbow(data, target, model, criterion, optimizer):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()return loss.item()# 假设有一个简单的语料库和单词到索引的映射
corpus = ["I like deep learning", "I enjoy NLP", "I love PyTorch"]
word_to_index = {"I": 0, "like": 1, "deep": 2, "learning": 3, "enjoy": 4, "NLP": 5, "love": 6, "PyTorch": 7}# 将语料库转换为训练数据
context_size = 3
data = []
target = []
for sentence in corpus:tokens = sentence.split()for i in range(context_size, len(tokens) - context_size):context = [word_to_index[tokens[j]] for j in range(i - context_size, i + context_size + 1) if j != i]target_word = word_to_index[tokens[i]]data.append(torch.tensor(context, dtype=torch.long))target.append(torch.tensor(target_word, dtype=torch.long))# 超参数
vocab_size = len(word_to_index)
embed_size = 10
learning_rate = 0.01
epochs = 100# 初始化模型、损失函数和优化器
cbow_model = CBOWModel(vocab_size, embed_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cbow_model.parameters(), lr=learning_rate)# 开始训练
for epoch in range(epochs):total_loss = 0for i in range(len(data)):loss = train_cbow(data[i], target[i], cbow_model, criterion, optimizer)total_loss += lossprint(f'Epoch {epoch + 1}/{epochs}, Loss: {total_loss}')# 获取词向量
word_embeddings = cbow_model.embeddings.weight.detach().numpy()
print("Word Embeddings:\n", word_embeddings)
  1. CBOW模型定义(class CBOWModel):

    • __init__ 方法:在初始化过程中定义了两个层,一个是nn.Embedding用于获取词向量,另一个是nn.Linear用于将词向量求和后映射到词汇表大小的空间
    • forward 方法:定义了模型的前向传播过程。给定一个上下文,首先通过Embedding层获取词向量,然后对词向量进行求和,最后通过Linear层进行映射。
  2. 训练函数(train_cbow):

    • train_cbow 函数用于训练CBOW模型。接受训练数据、目标、模型、损失函数和优化器作为输入,并执行前向传播、计算损失、反向传播和优化器更新权重的过程。
  3. 语料库和单词到索引的映射:

    • corpus 包含了三个简单的句子。
    • word_to_index 是单词到索引的映射。
  4. 将语料库转换为训练数据:

    • 对每个句子进行分词,然后构建上下文和目标。上下文是目标词的上下文词的索引列表,目标是目标词的索引。
  5. 超参数和模型初始化:

    • vocab_size 是词汇表大小。
    • embed_size 是词向量的维度。
    • learning_rate 是优化器的学习率。
    • epochs 是训练迭代次数。
    • CBOWModel 实例化为 cbow_model
    • 使用交叉熵损失函数和随机梯度下降(SGD)优化器。
  6. 训练过程:

    • 使用嵌套的循环对训练数据进行多次迭代。
    • 对每个训练样本调用 train_cbow 函数,计算损失并更新模型权重。
  7. 获取词向量:

    • 通过 cbow_model.embeddings.weight 获取训练后的词向量矩阵,并将其转换为 NumPy 数组。

需要注意的是,代码中的训练过程比较简单,通常在实际应用中可能需要更复杂的数据集、更大的模型和更多的训练策略。此处的代码主要用于展示CBOW模型的基本实现。

在CBOW(Continuous Bag of Words)模型中,神经网络的输入和输出数据的构造方式如下:

  1. 输入数据:

    • 对于每个训练样本,输入数据是上下文窗口内的单词的独热编码(one-hot encoding)向量的拼接。
    • 上下文窗口大小为3,因此对于每个目标词,上下文窗口内有3个单词。这3个单词的独热编码向量会被拼接在一起作为输入。
    • 对于语料库中的每个目标词,都会生成一个对应的训练样本。

    以 "I like deep learning" 为例:

    • "deep" 是目标词,上下文窗口为["like", "I", "learning"]。
    • 对应的独热编码向量分别是 [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0]。
    • 这三个向量拼接在一起作为神经网络的输入。

    对于整个语料库,这个过程会生成一组输入数据。

  2. 输出数据:

    • 输出数据是目标词的独热编码向量,表示模型要预测的词。
    • 对于 "I like deep learning" 中的 "deep",其对应的独热编码向量是 [0, 0, 0, 1, 0, 0, 0, 0]。
    • 整个语料库中,为每个目标词生成相应的输出数据。

综上所述,CBOW模型的神经网络输入数据是上下文窗口内单词的拼接独热编码向量,输出数据是目标词的独热编码向量。在训练过程中,模型通过学习输入与输出之间的映射关系,逐渐调整权重以更好地捕捉语境信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/338187.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[机缘参悟-122] :IT人如何认识自己的?自省、面试、考核、咨询?

目录 一、为什么要认识自己 二、认识自己的哪些方面? 三、如何认识自己 3.1 通过自省认识自己 3.2 通过面试认识自己 3.3 通过咨询认识自己 3.4 通过相亲认识自己 3.5 通过一段感情关系认识自己 一、为什么要认识自己 认识自己在人类的成长和心灵发展过程中…

光纤知识总结

1光纤概念: 光导纤维(英语:Optical fiber),简称光纤,是一种由玻璃或塑料制成的纤维,利用光在这些纤维中以全内反射原理传输的光传导工具。 微细的光纤封装在塑料护套中,使得它能够…

CSND修改付费专栏价格

人工客服在个人中心右下角可以找到 客服回复已订阅专栏不支持修改价格

【ECShop电子商务系统__软件测试作业】ECSHOP系统搭建文档+接口测试用例+接口文档+接口测试脚本

一、选题题目可选《ECShop电子商务系统》、《EPShop电子商城系统》或者自选其它的开源系统(至少有十个以上的功能模块的系统,不得选功能少、简单的系统)。 软件测试作业 说明:接口测试相关资料 二、具体要求 1、搭建测试系统并写出搭建被测系统的全过程。 2、根…

Nginx介绍与安装

目录 nginx服务 1、Nginx 介绍 2、为什么选择 nginx 3、IO多路复用 1、I/O multiplexing【多并发】 2、一个请求到来了,nginx使用epoll接收请求的过程是怎样的? 3、异步,非阻塞 4、nginx 的内部技术架构 5、yum安装部署nginx和配置管理 1.获取…

Kafka集群部署 (KRaft模式集群)

KRaft 模式是 Kafka 在 3.0 版本中引入的新模式。KRaft 模式使用了 Raft 共识算法来管理 Kafka 集群元数据。Raft 算法是一种分布式共识算法,具有高可用性、可扩展性和安全性等优势。 在 KRaft 模式下,Kafka 集群中的每个 Broker 都具有和 Zookeeper 类…

Redis命令总结

1、启动Redis服务,登录Redis # 开启redis服务 redis-server redis配置文件路径例子: redis-server redis.windows.conf# 连接redis 【无密码】 redis-cli# 连接redis【有密码】 # 1 先连接再输入密码 redis-cli auth 密码 2、连接时输入 IP址、端口号、…

仿蓝奏云网盘 /file/list SQL注入漏洞复现

0x01 产品简介 仿蓝奏网盘是一种类似于百度网盘的文件存储和共享解决方案。它为用户提供了一个便捷的平台,可以上传、存储和分享各种类型的文件,方便用户在不同设备之间进行文件传输和访问。 0x02 漏洞概述 仿蓝奏云网盘 /file/list接口处存在SQL注入漏洞,登录后台的攻击…

烟火检测/区域人流统计/AI智能分析网关V4如何配置通道?

TSINGSEE青犀智能分析网关(V4版)是一款高性能、低功耗的软硬一体AI边缘计算硬件设备,硬件内部署了近40种AI算法模型,支持对接入的视频图像进行人、车、物、行为等实时检测分析,并上报识别结果,并能进行语音…

JS 监听网络状态

我们在开发过程中会遇到监听用户网络状态的需求,通过JS可以获取当前的网络状态,包括下载速度、网络延迟、网络在线状态、网络类型等信息 具体获取如下: let info navigator.connection console.log(info)可以看到,包含几个信息…

基于uniapp封装的card容器 带左右侧两侧标题内容区域

代码 <template><view class"card"><div class"x_flex_header"><div><title v-if"title ! " class"title" :title"title" :num"num"></title></div><div><s…

【人工智能】智能电网:未来能源的革命

未来能源的革命 智能电网革命的意义在于将电力行业从传统的集中式发电和集中式输配电模式转变为智能化、分布式、互动式的能源网络。 现在我们从以下方面详细认真的了解一下智能电网&#xff1a; 智能变电站&#xff0c;智能配电网&#xff0c;智能电能表&#xff0c;智能交互…