边写代码边学习之TF Attention

1. 什么是Attention

注意力机制(Attention Mechanism)是机器学习和人工智能领域中的一个重要概念,用于模拟人类视觉或听觉等感知过程中的关注机制。注意力机制的目标是让模型能够在处理信息时,更加关注与任务相关的部分,忽略与任务无关的信息。这种机制最初是受到人类大脑对信息处理的启发而提出的。

注意力机制的基本原理如下:

  1. 输入信息:首先,注意力机制接收输入信息,这可以是序列数据、图像、语音等。

  2. 查询、键和值:对于每个输入,注意力机制引入了三个部分:查询(query)、键(key)、值(value)。这些部分通常是通过神经网络学习得到的。查询用于表示要关注的内容,键用于表示输入信息中的特征,值则是与每个键相关的信息。

  3. 权重分配:注意力机制根据查询和键之间的关系来计算权重,这些权重决定了每个值在最终输出中的贡献程度。通常使用某种形式的相似度度量(如点积、缩放点积等)来计算权重。

  4. 加权求和:将计算得到的权重与对应的值相乘,然后将它们加权求和,得到最终的输出。这个输出通常包含了模型在处理输入信息时关注的部分。

  5. 重复:上述过程通常会被重复多次,以便模型可以在不同的上下文中动态地调整注意力。

注意力机制的核心思想是让模型能够自动地确定在处理输入信息时要关注哪些部分,从而提高了模型在各种任务中的性能。它在自然语言处理、计算机视觉和语音处理等领域都有广泛的应用,如在机器翻译中的Transformer模型、图像分割中的U-Net模型以及语音识别中的Listen, Attend and Spell(LAS)模型等。

总的来说,注意力机制可以帮助模型更好地理解和利用输入信息,提高了模型的表现和泛化能力。

2. Why Attention

由于LSTM和GRU只在一定程度上改进了循环神经网络的长句子依赖问题,并且信息的记忆能力也不是很强和计算能力有限。如果模型要记住很多信息,不得不设计的更复杂,为了解决这些问题,注意力机制出现了,它即能从大量信息中选择重要的信息来缓解神经网络模型的复杂度,而且能高效的并行运算。注意力机制的计算是一个匹配的过程,即通过一个查询(Query)向量到键(Key)和值(Value)对数据对来映射输出值.

注意力的计算一般有三个阶段。第一阶段是计算查询向量Q和每个输入的K的相关性或相似度,得到注意力权重系数S_i :

S_i=f(Q,K_i)

第二阶段是使用SoftMax函数对第一阶段得出的权重系数进行尺度缩放,即把它归一化为概率分布 ai ,分子是把神经元的当前输出映射到(0,+∞),分母是所有输出结果值的总和,公式如下:

a _i=softmax (S_i ) = e^{S_i }/(\sum e^{S_j})

第三阶段:将第二阶段得出的权重与value值加权求和,得到最终需要的Attention数值:

Attention(Q,K,V)=\sum a_i V_i

3. TF attention api 介绍

Attention class

tf.keras.layers.Attention(use_scale=False, score_mode="dot", **kwargs)

Dot-product attention layer, a.k.a. Luong-style attention.

Inputs are query tensor of shape [batch_size, Tq, dim]value tensor of shape [batch_size, Tv, dim] and key tensor of shape [batch_size, Tv, dim]. The calculation follows the steps:

  1. Calculate scores with shape [batch_size, Tq, Tv] as a query-key dot product: scores = tf.matmul(query, key, transpose_b=True).
  2. Use scores to calculate a distribution with shape [batch_size, Tq, Tv]distribution = tf.nn.softmax(scores).
  3. Use distribution to create a linear combination of value with shape [batch_size, Tq, dim]return tf.matmul(distribution, value).

4. 实验代码

4.1.  验证并理解TF attention方法,只输入query和value矩阵。

def softmax(t):s_value = np.exp(t) / np.sum(np.exp(t), axis=-1, keepdims=True)# print('softmax value: ', s_value)return s_valuedef numpy_attention(inputs,mask=None,training=None,return_attention_scores=False,use_causal_mask=False):query = inputs[0]value = inputs[1]key = inputs[2] if len(inputs) > 2 else valuescore = np.matmul(query, key.transpose())attention_score_np = softmax(score)result = np.matmul(attention_score_np, value)print('attention score in numpy =', attention_score_np)print('result in numpy = ', result)def verify_logic_in_attention_with_query_value():query_data = np.array([[1, 0.0, 1],[2, 3, 1]])value_data = np.array([[2, 1.0, 1],[1, 4, 2 ]])print(query_data.shape)numpy_attention([query_data, value_data], return_attention_scores=True)print("=============following is keras attention output================")attention_layer= tf.keras.layers.Attention()result, attention_scores = attention_layer([query_data, value_data], return_attention_scores=True)print('attention_scores = ', attention_scores)print('result=', result);
if __name__ == '__main__':verify_logic_in_attention_with_query_value()

运行结果

(2, 3)
attention score in numpy = [[5.0000000e-01 5.0000000e-01][3.3535013e-04 9.9966465e-01]]
result in numpy =  [[1.5        2.5        1.5       ][1.00033535 3.99899395 1.99966465]]
=============following is keras attention output================
attention_scores =  tf.Tensor(
[[5.0000000e-01 5.0000000e-01][3.3535014e-04 9.9966466e-01]], shape=(2, 2), dtype=float32)
result= tf.Tensor(
[[1.5       2.5       1.5      ][1.0003353 3.998994  1.9996647]], shape=(2, 3), dtype=float32)

4.2.  验证并理解TF attention方法,输入query, key, value矩阵。

def verify_logic_in_attention_with_query_key_value():query_data = np.array([[1, 0.0, 1],[2, 3, 1]])value_data = np.array([[2, 1.0, 1],[1, 4, 2 ]])key_data = np.array([[1, 2.0, 2], [3, 1, 0.1]])print(query_data.shape)numpy_attention([query_data, value_data, key_data], return_attention_scores=True)print("=============following is keras attention output================")attention_layer= tf.keras.layers.Attention()result, attention_scores = attention_layer([query_data, value_data, key_data], return_attention_scores=True)print(attention_layer.get_weights())print('attention_scores = ', attention_scores)print('result=', result);
if __name__ == '__main__':verify_logic_in_attention_with_query_key_value()

结果

(2, 3)
attention score in numpy = [[0.47502081 0.52497919][0.7109495  0.2890505 ]]
result in numpy =  [[1.47502081 2.57493756 1.52497919][1.7109495  1.86715149 1.2890505 ]]
=============following is keras attention output================
[]
attention_scores =  tf.Tensor(
[[0.47502086 0.52497923][0.7109495  0.28905058]], shape=(2, 2), dtype=float32)
result= tf.Tensor(
[[1.4750209 2.5749378 1.5249794][1.7109495 1.8671517 1.2890506]], shape=(2, 3), dtype=float32)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/94298.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电缆工厂 3D 可视化管控系统 | 智慧工厂

近年来,我国各类器材制造业已经开始向数字化生产转型,使得生产流程变得更加精准高效。通过应用智能设备、物联网和大数据分析等技术,企业可以更好地监控生产线上的运行和质量情况,及时发现和解决问题,从而提高生产效率…

TDengine(2):wsl2+ubuntu20.04+TDengine安装

一、ubuntu系统下提供了三种安装TDengine的方式: 二、通过 apt 指令安装失败 因为是linux初学者,对apt 指令较为熟悉,因此首先使用了该方式进行安装。 wget -qO - http://repos.taosdata.com/tdengine.key | sudo apt-key add -echo "…

vscode 清除全部的console.log

在放页面的大文件夹view上面右键点击在文件夹中查找 console.log.*$ 注意:要选择使用正则匹配 替换为 " " (空字符串)

什么是Flex容器和Flex项目(Flex Container and Flex Item)?它们之间有什么关系?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ Flex容器和Flex项目⭐ Flex容器⭐ Flex项目⭐ 关系⭐ 写在最后 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 记得点击上方或者右侧链接订阅本专栏哦 几何带你启航前端之旅 欢迎来到前端入门之旅!这个专栏是为…

【LeetCode每日一题】——1365.有多少小于当前数字的数字

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时间频度】九【代码实现】十【提交结果】 一【题目类别】 排序 二【题目难度】 简单 三【题目编号】 1365.有多少小于当前数字的数字 四【题目描述】 …

Android JNI系列详解之生成指定CPU的库文件

一、前提 这次主要了解Android的cpu架构类型,以及在使用CMake工具的时候,如何指定生成哪种类型的库文件。 如上图所示,是我们之前使用CMake工具默认生成的四种cpu架构的动态库文件:arm64-v8a、armeabi-v7a、x86、x86_64&#xff0…

docker-compose 部署 Seata整合nacos,Postgresql 为DB存储

docker-compose 部署 Seata整合nacos,Postgresql 为DB存储 环境 详情环境可参考 https://github.com/alibaba/spring-cloud-alibaba/wiki/%E7%89%88%E6%9C%AC%E8%AF%B4%E6%98%8E 我这里 <spring.cloud.alibaba-version>2021.1</spring.cloud.alibaba-version>所…

数据库-DML

DML&#xff1a;用来对数据库中表的数据记录进行增、删、改等操作。 添加数据&#xff08;INSERT&#xff09; insert语法&#xff1a; 指定字段添加数据&#xff1a;insert into 表单&#xff08;字段名1&#xff0c;字段名2&#xff09;values&#xff08;值1&#xff0c;值…

【C++】C++11新特性(下)

上篇文章&#xff08;C11的新特性&#xff08;上&#xff09;&#xff09;我们讲述了C11中的部分重要特性。本篇接着上篇文章进行讲解。本篇文章主要进行讲解&#xff1a;完美转发、新类的功能、可变参数模板、lambda 表达式、包装器。希望本篇文章会对你有所帮助。 文章目录 一…

Day5:react函数组件与类组件

「目标」: 持续输出&#xff01;每日分享关于web前端常见知识、面试题、性能优化、新技术等方面的内容。 「主要面向群体&#xff1a;」前端开发工程师&#xff08;初、中、高级&#xff09;、应届、转行、培训、自学等同学 Day4-今日话题 react「函数组件和类组件」的区别&…

一百六十八、Kettle——用海豚调度器定时调度从Kafka到HDFS的任务脚本(持续更新追踪、持续完善)

一、目的 在实际项目中&#xff0c;从Kafka到HDFS的数据是每天自动生成一个文件&#xff0c;按日期区分。而且Kafka在不断生产数据&#xff0c;因此看看kettle是不是需要时刻运行&#xff1f;能不能按照每日自动生成数据文件&#xff1f; 为了测试实际项目中的海豚定时调度从…

JVM调优指令参数

常用命令查找文档站点&#xff1a;https://docs.oracle.com/javase/8/docs/technotes/tools/unix/index.html -XX:PrintFlagsInitial 输出所有参数的名称和默认值&#xff0c;默认不包括Diagnostic和Experimental的参数。可以配合 -XX:UnlockDiagnosticVMOptions和-XX:UnlockEx…