nn.Embedding()函数详解

news/2025/3/16 21:45:52/文章来源:https://www.cnblogs.com/wangxiaobin/p/18775759

nn.Embedding()函数详解

nn.Embedding()函数:随机初始化词向量,词向量在正态分布N(0,1)中随机取值

输入:

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None)

num_embeddings:词典的大小尺寸,比如该词典共有5000词,则num_embeddings=5000,此时index的取值范围为0-4999

embedding_dim:词嵌入向量的维度,即用多少维度表示一个符号/词

padding_idx=None:填充id,比如,输入的句子的长度为100,但是每次的句子长度并不一定是相同的,后面就需要统一用数字填充,这里就是相当于指定填充的数字。这样,网络在遇到填充id时,就不会计算其与其他符号的相关性。(初始化为0) 或者另一种说法,padding_idx是不更新梯度的“单词”的index,可以在字典中指定一个不被训练的embedding

max_norm=None:最大范数,暂不考虑

norm_type=2.0:指定利用什么范数计算,暂不考虑

scale_grad_by_freq=False:根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False

输出:

[length_seq, batch_size, embedding_dim]

length_seq:词向量长度

batch_size:批次数量

embedding_dim:嵌入词向量维度

举例如下:

import torch# 创建page索引
a = torch.LongTensor([[1,3], [3, 8]])# 创建一个词典,词典包含词的数量为10,每个词的维度为5
emb = torch.nn.Embedding(10, 5)
print(emb.weight, emb.weight.shape)# 通过索引查询emb内容
y = emb(a)
print(y, y.shape)
# 


关于padding_idx,看下面的例子:

import torcha = torch.LongTensor([[1, 3], [3, 5]])emb = torch.nn.Embedding(10, 5, padding_idx=0)
print(emb.weight, emb.weight.shape)
y = emb(a)
print(y, y.shape)


其中,emb不仅仅是一个矩阵,其属性有以下:

简单来说,nn.Embedding()就是随机初始化了一个[num_embeddings, embedding_dim]的二维表格,每一行代表着对应索引的词向量的表示。我们要想得到一句话的初始化词向量,需要将句子进行分词,即得到每个词的索引,将索引送入nn.embedding()函数中,会自动在已经建立的二维表中找到索引对应的初始化词向量

参考链接:

关于nn.Embedding的解释,以及它是如何将一句话变成vector的

torch.nn.Embedding函数用法图解

通俗讲解PyTorch中nn.Embedding原理及使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/900004.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

htb Authority

端口扫描 nmap -sC -sV -p- -Pn -T4 10.10.11.222 Starting Nmap 7.92 ( https://nmap.org ) at 2024-10-04 19:42 CST Nmap scan report for 10.10.11.222 (10.10.11.222) Host is up (0.40s latency). Not shown: 65506 closed tcp ports (reset) PORT STATE SERVICE …

蓝桥杯14届省B

蓝桥杯14届省赛B组A:int a[105]; int day[]={0,31,28,31,30,31,30,31,31,30,31,30,31};//记录每个月有多少天 set<int> st;//记录不重复的日期void check(int mm,int dd){if (mm>12||mm<1||dd<1||dd>day[mm]) return;else st.insert(mm*100+dd);//st存日期 …

docker 安装 oracle database 问题记录

pre本地docker (WSL)安装运行 Oracle1. 镜像处理参考链接:https://www.cnblogs.com/wuchangsoft/p/18344847 oracle 镜像获取:https://container-registry.oracle.com/ords/f?p=113:10:::::: (Oracle官网,由于部分问题导致直接pull无法拉取) 阿里云,参考链接里有个个人19…

20242103 实验一《Python程序设计》实验报告

20242103 《Python程序设计》实验1报告 课程:《Python程序设计》 班级: 2421 姓名: 李雨虓 学号:20242103 实验教师:王志强 实验日期:2025年3月12日 必修/选修: 公选课 1.实验内容: 1.熟悉Python开发环境; 2.练习Python运行、调试技能;(编写书中的程序,并进行调试…

20241313 2024-2025-2 《Python程序设计》实验一报告

20241313 2024-2025-2 《Python程序设计》实验一报告 课程:《Python程序设计》 班级: 2413 姓名: 刘鸣宇 学号:20241313 实验教师:王志强 实验日期:2025年3月12日 必修/选修: 公选课 1.实验内容 1.熟悉Python开发环境; 2.练习Python运行、调试技能;(编写书中的程序…

mutatingwebhook的简单实例

一. k8s集群准备 这里不再赘述k8s集群搭建。主要注意参数:kubectl get po kube-apiserver-server -n kube-system -o yaml | grep plugin 预期结果为:- --enable-admission-plugins=NodeRestriction,MutatingAdmissionWebhook,ValidatingAdmissionWebhook 至少要拥有两个参数…

Tauri新手向 - 基于LSB隐写的shellcode加载器

此篇是记录自己初次学习tauri开发工具,包含遇到的一些问题以及基本的知识,也给想上手rust tauri的师傅们一些小小的参考。此项目为保持免杀性暂不开源,希望各位师傅多多支持,反响可以的话后续会放出代码大家一起交流学习。ShadowMeld - 基于图像隐写技术的载荷生成框架 通过…

P2341 [USACO03FALL / HAOI2006] 受欢迎的牛 G(缩点)

P2341 [USACO03FALL / HAOI2006] 受欢迎的牛 G 题目背景 本题测试数据已修复。 题目描述 每头奶牛都梦想成为牛棚里的明星。被所有奶牛喜欢的奶牛就是一头明星奶牛。所有奶牛都是自恋狂,每头奶牛总是喜欢自己的。奶牛之间的“喜欢”是可以传递的——如果 \(A\) 喜欢 \(B\),\(…

允许蜘蛛访问,屏蔽访客的php代码

大部分时候我们制作的泛目录需要屏蔽访客,php的优于js识别蜘蛛屏蔽,毕竟一个在服务器内运行后输出,一个在html中调用。 这里分享一段屏蔽游客查查看真实页面的php代码,直接命名为啥php文件,后在想要屏蔽游客的页面中引用(如:include /baidu.php;)就可以了,代码如下:&…

【程设の旅】第二次上机卡题复盘

python上机 其实很快就写完了,第五题有个坑,讲一下 05:奇偶ASCII值判断 描述 任意输入一个字符,判断其ASCII是否是奇数,若是,输出YES,否则,输出NO 例如,字符A的ASCII值是65,则输出YES,若输入字符B(ASCII值是66),则输出NO 输入 输入一个字符 输出 如果其ASCII值为奇数…