torch.nn.Embedding的导入与导出

news/2024/9/18 16:20:24/文章来源:https://www.cnblogs.com/kingwz/p/18411016

简介及导入转自:torch.nn.Embedding使用

在RNN模型的训练过程中,需要用到词嵌入,使用torch.nn.Embedding可以快速的完成:只需要初始化torch.nn.Embedding(n,m)即可(n是单词数,m是词向量的维度)(n是嵌入字典的大小,m是嵌入向量的维度。)。

注意: embedding开始是随机的,在训练的时候会自动更新。

简单使用

举个简单的例子:

  • 输入:word1和word2是两个长度为3的句子,保存的是单词所对应的词向量的索引号。
  • 输出:随机生成(4,5)维度大小的embedding,可以通过embedding.weight查看embedding的内容。
  • 过程:输入word1时,embedding会输出第0、1、2行词向量的内容; word2同理。
import torchword1 = torch.LongTensor([0, 1, 2])
word2 = torch.LongTensor([3, 1, 2])
embedding = torch.nn.Embedding(4, 5)print(embedding.weight)
print('word1:')
print(embedding(word1))
print('word2:')
print(embedding(word2))

image

导出

创建一个嵌入层并将其导出为numpy数组。

import torch
import numpy as np# 创建嵌入层
embedding = torch.nn.Embedding(10, 5)# 将权重转换为numpy数组
embedding_weights = embedding.weight.data.numpy()# 保存权重到文件
np.savetxt("embedding_weights.txt", embedding_weights)

导入

导入已经训练好的词向量,需要设置训练过程中不更新(固定embedding)。

如下所示,emb是已经训练得到的词向量,先初始化等同大小的embedding,然后将emb的数据复制过来,最后一定要设置weight.requires_grad为False。

self.embedding = torch.nn.Embedding(emb.size(0), emb.size(1))
self.embedding.weight = torch.nn.Parameter(emb)# 固定embedding
self.embedding.weight.requires_grad = False

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/796209.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第十七讲 为什么这些SQL语句逻辑相同,性能却差异巨大?

第十七讲: 为什么这些SQL语句逻辑相同,性能却差异巨大? 简概:引入: ​ 在 MySQL 中,有很多看上去逻辑相同,但性能却差异巨大的 SQL 语句。对这些语句使用不当的话,就会不经意间导致整个数据库的压力变大。我今天挑选了三个这样的案例和你分享。希望再遇到相似的问题时,…

2024-09-12 TypeError: Cannot read properties of undefined (reading 0) ==》检查未定义的对象or数组

TypeError: Cannot read properties of undefined (reading 0) ==》TypeError:无法读取undefined的属性(读取“0”) 请记住出现这种错误大多数都是因为你读取了未定义的对象或数组 排查结果:后端返回的id由原来的小写id改成了大写Id。 666,服了,哥们。

UNO 已知问题 在后台线程触发 SKXamlCanvas 的 Invalidate 且在 PaintSurface 事件抛出异常将炸掉应用

本文记录一个 UNO 已知问题,在 UNO 里面可以利用 SKXamlCanvas 对接 Skia 绘制到应用里面。如果此时在后台线程里面调用 SKXamlCanvas 的 Invalidate 触发界面的重新刷新,但在具体的执行绘制 PaintSurface 事件里面对外抛出异常,将会导致应用炸掉背景: 我准备在 UNO 里面将…

WPF 的 WriteableBitmap 在 Intel 11 代 Iris Xe Graphics 核显设备上停止渲染

在 Intel 11 代锐炬 Intel Iris Xe Graphics 核显设备上,如果此设备使用旧版本驱动,则可能导致 WPF 的 WriteableBitmap 停止渲染。此问题和 WPF 无关,此问题是 Intel 的 bug 且最新驱动版本已修复官方问题记录地址:https://www.intel.cn/content/www/cn/zh/support/articl…

WPF 的 Viewport3D 等 3D 模块在带 Intel UHD 770 设备上抛出渲染异常

在带 Intel UHD 770 的设备上,使用旧版本驱动,即小于 30.0.101.1660 版本驱动,将会导致 WPF 的 3D 模块出现渲染异常。此问题和 WPF 无关,此问题是 Intel 的 bug 且最新驱动版本已修复官方问题记录地址:https://community.intel.com/t5/Graphics/Crash-with-UHD-770-in-WP…

【Azure Service Bus】批量处理Service Bus Topic 中的死信消息(dead-lettered messages)

问题描述 在Azure的门户页面上,因为Service Bus Topic中有很多dead-lettered message,而这些消息占用了大量的存储空间,通过门户上的Service Bus Explorer每次只能消费一条消息。 虽然可以通过修改代码来指定消费私信队列中消息,但是需要修改代码,需要一些工作量。 有没有…

Transformer两大发展方向——GPT系列及BERT(一)

前面介绍了Transformer,随着其发展在NLP领域应用越来越多,在其基础上主要有两篇影响非常大的文章,一篇是GPT,另一篇是BERT。OpenAI提出的GPT采用Transformer解码器结构,一路更新迭代到了现在有了GPT-4,而Google提出的BERT采用Transformer的编码器结构。大体时间线如下图所…

RustPython简单使用

RustPython介绍 同CPython,Jpython,PyPy一样,RustPython,是使用Rust语言实现的Python解释器,支持Python3语法。 项目地址:https://github.com/RustPython/RustPython RustPython真正方便的是可以编译成Wasm文件,可以直接在浏览器中使用,示例网站:https://rustpython.g…

【解题报告】P8478 「GLR-R3」清明

我无可代替,哪怕来历已不神秘;麦克风接力,百万人就等我出席。P8478 「GLR-R3」清明 参考了出题人题解和 xcyyyyyy 大神的题解,强推前两篇。 拿到题完全没思路怎么办??? 人类智慧的巅峰,思维量的登峰造极。 换句话说就是非人题目,不过不得不说 GLR 的题是真的好,难度也…

Openwrt安装ddns-go

必备条件已刷好OpenWRT的路由 Openwrt已配置好网络根据CPU架构下载DDNS-go 我用的是迅雷赚钱宝1代,其CPU是arm7,所以要下载对应的arm7版本 https://github.com/jeessy2/ddns-go/releases 解压文件,将文件复制到openwrt 用WinSCP连接OpenWRT,复制ddns-go进去 WinSCP下载 如果…

python如何使用 秘钥证书 进行 SM2 加密

最近一个项目,需要使用sm2非对称加密,对方直接给的秘钥证书,python使用gmssl 进行加密,解密,加签,验签用的秘钥是这种格式 # Private Key秘钥 5aa03412c3051e1d4cf9d19cfbeeec70c28f388c9f82747cc912096c9cd44bea # Public Key 公钥 044291b381a039a8d7d02d7272d2d7c78a30d33e…

让小爱音箱播放电脑/NAS上歌曲,支持自动从哔哩哔哩/油管下载歌曲,无需刷机。支持语音控制和WebUI控制,docker部署多平台兼容,解决仅能播放试听版的苦恼

小米AI音箱很多人都有,但使用中播放歌曲时总是提示仅能播放试听版,不能完整听歌,很烦人。今天介绍的方法就是要彻底解决这个问题,实现让小爱AI音箱能够播放本地歌曲,本地没有的歌曲还能自动从网上搜索下载的功能。 已测试支持的设备:型号 名称L06A 小爱音箱L07A Redmi小爱…