单头注意力机制(ScaledDotProductAttention) python实现

输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility),每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出。
在这里插入图片描述

import torch
import torch.nn as nnclass ScaledDotProductAttention(nn.Module):""" Scaled Dot-Product Attention """def __init__(self, scale):super().__init__()self.scale = scaleself.softmax = nn.Softmax(dim=2)def forward(self, q, k, v, mask=None):u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmulu = u / self.scale # 2.Scaleif mask is not None:u = u.masked_fill(mask, -np.inf) # 3.Maskattn = self.softmax(u) # 4.Softmaxoutput = torch.bmm(attn, v) # 5.Outputreturn attn, outputif __name__ == "__main__":n_q, n_k, n_v = 2, 4, 4d_q, d_k, d_v = 128, 128, 64batch = 2q = torch.randn(batch, n_q, d_q)k = torch.randn(batch, n_k, d_k)v = torch.randn(batch, n_v, d_v)mask = torch.zeros(batch, n_q, n_k).bool()attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))attn, output = attention(q, k, v, mask=mask)print(attn)print(output)

运行结果:


tensor([[[0.4165, 0.3548, 0.1667, 0.0620],[0.0381, 0.3595, 0.4584, 0.1439]],[[0.3611, 0.1587, 0.2078, 0.2723],[0.1603, 0.0530, 0.0670, 0.7198]]])
tensor([[[ 2.2813e-01, -6.3289e-01,  1.3624e+00,  8.4069e-01,  8.1762e-02,-6.3727e-01, -6.3929e-01, -1.0091e+00,  3.7668e-01, -2.9384e-01,-6.2543e-02, -4.4706e-01,  3.8331e-01,  2.2979e-02, -1.1968e+00,-3.7061e-01, -1.9007e-01, -1.7616e-01,  3.6516e-01,  1.1321e-01,-9.5077e-01, -1.3449e+00, -1.2594e+00,  4.2644e-01, -6.3195e-01,-5.2016e-01, -2.5782e-01, -2.4116e-01,  1.7582e-01, -1.5177e+00,-9.3120e-01, -4.9671e-01, -4.5024e-01, -1.0746e+00,  5.4357e-01,-6.2079e-01,  5.1379e-01,  5.6308e-02, -6.3830e-01, -3.6174e-01,-3.0044e-01, -3.0946e-01, -5.0303e-01, -1.8382e-01,  1.1064e+00,-7.5142e-01, -1.5372e-01, -3.3204e-01, -7.9568e-01,  1.3108e-01,-8.6041e-01,  2.5165e-01,  8.8248e-02,  3.7294e-01, -5.2247e-02,4.8462e-01, -7.4389e-01, -5.4351e-01, -9.7697e-01, -9.3327e-01,-4.4550e-02,  6.1108e-01, -5.4613e-01,  2.3962e-01],[ 6.9032e-02,  9.0591e-01,  8.3206e-01,  1.3668e+00,  1.8095e-02,-7.3172e-02, -3.0873e-01, -9.2571e-01,  4.3452e-01, -4.7707e-02,-3.0431e-01, -1.7578e-01,  4.0575e-01, -4.4958e-01, -4.9809e-01,-1.7263e-02, -3.8684e-01,  2.8536e-01,  4.1150e-02, -3.7069e-01,-7.2903e-01, -2.5185e-01, -1.0011e-01,  9.0434e-01, -7.8387e-02,6.9680e-01,  5.3684e-01,  2.8456e-01,  2.2887e-01, -1.7423e+00,-4.4135e-01, -2.9209e-01,  1.7053e-01, -6.4208e-01,  1.7977e-01,1.3822e-01, -1.7873e-01, -4.7619e-01, -6.7788e-01, -5.3340e-01,3.1518e-01, -5.6127e-02,  2.2175e-01, -3.9524e-01,  5.4478e-01,-5.7730e-01,  5.8043e-01, -3.0143e-01, -5.7146e-01,  1.5063e-05,-6.8221e-01, -1.3456e-02, -6.5192e-01,  7.4233e-02,  3.1776e-01,3.1504e-01, -9.5457e-01, -8.9894e-01, -7.8422e-01, -4.1440e-01,-9.4272e-02,  2.7226e-01, -7.0286e-01,  8.9388e-01]],[[-7.6068e-02,  1.6911e-01,  5.1532e-02, -5.3612e-02,  2.4258e-02,1.6490e-01,  7.4469e-01, -1.1471e+00, -4.5234e-01,  1.0684e-01,1.0929e+00, -5.8079e-01,  1.7665e-01, -2.0187e-02, -3.3850e-01,4.4517e-01, -4.5871e-01,  6.7840e-01, -4.3617e-01,  7.6141e-01,3.8135e-02, -2.3898e-01,  3.2086e-01,  4.1481e-01, -1.8267e-01,8.4337e-01,  7.8504e-02, -1.0101e+00,  5.0766e-02,  2.3338e-01,-3.5572e-01,  1.3751e-01, -4.9570e-02,  4.8627e-01, -3.3225e-01,6.5361e-01,  2.8979e-01,  9.9991e-02,  8.6995e-01, -7.2569e-02,2.5490e-01, -2.6418e-01,  6.1185e-01, -7.7243e-01, -4.6956e-01,-3.1459e-01, -2.1278e-01,  9.1588e-01, -2.1349e-02, -5.0036e-01,3.6214e-01,  1.3723e-02,  1.2322e-01, -5.3018e-01,  2.4809e-01,-3.2042e-01,  2.4807e-01, -1.5764e-01, -2.6655e-01,  1.8610e-01,-1.6585e-01,  2.3454e-01,  3.1852e-01,  6.1627e-01],[-1.7126e-01,  8.6634e-01,  4.7069e-01, -8.1842e-01, -6.2145e-01,-3.8596e-02,  1.2991e+00, -8.4528e-01, -1.5742e+00,  1.2813e+00,1.1197e+00, -1.2562e+00,  7.3848e-01,  2.2198e-02, -4.1664e-01,1.1044e+00, -1.2744e+00, -1.6599e-01, -6.4863e-01,  1.1497e+00,-1.4236e-01, -1.2829e-01, -2.7600e-01,  4.7095e-01, -5.1933e-02,8.7453e-01, -6.4251e-01, -4.2953e-01,  3.5337e-01, -2.2782e-01,2.5079e-01,  1.7728e-01,  6.4826e-01,  2.4980e-01,  8.3032e-02,2.1247e+00, -3.0265e-01, -1.9821e-01,  9.7439e-01, -3.6237e-01,-2.6392e-01, -5.1498e-01,  1.3055e+00, -9.1860e-01, -6.9769e-01,6.5717e-01,  5.8009e-01,  3.6944e-01,  2.0414e-01, -9.0271e-01,4.5972e-01,  9.4667e-01,  1.3700e-02, -2.7962e-01,  3.7535e-01,-4.1842e-01, -6.2615e-01,  6.8238e-03, -3.4866e-01,  5.7681e-01,-5.5240e-01,  1.8245e-01,  6.2508e-01,  6.0020e-01]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539409.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据结构的美之链表和树

有种感觉叫做,不同的场景,应用不同的数据结构和算法,可以大大滴优化增删改查以及存储方面等等的性能。笔者这里呢也是在最近复习准备面试的时候,去阅读源码,觉得设计这种数据结构和引用的人真的是非常牛逼,…

【计算机网络】概述

文章目录 一、Internet 因特网1.1 网络、互联网、因特网1.2 因特网的组成 二、三种交换方式2.1 电路交换 (Circuit Switching)2.2 *分组交换 (Packet Switching)2.3 报文交换 (Message Switching) 三、计算…

IIS上部署.netcore WebApi项目及swagger

.netcore项目一般是直接双击exe文件,运行服务,今天有个需求,需要把.netcore项目运行在IIS上,遇到了一个小坑,在这里记录一下。 安装IIS,怎么部署站点,这些过于简单就不细说了,不知道…

力扣热题100_矩阵_48_旋转图像

文章目录 题目链接解题思路解题代码 题目链接 48.旋转图像 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1&#xff1…

LORA_ LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

paper: https://arxiv.org/pdf/2106.09685.pdf code: https://github.com/microsoft/LoRA 摘要 作者提出了低秩自适应,或称LoRA,它冻结了预先训练的模型权值,并将可训练的秩分解矩阵注入变压器架构的每一层,大大减少了下游任务的…

使用Python对文本文件进行分词、词频统计和可视化

目录 一、引言 二、文本分词 三、词频统计 四、可视化 五、案例与总结 六、注意事项与扩展 七、总结与展望 一、引言 在大数据时代,文本处理是信息提取和数据分析的重要一环。分词、词频统计和可视化是文本处理中的基础任务,它们能够帮助…

《LeetCode热题100》笔记题解思路技巧优化_Part_2

《LeetCode热题100》笔记&题解&思路&技巧&优化_Part_2 😍😍😍 相知🙌🙌🙌 相识😢😢😢 开始刷题普通数组🟡1. 最大子数组和🟡2. 合…

【数据结构高阶】图

目录 一、图的基本概念 二、 图的存储结构 2.1 邻接矩阵 2.2.1 邻接矩阵存储模式的代码实现 2.2.2 邻接矩阵存储的优缺点 2.2 邻接表 2.2.1 无向图的邻接表 2.2.2 有向图的邻接表 2.2.3 邻接表存储模式的代码实现 2.2.4 邻接表存储的优缺点 三、图的遍历 3.1 图的…

稀碎从零算法笔记Day17-LeetCode:有效的括号

题型:栈 链接:20. 有效的括号 - 力扣(LeetCode) 来源:LeetCode 题目描述(红字为笔者添加) 给定一个只包括 (,),{,},[,] 的字符串 …

Ubuntu Flask 运行 gunicorn+Nginx 部署

linux Ubuntu 下运行python 程序出现killed 原因:CPU或内存限制:在华为云上,你可能有CPU或内存使用的限制。例如,如果你使用的是一个固定大小的实例,那么超过该实例的CPU或内存限制可能会导致进程被杀死。 参考&am…

微前端框架 qiankun 配置使用【基于 vue/react脚手架创建项目 】

qiankun官方文档:qiankun - qiankun 一、创建主应用: 这里以 vue 为主应用,vue版本:2.x // 全局安装vue脚手架 npm install -g vue/clivue create main-app 省略 vue 创建项目过程,若不会可以自行百度查阅教程 …

【mask】根据bbox提示同一张图片生成多个矩形框掩码

前提:使用labelimg得到bbox 1.代码 import cv2 import numpy as np# 读取图片 image cv2.imread("D:\Desktop\mult_test\images\SL03509990_1694761223500.jpg")# 假设我们有多个目标的ROI(感兴趣区域) rois [(565,635,1006,85…