在一些领域,将嵌入层和输出层的权重绑定,以达到减少参数量并使得相同token保持统一的embedding空间的作用。
下面的nn.Linear(3, 10)
的权重矩阵的尺寸是10*3,即y = W @ x + b
,因此跟nn.Embedding(10, 3)
的权重矩阵大小相等。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Model_1(nn.Module):def __init__(self):super(Model_1, self).__init__()self.embedding = nn.Embedding(10, 3)self.head = nn.Linear(3, 10)# self.embedding.weight = self.head.weightdef forward(self, x):output = self.embedding(x)output = self.head(output)return F.softmax(output, dim=-1) class Model_2(nn.Module):def __init__(self):super(Model_2, self).__init__()self.embedding = nn.Embedding(10, 3)self.head = nn.Linear(3, 10)# 使用下面这行代码,二者权重会同步更新self.embedding.weight = self.head.weightdef forward(self, x):output = self.embedding(x)output = self.head(output)return F.softmax(output, dim=-1)model_1 = Model_1()
model_2 = Model_2()torch.manual_seed(0)
input_indexes = torch.randint(0, 10, (2, 3))
target = torch.zeros(2, 3, 10)
for i in range(2):for j in range(3):target[i, j, input_indexes[i, j]] = 1
print(target)# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
optimizer_1 = torch.optim.Adam(model_1.parameters(), lr=0.001)
optimizer_2 = torch.optim.Adam(model_2.parameters(), lr=0.001)
loss_tying = []
loss_no_tying = []for _ in range(2000):output_1 = model_1(input_indexes)loss = criterion(output_1, target)optimizer_1.zero_grad()loss.backward()optimizer_1.step()loss_no_tying.append(loss.item())output_2 = model_2(input_indexes)loss = criterion(output_2, target)optimizer_2.zero_grad()loss.backward()optimizer_2.step()loss_tying.append(loss.item())# print(output)
print(model_1.embedding.weight==model_1.head.weight)
print(model_2.embedding.weight==model_2.head.weight)
import matplotlib.pyplot as plt
plt.plot(loss_tying, label="use weight tying")
plt.plot(loss_no_tying, label="not use weight tying")
plt.legend()
plt.show()
可以看到,在这个例子中,使用 weight-tying 后 loss 收敛更快。