在 PyTorch 中理解词向量,将单词转换为有用的向量表示

news/2025/2/10 9:09:26/文章来源:https://www.cnblogs.com/jellyai/p/18707099

你要是想构建一个大型语言模型,首先得掌握词向量的概念。幸运的是,这个概念很简单,也是本系列文章的一个完美起点。

那么,假设你有一堆单词,它可以只是一个简单的字符串数组。

animals = ["cat", "dog", "rat", "pig"]

你没法直接用单词进行数学运算,所以必须先把它们转换成数字。最简单的方法就是用它们在数组中的索引值。

animal_to_idx = {animal: idx for idx, animal in enumerate(animals)}

animal_to_idx

Output:

当然,等你把数学运算做完,你还需要把索引转换回对应的单词。可以这样做:

idx_to_animal = {idx: animal for animal, idx in animal_to_idx.items()}

idx_to_animal

Output:

用索引来表示单词,在自然语言处理中一般不是个好主意。问题在于,索引会暗示单词之间存在某种顺序关系,而实际上并没有。

比如,我们的数据里,猫和猪之间并没有固有的关系,狗和老鼠之间也没有。但是,使用索引后,看起来猫离猪“很远”,而狗似乎“更接近”老鼠,仅仅因为它们在数组中的位置不同。这些数值上的距离可能会暗示一些实际上并不存在的模式。同样,它们可能会让人误以为这些动物之间存在基于大小或相似度的关系,而这在这里完全没有意义。

一个更好的方法是使用独热编码(one-hot encoding)。独热向量是一个数组,其中只有一个元素是 1(表示“激活”),其他所有元素都是 0。这种表示方式可以完全消除单词之间的错误排序关系。

让我们把单词转换成独热向量:

import numpy as np

n_animals = len(animals)

animal_to_onehot = {}

for idx, animal in enumerate(animals):

one_hot = np.zeros(n_animals, dtype=int)

one_hot[idx] = 1

animal_to_onehot[animal] = one_hot

animal_to_onehot

Output:

{

'cat': array([1, 0, 0, 0]),

'dog': array([0, 1, 0, 0]),

'rat': array([0, 0, 1, 0]),

'pig': array([0, 0, 0, 1])

}

可以看到,现在单词之间没有任何隐含的关系了。

独热编码的缺点是,它是一种非常稀疏的表示,只适用于单词数量较少的情况。想象一下,如果你有 10,000 个单词,每个编码都会有 9,999 个零和一个 1,太浪费内存了,存那么多零干嘛……

是时候创建更密集的向量表示了。换句话说,我们现在要做词向量(word embeddings)了。

词向量是一种密集向量(dense vector),其中大多数(甚至所有)值都不是零。在机器学习,尤其是自然语言处理和推荐系统中,密集向量可以用来紧凑而有意义地表示单词(或句子、或其他实体)的特征。更重要的是,它们可以捕捉这些特征之间的有意义关系。

举个例子,我们创建一个词向量,其中每个单词用 2 个特征表示,而总共有 4 个单词。

用 PyTorch 创建词向量非常简单。我们只需要使用 nn.Embedding 层。你可以把它想象成一个查找表,其中行代表每个唯一单词,而列代表该单词的特征(即单词的密集向量)。

import torch

import torch.nn as nn

embedding_layer = nn.Embedding(num_embeddings=4, embedding_dim=2)

好,现在我们把单词的索引转换成词向量。这几乎不费吹灰之力,因为我们只需要把索引传给 nn.Embedding 层就行了。

indices = torch.tensor(np.arange(0, len(animals)))

indices

Output:

tensor([0, 1, 2, 3])

embeddings = embedding_layer(indices)

embeddings

Output:

tensor([[ 1.6950, -2.7905],

[ 2.4086, -0.1779],

[ 0.7402, 0.0955],

[-0.5155, 0.0738]], grad_fn=)

现在,我们可以用索引查看每个单词的词向量了。

for animal, _ in animal_to_idx.items():

print(f"{animal}'s embedding is {embeddings[animal_to_idx[animal]]}")

Output:

cat's embedding is tensor([ 1.6950, -2.7905], grad_fn=)

dog's embedding is tensor([ 2.4086, -0.1779], grad_fn=)

rat's embedding is tensor([0.7402, 0.0955], grad_fn=)

pig's embedding is tensor([-0.5155, 0.0738], grad_fn=)

每个单词都有两个特征——正是我们想要的结果。

目前这些数值没啥实际意义,因为 nn.Embedding 层还没有经过训练。但一旦它被适当地训练了,这些特征就会变得有意义。

注意:

这些特征对模型来说非常关键,但对人类来说可能永远不会“有意义”。它们代表的是通过训练学到的抽象特征。对我们来说,这些特征看起来可能是随机的、毫无意义的,但对一个训练好的模型来说,它们能够捕捉到重要的模式和关系,使其能够有效地理解和处理数据。

在本系列的下一篇文章中,我们将学习如何训练词向量模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/881542.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决 virsh 无法通过 console 登录虚拟机

报错信息解决方法 登录虚拟机 通过 vnc 或者 ssh登录虚拟机 在虚拟机中执行已执行内容 echo "ttyS0" >> /etc/securetty echo "S0:12345:respawn:/sbin/agetty ttyS0 115200, 1152000 xterm" >>/etc/inittab grub2.cfg grep kernel /etc/grub2…

【分享】Ftrans内外网数据交换方案:打破网络边界,高效融合数据!

随着网络技术和互联网技术的成熟和高速发展,来自互联网的安全威胁越来越严重。数据隔离对很多企业来说并不陌生,越来越多的企业选择网络隔离技术来保护网络安全,而在500强企业中,使用网络隔离技术的企业几乎是绝对性的占比。网络隔离后,仍存在数据交换的需求,内外网数据交…

[gym 102428] Fabricating Sculptures

前言 现在补题是不是不太好 思路 转化题意给定列数 \(S\) , 方块数 \(B\) , 求一种摆放方式, 使得每一列的方块数 \(a_i\) 满足 \(a_i \geq 1\) 且 \(a_i\) 呈非严格单峰观察单峰函数的性质, 发现如果按行处理, 就是单调不增的 更一般的, 每一行放置的方块数非严格小于上一行放…

重做 CF906E Reverses

不是,JJ 怎么退役了,悲。 嗯,先有\[dp_{r}=\min_{l=1}^r[s(l,r)\text{是回文串,且长度不为 2}]dp_{l-1}+1 \] 总复杂度就 \(n^2\),考虑优化 然后有引理说,一个字符串的所有 border 构成 \(\log n\) 个等差数列 我们考虑什么样的点能够转移到 \(dp_i\)尝试借助 \(log\) 段…

Zerto 10.0 Update 6 下载 - 适用于本地、混合和多云环境的灾难恢复和数据保护

Zerto 10.0 Update 6 下载 - 适用于本地、混合和多云环境的灾难恢复和数据保护Zerto 10.0 U6 - 适用于本地、混合和多云环境的灾难恢复和数据保护 勒索软件防护、灾难恢复和多云移动性的统一解决方案 请访问原文链接:https://sysin.org/blog/zerto-10/ 查看最新版。原创作品,…

OpenWrt 24.10 OVF:在 ESXi 8.0、Fusion 13 和 Workstation 17 上运行 OpenWrt 的简单方法

OpenWrt 24.10 OVF:在 ESXi 8.0、Fusion 13 和 Workstation 17 上运行 OpenWrt 的简单方法OpenWrt 24.10 OVF:在 ESXi 8.0、Fusion 13 和 Workstation 17 上运行 OpenWrt 的简单方法 OpenWrt 24.10.0 x86_64 OVF 请访问原文链接:https://sysin.org/blog/openwrt-ovf/ 查看最…

SQL Server 2022新功能:将数据库备份到S3兼容的对象存储

SQL Server 2022新功能:将数据库备份到S3兼容的对象存储 本文介绍将S3兼容的对象存储用作数据库备份目标所需的概念、要求和组件。 数据库备份和恢复功能在概念上类似于使用SQL Server备份到Azure Blob存储的URL作为备份设备类型。 要注意的是,不只是amazon S3对象存储,只要…

护眼神器!LightBulb电脑屏幕护眼软件,你值得拥有!

点击上方蓝字关注我 前言 LightBulb是一个免费的护眼软件,它可以帮助我们在晚上或长时间看电脑屏幕时,减少眼睛的不舒服和疲劳。这个软件会随着一天时间的推移,自动调整电脑屏幕的颜色。比如,在白天,它会让屏幕颜色更偏向冷蓝色,就像阳光下的颜色;到了晚上,它会让屏幕颜…

【JWT安全】攻防指南全面梳理

一、简单介绍 JWT(JSON Web Token)是一种用于身份认证和授权的开放标准,它通过在网络应用间传递被加密的JSON数据来安全地传输信息使得身份验证和授权变得更加简单和安全,JWT对于渗透测试人员而言可能是一种非常吸引人的攻击途径,因为它们不仅是让你获得无限访问权限的关键而…

【CodeForces训练记录】Codeforces Round 1003 (Div. 4)

训练情况赛后反思 题面读的有点疑惑,怀疑自己阅读理解不大行了,简单题狂WA,C2二分调半天没出,水平严重退步 A题 最后两个字母 us 换成 i点击查看代码 #include <bits/stdc++.h> // #define int long long #define endl \nusing namespace std;void solve(){string s;…

[流程图/技术调研] drawio : 流程图绘制工具

引言 流程图绘制工具: draw.io 简介urlhttps://www.drawio.com/ (官网首页) https://github.com/jgraph/drawio (github)【官网简介】 drawio 这个项目,是一个可配置的图表/白板可视化应用程序。drawio 是由 JGraph Ltd 和 draw AG 共同拥有和开发的。 在运行这个项目的同时,…