Transformer学习: Transformer小模块学习--位置编码,多头自注意力,掩码矩阵

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

Transformer学习

  • 1 位置编码模块
    • 1.1 PE代码
    • 1.2 测试PE
  • 2 多头自注意力模块
    • 2.1 多头自注意力代码
    • 2.2 测试多头注意力
  • 3 未来序列掩码矩阵

在这里插入图片描述

1 位置编码模块

P E ( p o s , 2 i ) = sin ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i)=\sin(pos/10000^{2i/d_{\mathrm{model}}}) PE(pos,2i)=sin(pos/100002i/dmodel)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 1000 0 2 i / d m o d e l ) PE(pos,2i+1)=\cos(pos/10000^{2i/d_\mathrm{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)

pos 是序列中每个对象的索引, p o s ∈ [ 0 , m a x s e q l e n ] pos\in [0,max_seq_len] pos[0,maxseqlen], i i i 向量维度序号, i ∈ [ 0 , e m b e d d i m / 2 ] i\in [0,embed_dim/2] i[0,embeddim/2], d m o d e l d_{model} dmodel是模型的embedding维度

1.1 PE代码

import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import seaborn as snsdef get_pos_ecoding(max_seq_len,embed_dim):# 初始化位置矩阵 [max_seq_len,embed_dim]pe = torch.zeros(max_seq_len,embed_dim])position = torch.arange(0,max_seq_len).unsqueeze(1) # [max_seq_len,1]print("位置:", position,position.shape)div_term = torch.exp(torch.arange(0,embed_dim,2)*-(math.log(10000.0)/embed_dim))   # 除项维度为embed_dim的一半,因为对矩阵分奇数和偶数位置进行填充。pe[:,0::2] = torch.sin(position/div_term)pe[:,1::2] = torch.cos(position/div_term)return pe

1.2 测试PE

pe = get_pos_ecoding(8,4)
plt.figure(figsize=(8,8))
sns.heatmap(pe)
plt.title("Sinusoidal Function")
plt.xlabel("hidden dimension")
plt.ylabel("sequence length")

输出:
位置: tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7]]) torch.Size([8, 1])
除项: tensor([1.0000, 0.0100]) torch.Size([2])
在这里插入图片描述

plt.figure(figsize=(8, 5))
plt.plot(positional_encoding[1:, 1], label="dimension 1")
plt.plot(positional_encoding[1:, 2], label="dimension 2")
plt.plot(positional_encoding[1:, 3], label="dimension 3")
plt.legend()
plt.xlabel("Sequence length")
plt.ylabel("Period of Positional Encoding")

在这里插入图片描述

2 多头自注意力模块

2.1 多头自注意力代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy# 复制网络,即使用几层网络就改变N的数量
# 如 4层线性层  clones(nn.Linear(model_dim,model_dim),4)
def clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])# 计算注意力
def attention(q, k, v, mask=None, dropout=None):d_k = q.size(-1)scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, v), p_attn# 计算多头注意力
class Multi_Head_Self_Att(nn.Module):def __init__(self,head,model_dim,dropout=0.1):super(Multi_Head_Self_Att,self).__init__()+assert model_dim % head == 0self.d_k = model_dim/headself.head = headself.linears = clones(nn.Linear(model_dim,model_dim),4)self.att = Noneself.dropout = nn.Dropout(p=dropout)def forward(self,q,k,v,mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = q.size(0)# zip函数 将线性层与q,k,v分别对应(self.linears,q),(self.linears,k),(self.linears,v)# q,k,v [bs,-1,head,embed_dim/head]q,k,v = [l(x).view(nbatches,-1,int(self.head),int(self.d_k)).transpose(1,2) for l,x in zip(self.linears,(q,k,v))] # 返回计算注意力之后的值作为x和注意力分数x, self.attn = attention(q, k, v, mask=mask, dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, int(self.head * self.d_k))return self.linears[-1](x),self.attn

2.2 测试多头注意力

# 模型参数
head = 4
model_dim = 128
seq_len = 10
dropout = 0.1# 生成示例输入
q = torch.randn(seq_len, model_dim)
k = torch.randn(seq_len, model_dim)
v = torch.randn(seq_len, model_dim)# 创建多头自注意力模块
att = Multi_Head_Self_Att(head, model_dim, dropout=dropout)# 运行模块
output,att = att(q, k, v)# 输出形状
print("Output shape:", output.shape)
print(att.shape())
sns.heatmap(att.squeeze().detach().cpu())

输出
Output shape: torch.Size([10, 1, 128])
torch.Size([10, 4, 1, 1])
在这里插入图片描述

3 未来序列掩码矩阵

作用: 防止泄露未来要预测的部分,掩码矩阵是一个除对角线的上三角矩阵
3.1 代码

def subsequent_mask(size):"Mask out subsequent positions."attn_shape = (1, size, size)subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')print("掩码矩阵:",subsequent_mask)return torch.from_numpy(subsequent_mask) == 0

测试掩码

plt.figure(figsize=(5,5))
print(subsequent_mask(8),subsequent_mask(8).shape)
plt.imshow(subsequent_mask(8)[0])

掩码矩阵:
[[[0 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1]
[0 0 0 0 1 1 1 1]
[0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0]]]

tensor([[[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True]]]) torch.Size([1, 8, 8])
在这里插入图片描述
紫色部分为添加掩码的部分

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/591705.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电商技术揭秘六:前端技术与用户体验优化

文章目录 引言一、前端技术在电商中的重要性1.1 前端技术概述1.2 用户体验与前端技术的关系 二、响应式设计与移动优化2.1 响应式设计的原则2.2 移动设备优化策略2.3 响应式设计的工具和框架 三、交互设计与用户体验提升3.1 交互设计的重要性3.2 用户体验的量化与优化3.3 通过前…

简单的购物商城

SSM整合后的一个及其简单的商城,首页数据是模拟的,主要测试购物车模块 启动 创建数据库:shopping导入建表脚本:shopping.sql修改db.properties部署和启动项目(项目的path为项目名)访问 http://localhost:…

【web】nginx+php-fpm云导航项目部署-(简版)

一、yum安装nginx yum -y install nginx 二、php环境安装 2.1 php安装 yum -y install php 2.2 php-fpm安装 yum -y install php-fpm 注:PHP在 5.3.3 之后已经讲php-fpm写入php源码核心了。 2.3 项目依赖的php-xml和php-xmlrpc安装 yum -y install php-…

[C#]OpenCvSharp改变图像的对比度和亮度

目的 访问像素值mat.At<T>(y,x) 用0初始化矩阵Mat.Zeros 饱和操作SaturateCast.ToByte 亮度和对比度调整 g(x)αf(x)β 用α(>0)和β一般称作增益(gain)和偏置(bias)&#xff0c;分别控制对比度和亮度 把f(x)看成源图像像素&#xff0c;把g(x)看成输出图像像素…

32-2 APP渗透 - 移动APP架构

前言 app渗透和web渗透最大的区别就是抓包不一样 一、客户端: 反编译: 静态分析的基础手段,将可执行文件转换回高级编程语言源代码的过程。可用于了解应用的内部实现细节,进行漏洞挖掘和算法分析等。调试: 排查软件错误的一种手段,用于分析应用内部原理和行为。篡改/重打…

如何(关闭)断开 Websocket 连接:简单易懂的实现指南

WebSocket 协议提供了一条用于 Web 应用程序中双向通讯的高效通道&#xff0c;让服务器能够实时地向客户端发送信息&#xff0c;而无需客户端每次都发起请求。本文旨在探讨有关结束 WebSocket 连接的适当时机&#xff0c;内容包括协议的基础知识、如何结束连接、一些使用场景&a…

【VUE】ruoyi框架自带页面可正常缓存,新页面缓存无效

ruoyi框架自带页面可正常缓存&#xff0c;新页面缓存无效 背景&#xff1a; 用若依框架进行开发时&#xff0c;发现ruoyi自带的页面缓存正常&#xff0c;而新开发的页面即使设置了缓存&#xff0c;当重新进入页面时依旧刷新了接口。 原因&#xff1a;页面name与 getRouters …

Ds18B20理解与运用

在51单片机中&#xff0c;DS18B20是一个十分重要的外设&#xff0c;它是一个温度传感器&#xff0c;可以读取温度并以16位的数据输出。这个数据由四位符号位&#xff0c;7位整数位和4位小数位组成。当符号位都是0的时候&#xff0c;温度是零上温度&#xff1b;当符号位都是1的时…

我与C++的爱恋:内联函数,auto

​ ​ &#x1f525;个人主页&#xff1a;guoguoqiang. &#x1f525;专栏&#xff1a;我与C的爱恋 ​ 一、内联函数 1.内联函数的概念 内联函数目的是减少函数调用的开销&#xff0c;通过将每个调用点将函数展开来实现。这种方法仅适用于那些函数体小、调用频繁的函数。 …

从“量子”到分子:探索计算的无限可能 | 综述荐读

在2023年年末&#xff0c;两篇划时代的研究报告在《科学》&#xff08;Science&#xff09;杂志上引发了广泛关注。这两篇论文分别来自两个研究小组&#xff0c;它们共同揭示了单氟化钙分子间相互作用的研究成果&#xff0c;成功地在这些分子间创造出了分子量子比特。这一成就不…

前端学习<四>JavaScript基础——03-常量和变量

常量&#xff08;字面量&#xff09;&#xff1a;数字和字符串 常量也称之为“字面量”&#xff0c;是固定值&#xff0c;不可改变。看见什么&#xff0c;它就是什么。 常量有下面这几种&#xff1a; 数字常量&#xff08;数值常量&#xff09; 字符串常量 布尔常量 自定义…

图片改大小尺寸怎么改?几个修改图片尺寸的方法

日常生活和工作中&#xff0c;图片的大小和尺寸对于我们的工作和生活都至关重要&#xff0c;因此我们经常需要调整图片的大小。我们都知道压缩图是一款功能强大的图片在线处理工具&#xff0c;那么用它怎么调整图片大小呢&#xff1f;下面就让我们一起来看一下具体的操作步骤。…