weight-tying探索

在一些领域,将嵌入层和输出层的权重绑定,以达到减少参数量并使得相同token保持统一的embedding空间的作用。

下面的nn.Linear(3, 10)的权重矩阵的尺寸是10*3,即y = W @ x + b,因此跟nn.Embedding(10, 3)的权重矩阵大小相等。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Model_1(nn.Module):def __init__(self):super(Model_1, self).__init__()self.embedding = nn.Embedding(10, 3)self.head = nn.Linear(3, 10)# self.embedding.weight = self.head.weightdef forward(self, x):output = self.embedding(x)output = self.head(output)return F.softmax(output, dim=-1)    class Model_2(nn.Module):def __init__(self):super(Model_2, self).__init__()self.embedding = nn.Embedding(10, 3)self.head = nn.Linear(3, 10)# 使用下面这行代码,二者权重会同步更新self.embedding.weight = self.head.weightdef forward(self, x):output = self.embedding(x)output = self.head(output)return F.softmax(output, dim=-1)model_1 = Model_1()
model_2 = Model_2()torch.manual_seed(0)
input_indexes = torch.randint(0, 10, (2, 3))
target = torch.zeros(2, 3, 10)
for i in range(2):for j in range(3):target[i, j, input_indexes[i, j]] = 1
print(target)# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
optimizer_1 = torch.optim.Adam(model_1.parameters(), lr=0.001)
optimizer_2 = torch.optim.Adam(model_2.parameters(), lr=0.001)
loss_tying = []
loss_no_tying = []for _ in range(2000):output_1 = model_1(input_indexes)loss = criterion(output_1, target)optimizer_1.zero_grad()loss.backward()optimizer_1.step()loss_no_tying.append(loss.item())output_2 = model_2(input_indexes)loss = criterion(output_2, target)optimizer_2.zero_grad()loss.backward()optimizer_2.step()loss_tying.append(loss.item())# print(output)
print(model_1.embedding.weight==model_1.head.weight)
print(model_2.embedding.weight==model_2.head.weight)
import matplotlib.pyplot as plt
plt.plot(loss_tying, label="use weight tying")
plt.plot(loss_no_tying, label="not use weight tying")
plt.legend()
plt.show()

在这里插入图片描述
可以看到,在这个例子中,使用 weight-tying 后 loss 收敛更快。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/600916.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python MNIST 转图片

Python MNIST 转图片 1 获取数据2 显示图片3 转换图片4 全部代码 1 获取数据 import numpy as np import tensorflow as tf from tensorflow.keras import datasets # type: ignoredef save(mnist_path):# 输出TensorFlow版本print("TensorFlow: {0}".format(tf.__v…

基于SpringBoot+Vue的健身教练预约管理系统(源码+文档+部署+讲解)

一.系统概述 私人健身与教练预约管理系统,可以摆脱传统手写记录的管理模式。利用计算机系统,进行用户信息、管理员信息的管理,其中包含首页,个人中心,用户管理,教练管理,健身项目管理&#xff0…

C语言 03 VSCode开发

安装好 C 语言的开发环境后,就需要创建项目进行开发了。 使用 IDE(集成开发环境)进行开发了。 C 语言的开发工具很多,现在主流的有 Clion、Visual Studio、VSCode。 这里以 VSCode 作为演示。 创建项目 安装 VSCode。 推荐直接在…

Yii2 路由美化访问需要加s

不得不说yii真是反人类,怪不得现在都不维护了,今天解析下路由美化下的路由访问问题。 设置main.php配置文件 urlManager > [enablePrettyUrl > true, // 启用 URL美化showScriptName > false, // 隐藏入口文件index.phpenableStrictParsing…

计算机网络-TCP连接建立阶段错误应对机制

错误现象 丢包 网络问题:网络不稳定可能导致丢包,例如信号弱或干扰强。带宽限制可能导致路由器或交换机丢弃包,尤其是在高流量时段。网络拥塞时,多个数据流竞争有限的资源,也可能导致丢包。缓冲区溢出:TC…

如何搭建企业级MQ消息集成平台

企业级MQ消息集成平台的重要性在于实现不同系统之间的高效、可靠、实时的消息传递和数据交换。它可以帮助企业实现系统解耦,提高系统的可扩展性和灵活性,降低系统间的依赖性。通过消息队列中间件,企业可以实现异步通信、削峰填谷、流量控制等…

橘子学JDK之JMH-02(BenchmarkModes)

一、案例二代码 这次我们来搞一下官网文档的第二个案例,我删除了一些没用的注释,然后对代码做了一下注释的翻译,可以看一下意思。 package com.levi;import org.openjdk.jmh.annotations.*; import org.openjdk.jmh.runner.Runner; import …

Vite+Vue3.0项目使用ant-design-vue <a-calendar>日期组件汉化

antd的弹框、日期等默认为英文,要把英文转为中文请看下文: 1.首先我们要在main.js中引入ant-design组件库并全局挂载: import App from ./App import Antd from ant-design-vue; import ant-design-vue/dist/antd.css;const app createApp(…

paddle实现手写数字模型(一)

参考文档:paddle官网文档环境:Python 3.12.2 ,pip 24.0 ,paddlepaddle 2.6.0 python -m pip install paddlepaddle2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple调试代码如下: LeNet.py import paddle import p…

1.网络编程-网络协议

目录 网络编程是什么 网络编程三要素 OSI七层网络模型 TCP/IP五层模型 SSL/TLS 是哪层协议 网络编程是什么 网络编程是计算机科学中的一个重要领域,它涉及到编写能够在网络环境中进行通信的程序。网络编程的核心目标是使不同的设备能够通过网络交换信息&#…

重建大师进行扫码认证了,接下来怎样才能正常使用?(如下图)

重建大师软件授权已经有了后,新建工程后设置任务目录和监控目录一致就可以运行了。 重建大师是一款专为超大规模实景三维数据生产而设计的集群并行处理软件,输入倾斜照片,激光点云,POS信息及像控点,输出高精度彩色网格…

linux创建文件、linux创建文件的几种方式、touch、echo、cat、vi、vim

文章目录 一、创建文件1.1、touch1.2、echo1.3、cat1.4、vi或vim 一、创建文件 1.1、touch touch命令:用于创建一个新的空文件或者更新已存在文件的访问和修改时间。 (1)如果目标文件不存在,则新建一个文件 touch demo.txt&am…