相对位置编码(relative position representation)

最近在看wenet项目时,发现其用的是相对位置编码。同时在做tts时,发现其效果还可以,但是就是对于长文本的生成效果不好,一直在思考是什么原因导致的,有想到最有可能是fastspeech是的绝对位置编码问题,所以还想着增大文本长度来重新训练,从实践角度来讲这是可以的,但是是治标不治本。直到最近研究了下相对位置编码,然后又翻vits的源码确认其也用的是相对位置编码,所以恍然大悟。这里总结下相对位置编码的原理,以做更深刻的理解。


一.动机

假设我们的文本为:“I think therefore I am”

如果我们要得到“I”的表示,第二个“I”的表示向量跟第一个“I”应该是不一样的,因为在第二个“I”时,有输入前面的“I think therefore”的隐藏状态信息,而第一个“I”只是输入一个初始化的隐藏状态信息。

如果不加入位置编码情况下(原始的transformer是不会在self-attention中加入位置编码的,只是在传入的时候加入绝对位置编码),那么transformer里面的self-attention会把两个位置上的“I”看做是同一个表示,从语义来看这是不合理的。

所以transformer的self-attention在不加入位置编码时,把两个位置的“I”看做一样来处理,这明显不合理,所以怎么来解决这个问题。如果加入绝对位置编码貌似可以解决问题,但是处理句子的能力只能是训练集中的最大长度,如果超过训练集中最大长度,模型就无能为力。

二.解决


了解到相对位置编码是从transformer-xl中了解到的,其实实际上是《Self-Attention with Relative Position Representations》这篇论文中首先提出来的。先来看下它引入的相对位置编码是怎样的?所谓的相对,从字面意思就是一个位置相对另一个位置。其采用一组可训练的embedding向量来表示输入句子中每个单词的位置编码。

如果我们以其中一个单词为中心,那么其有左边也有右边的单词,假设我们一个句子的长度为5,那么将有9个embedding向量可学习,一个embedding向量表示当前词,其中4个embedding向量用来表示其左边的单词,另外4个embedding向量来表示其左边的单词。为什么是9个embedding向量呢?因为我们在实际计算时候,5个单词都有可能作为中心词。上图中以第5个位置(索引为4)的单词为中心,那么其左边的单词的编号为:-1,-2,-3,-4,右边的单词的编号为:+1,+2,+3,+4 。

下面的图来表示这些embedding是怎么用的

对于第一个位置的单词“I”,当transformer计算“I”跟“therefore”的attention信息时候,"therefore"会采用第6个位置编码,因为我们是以第4个索引为中心,“therefore”是位于“I”的右边相对于“I”的相对距离为2,所以其采用的是第6个embedding向量。

跟之前一样,当计算到第二个“I”和其他单词的attention信息时,如计算其跟左边“therefore”的attention信息,那么“therefore”采用的位置编码为第3个embedding向量,因为它在“I”的左边相对于“I”偏离1个距离,索引应该采用第3个embedding向量。

三.RPR向量

先声明一些变量的,给出其作用先:

zi:输入单词i的输出表示

aij:单词i和单词j的权重系数

eij:单词i和单词j经过softmax后i对于j的关注程度

h:multihead attention的head个数

dx:输入序列的embedding size

dz:zi的embedding size

注意,有两个位置编码向量需要学习,一个是为了计算zi,一个是为了计算eij。最大的长度也会被考虑在内,如果我们中间索引为k,那么会有2k+1个相对位置编码向量需要学习,其中k个是其左边的,k个是其右边的,还有一个属于自己。如果长度超过2k+1,那么其右边超过k的全部置为k,左边超过k的全部置为0。下面是个长度为10的句子的例子,其中k=3,那么它到相对位置编码表中拿向量的索引为:

这么做的原因有:

1.作者认为超出范围的位置还采用精准的位置编码时没必要的

2.clip最大长度是的模型可以学到训练集中没见过的长度


四.实现

先看下传统的transformer没有采用RPR来计算zi向量的方式:

再看下采用RPR来计算的方式:

对比两种计算方式,比较容易看出来,其实是在计算zi的时候,计算完j跟权重w后加上i相对于j的相对位置编码。而在计算eij时候同理,计算完j跟权重w后加上ij的相对位置编码。注意这里两种权重矩阵是不同的,理解transformer中的q,k,v 这三个向量分别对应过去。

五.代码实现解析


实现代码参考:

https://github.com/tensorflow/tensor2tensor/blob/9e0a894034d8090892c238df1bd9bd3180c2b9a3/tensor2tensor/layers/common_attention.py#L1556-L1587

对于没有RPR的计算如下,是比较好计算的,代码通过矩阵的相乘直接可以得到eij.


如果要实现RPR,我们把其eij计算方式转换下得到:

分子第一项中,我们的输入xi的tensor的Shape为:(B,h,seq_length,d),它计算的是query和key的关系,所以第一项的输出为(B,h,seq_length,seq_length)。了解transformer就知道其是多header,分子的第一项其实是比较好计算的,第二项的输出shape必须跟第一项一致。

第二项中,aijK表示的是ij的相对位置编码,从位置编码的Embeding向量table中去lookup得到的,位置编码的embedding向量table的Shape为(seq_length,seq_length,da),转换下维度得到(seq_length,da,seq_length),其中位置编码向量的table我们用A来表示,转换后得到AT。

xi跟WQ相乘后得到tensor其shape为(B,h,seq_length,dz),转换下维度得到(seq_length,B,h,dz),在转换下得到(seq_length,B*h,dz),再跟aijK来相乘,实质是跟AT相乘,所以(seq_length,B*h,dz)跟(seq_length,da,seq_length)矩阵相乘,dz=da,得到(seq_length,B*h,seq_length),reshape下得到(seq_length,B,h,seq_length),transpose为(B,h,seq_length,seq_length)这样就跟第一项对应起来了。

参考:

这篇文章基本上是算是参考第一个链接,只是没有完全按照它的翻译,而是用自己理解的话语写出来。

https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a​medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a

Relative Multi-Headed Attention​nn.labml.ai/transformers/xl/relative_mha.html正在上传…重新上传取消

Attention Is All You Need​arxiv.org/abs/1706.03762

编辑于 2022-09-24 17:19

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/7163.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络

计算机网络 学习路径规划Cisco Packet TracerCisco Packet Tracer下载和安装探究集线器性质访问Web服务器(加深对网络体系结构的理解)_一台主机访问一个Web服务器,HTTP协议 MAC帧交换机网桥网桥和交换机的区别 学习路径规划 用cisco的课程和…

CV多模态和AIGC的原理解析:从CLIP、BLIP到Stable Diffusion、Midjourney

前言 终于开写本CV多模态系列的核心主题:stable diffusion相关的了,为何执着于想写这个stable diffusion呢,源于三点 去年stable diffusion和midjourney很火的时候,就想写,因为经常被刷屏,但那会时间错不…

OpenCV(图像处理)-基于python-滤波器(低通、高通滤波器的使用方法)

1.概念介绍2. 图像卷积filter2D() 3. 低通滤波器3.1 方盒滤波和均值滤波boxFilter()blur() 3.2 高斯滤波(高斯噪音)3.3 中值滤波(胡椒噪音)3.4 双边滤波 4. 高通滤波器4.1Sobel(索贝尔)(高斯&am…

查询Oracle当前用户下,所有数据表的总条数

1. 需求 查询Oracle当前用户下,所有数据表的总条数 2.方法 存储过程 3. 操作 3.1 新建 右键点击Procedures,点击New 点击OK 把存储过程写进去,然后点击编译运行: create or replace procedure tables_count ist_count n…

uni-App踩坑记录

​ 1、uni自己封装的axios在真机中失效,发不出请求 uniapp中使用axios 需要配置适配器 (添加适配器有点费劲,直接封装uni自带请求也可以) axios-adapter-uniapp传送门 axios.defaults.adapter function(config) { //自己定义个适配器,用来…

WIN11系统安装MySql8.0.15详细安装

一.下载mysql8.015数据库 下载地址: 如下图所示 此处下载的是8.0.15版本,免安装版,系统为64位系统: 二,配置mysql环境变量: D:\program_file_worker\mysql8.15\mysql-8.0.15-winx64\bin 三. 环境配置完成后&#xff…

UI自动化截图之chromeFirefox篇

在web的UI自动化中,小伙伴们经常遇到的一个问题是,IE的截屏非常好实现(一个save_screenshot即可满足),而chrome和Firefox的全屏截图就让人很是头疼了。今天作者来给大家分享下自己实例中使用的chrome和Firefox浏览器全…

Jira UI Locations及注意事项总结

issue view ui locations : https://developer.atlassian.com/server/jira/platform/issue-view-ui-locations/#issue-operations-bar-locations1.问题操作栏Issue Operations Bar Locations模块分为两部分: opsbar-operationsflopsbar-transitions两个location.共同定义了问题…

【uniapp】uniapp反向代理解决跨域问题(devServer)

背景介绍 前段时间,在拿uniapp开发的时候,出现了跨域问题,按理说跨域应该由后端解决,但既然咱前端可以上,我想就上了(顺手装个13) 首先介绍什么是跨域 出于浏览器的同源策略,在发…

【Java-15】反射知识总结

01_类的加载 类的加载过程类的加载时机 类的加载 当程序在运行后,第一次使用某个类的时候,会将此类的class文件读取到内存,并将此类的所有信息存储到一个Class对象中 说明:Class对象是指java.lang.Class类的对象,此类…

想知道PDF转高清图片软件哪个好?

张琳是一名设计师,她经常需要将自己的设计作品整理成PDF文档,以便向客户展示和交付。然而,有时客户需要对设计进行更详细的审查,而PDF格式的文件并不方便进行缩放和查看细节。这一问题让张琳感到非常困扰,她希望能够找…

【Redisson】Redisson--限流器

Redisson系列文章: 【Redisson】Redisson–基础入门【Redisson】Redisson–布隆(Bloom Filter)过滤器【Redisson】Redisson–分布式锁的使用(推荐使用)【分布式锁】Redisson分布式锁底层原理【Redisson】Redisson–限流器 文章目录 一、限流…