PixelSNAIL论文代码学习(3)——自注意力机制的实现

文章目录

    • 引言
    • 正文
      • 介绍
      • 自注意力机制的简单实现样例
      • 本文中的自注意力机制
      • 具体实现代码分析
        • nn.nin函数的具体实现
        • nn.causal_attention模块实现
        • 注意力模块实现代码
        • 完整实现代码
        • 使用pytorch实现因果注意力模块causal_atttention模块
      • 问题
    • 总结
    • 引用

引言

  • 阅读了pixelSNAIL,很简短,就用了几页,介绍了网络结构,介绍了试验效果就没有了,具体论文学习链接
  • 这段时间看他的代码,还是挺痛苦的,因为我对于深度学习的框架尚且不是很熟练 ,而且这个作者很厉害,很多东西都是自己实现的,所以看起来十分费力,本来想逐行分析,结果发现逐行分析不现实,所以这里按照模块进行分析。
  • 今天就专门来学习一下他自注意力机制是如何实现的。

正文

介绍

  • 含义:自注意力机制是一种让模型在处理序列数据时,考虑数据其他位置信息的方法(可以用来考虑时序信息)。对于每一个序列中的元素,自注意力机制会计算其与序列中其他元素的相似度,并使用这些相似度来更新元素本身

  • 基本步骤

    • 线性投影:对于输入序列X,通过三个不同的线性变换得到Query(Q)Key(K)Value(V)三个矩阵
      • Query:查询,用于和key进行匹配
      • key:与Query进行匹配,决定了每一个value的权重
      • value:值,实际想要加权平均的内容
    • 计算注意力分数:使用QK的点积来计算注意力分数
    • 缩放:将注意力分数除以 d k d_k dk的平方根, d k d_k dk是key的维度
    • 应用softmax:沿着每一行对缩放后的注意力分数应用softmax函数
    • 加权求和:使用softmax输出对
  • 原理解释

    • 计算Query和Key的点积,因为通过点积来衡量两个矩阵的相似性,如果相似性越大,那么他们的点积就越大。借此使得模型能够关注与Query相似的key

    • 使用softmax函数和缩放因子是为了归一化最终的输出,让最终的输出以概率的方式呈现

    • 最终的输出是通过权重和value的加权和计算出的。

    • 并没有理论推导,但是在transformer中的效果很好

自注意力机制的简单实现样例

  • 下面是公式推导,基本上具体实现也是按照这个公式推导进行的
    在这里插入图片描述

  • 具体代码实现

  • 假设我们有一个句子:“I love dogs”,我们希望通过自注意力机制来重新表示每个词。

  • 首先,我们需要将每个词转化为一个向量。为了简化,我们假设:

  • 在这里插入图片描述

  • 具体代码如下,基本上是按照上述公式实现的

import numpy as np
import torch
import torch.nn.functional as F# Query, Key, Value
Q = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
K = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
V = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])# Attention Weight Calculation
d = 3  # dimension of Q and K
attention_weights = F.softmax(Q @ K.T / np.sqrt(d), dim=-1)# Output Calculation
output = attention_weights @ V

本文中的自注意力机制

  • 下面是他具体的自注意力模块的生成流程图,无非是明确三个矩阵,Q、K和V,可以看到作者给了标注,分别是经过了1*1的卷积,具体实现代码看下节
    在这里插入图片描述

具体实现代码分析

  • 下述为整个模型中具体实现自注意力机制的代码部分,要实现自注意力机制,无非是明确三个矩阵的具体是哪个矩阵,具体如下

    • Query矩阵:经过n次门控残差网络处理的ul矩阵和背景矩阵background拼接而成

    • Key矩阵:x, ul, background三个矩阵拼接成的矩阵

    • Value矩阵::经过n次门控残差网络处理的ul矩阵

  • 这里两个作者自己定义函数,分别是nn.nin和nn.causal_attention两个操作模块。这里简单介绍一下功能,在下一节具体讲解代码

    • nn.nin: 1* 1的卷积层,用于减少或者增加数据张量的深度,但是不改变对应的batch_size、H和W
    • nn.causal_attention:实现因果注意力机制,确保当前元素之和之前的元素进行交互,不与未来的元素进行交互,通过掩码实现。

nn.nin函数的具体实现

  • 这里是实现了1*1卷积,不改变除了深度以外的任何形状,通过这个操作来改变矩阵的深度或者频道数
@add_arg_scope
def nin(x, num_units, **kwargs):""" a network in network layer (1x1 CONV) """s = int_shape(x)# 这里是将前三个维度相乘,保留最后一个维度,将原来的四维度矩阵变成二维度矩阵x = tf.reshape(x, [np.prod(s[:-1]), s[-1]])# 全连接层实现一乘一卷积x = dense(x, num_units, **kwargs)return tf.reshape(x, s[:-1] + [num_units])
  • 总的来说,实现起来还是很容易的,不过说实话,还是pytorch方便点,直接指定filter_size为1不就行了

nn.causal_attention模块实现

  • 这个模块是因果卷积和自注意力机制的结合,在权重矩阵上乘以一个因果掩码矩阵,来抑制未来的信息

  • 参数说明

    • key: [bs, h, w, chns]

    • mixin: [bs, h, w, chns]

    • query: [bs, h, w, chns]

    • downsample: int.表示下采样的倍数

      • 在必要的情况下,使用下采样减少需要处理的键值数量,加速运算
      • 代码中是使用最大池化进行下采样的
    • use_pos_enc: bool.表示是否使用位置编码

      • 常规的卷积中,并不考虑到位置信息,通过位置编码来补充信息,因为这里处理的是序列信息。
  • 下面是这个代码的具体流程,为了方便起见,这里就忽略了对于下采样和位置编码的判断

在这里插入图片描述

def causal_attention(key, mixin, query, downsample=1, use_pos_enc=False):'''key: [bs, h, w, chns]mixin: [bs, h, w, chns]query: [bs, h, w, chns]downsample: int.表示下采样的倍数use_pos_enc: bool.表示是否使用位置编码'''# 获取key的形状bs, nr_chns = int_shape(key)[0], int_shape(key)[-1]# 下采样if downsample > 1:pool_shape = [1, downsample, downsample, 1]key = tf.nn.max_pool(key, pool_shape, pool_shape, 'SAME')mixin = tf.nn.max_pool(mixin, pool_shape, pool_shape, 'SAME')# 使用位置编码xs = int_shape(mixin)if use_pos_enc:pos1 = tf.range(0., xs[1]) / xs[1]pos2 = tf.range(0., xs[2]) / xs[1]mixin = tf.concat([mixin,tf.tile(pos1[None, :, None, None], [xs[0], 1, xs[2], 1]),tf.tile(pos2[None, None, :, None], [xs[0], xs[2], 1, 1]),], axis=3)# 因果掩码# 通过get_causal_mask函数生成一个上三角矩阵,对角线为0,其余为1mixin_chns = int_shape(mixin)[-1]canvas_size = int(np.prod(int_shape(key)[1:-1]))canvas_size_q = int(np.prod(int_shape(query)[1:-1]))causal_mask = get_causal_mask(canvas_size_q, downsample)# 注意力权重的计算# 使用矩阵乘法来计算查询和键之间的点积dot = tf.matmul(tf.reshape(query, [bs, canvas_size_q, nr_chns]),tf.reshape(key, [bs, canvas_size, nr_chns]),transpose_b=True# 应用因果掩码和一个小数来抑制未来的信息) - (1. - causal_mask) * 1e10dot = dot - tf.reduce_max(dot, axis=-1, keep_dims=True)# 实现softmax,计算注意力权重causal_exp_dot = tf.exp(dot / np.sqrt(nr_chns).astype(np.float32)) * causal_maskcausal_probs = causal_exp_dot / (tf.reduce_sum(causal_exp_dot, axis=-1, keep_dims=True) + 1e-6)# 输出计算mixed = tf.matmul(causal_probs,tf.reshape(mixin, [bs, canvas_size, mixin_chns]))return tf.reshape(mixed, int_shape(query)[:-1] + [mixin_chns])

注意力模块实现代码

  • 虽然这个流程很好理解,根据代码就可以看出来,就是矩阵的变换,但是有个地方是怪怪的,想问为什么?但是这个是通过实验证明有效的。

    • 我知道了,我疑惑的是,作者是如何探索出这种结构的?
      • 为什么经过因果注意力机制处理后,又把他丢进了门控残差网络的处理?
  • 下面是具体的流程图,整个过程主要用到了三个矩阵,分别是

    • x:原始输入矩阵

    • ul:经过n次门控残差网络处理的矩阵

    • background:是一个背景矩阵,用来传递每一个像素的位置信息,主要是在宽度和高度两个维度上的位置信息。维度为[1,4,4,2]

  • 具体流程图如下

在这里插入图片描述

  • 重复了若干次注意力机制处理后,为了防止出现梯度消失,将最终的输出在经过elu指数线性单元进行激活,改变输出维度,作为最终输出。
	# 注意力机制具体实现# 这个ul是门控残差网络的ul = ul_list[-1]# 准备原始内容,包括了原始输入x,上一次的输出ul,以及背景信息raw_content = tf.concat([x, ul, background], axis=3)# 生成key和queryq_size = 16raw = nn.nin(nn.gated_resnet(raw_content, conv=nn.nin), nr_filters // 2 + q_size)key, mixin = raw[:, :, :, :q_size], raw[:, :, :, q_size:]# 这里是生成queryraw_q = tf.concat([ul, background], axis=3)query = nn.nin(nn.gated_resnet(raw_q, conv=nn.nin), q_size)# 计算注意力mixed = nn.causal_attention(key, mixin, query, downsample=att_downsample)# 将注意力的结果和原始结果通过按位加来是心爱ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin))

完整实现代码

def _base_noup_smallkey_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5,nr_filters=256, attn_rep=12, nr_logistic_mix=10,att_downsample=1, resnet_nonlinearity='concat_elu'):"""x:输入张量,形状为(N,H,W,D1),N为batch_size,H,W为图像的高和宽,D1为图像的通道数h:可选的N x K矩阵,用于在生成模型上进行条件init:是否初始化ema:是否使用指数移动平均dropout_p:dropout概率nr_resnet:残差网络的数量nr_filters:卷积核的数量attn_rep:注意力机制的重复次数nr_logistic_mix:logistic混合的数量att_downsample:注意力机制的下采样resnet_nonlinearity:残差网络的非线性激活函数We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and producea Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiberof the x_out tensor describes the predictive distribution for the RGB atthat position.'h' is an optional N x K matrix of values to condition our generative model on"""counters = {}# 使用arg_scope,可以给函数的参数自动赋予某些默认值# 设置一组层[nn.conv2d,nn.deconv2d,nn.gated_resnet,nn.dense]这样一组层的counters,init,ema,dropout_p参数为默认值with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense, nn.nin],counters=counters, init=init, ema=ema, dropout_p=dropout_p):# 根据传入的resnet_nonlinearity参数,选择不同的激活函数if resnet_nonlinearity == 'concat_elu':resnet_nonlinearity = nn.concat_eluelif resnet_nonlinearity == 'elu':resnet_nonlinearity = tf.nn.eluelif resnet_nonlinearity == 'relu':resnet_nonlinearity = tf.nn.reluelse:raise('resnet nonlinearity ' +resnet_nonlinearity + ' is not supported')with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h):# // 通过PixelCNN进行上行传递 # 创建一个背景张量,形状为(1,H,W,2),其中H,W为图像的高和宽,用来保存每一个像素位置的相对位置信息# 获取输入向量的形状xs = nn.int_shape(x)background = tf.concat([# 创建一个长度为xs[1](即输入x的高度)的一维张量。张量的值从−0.5到0.5,表示水平方向上的位置信息# 例如,如果xs[1]为32,则tf.range(xs[1], dtype=tf.float32)的值为[0,1,2,...,31]# 然后将其归一化到[-0.5,0.5],即((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) / xs[1])# 最后将其扩展为形状为(1,H,W,1)的张量# 这里是扩展在第二个维度,也就是H,然后加上对应形状的矩阵, 使用扩散机制,将背景矩阵复制为同样大小。((tf.range(xs[1], dtype=tf.float32) - xs[1] / 2) / xs[1])[None, :, None, None] + 0. * x,((tf.range(xs[2], dtype=tf.float32) - xs[2] / 2) / xs[2])[None, None, :, None] + 0. * x,],axis=3)# add channel of ones to distinguish image from padding later on# 增加一个信号,用于区分图像和填充x_pad = tf.concat([x, tf.ones(xs[:-1] + [1])], axis=3)# 下传递,从左上角开始# nn.down_shifted_conv2d:下移卷积:# nn.down_right_shifted_conv2d:右下移卷积# nn.down_shift:下移# nn.right_shift:右移ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1, 3])) +nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 1]))]# stream for up and to the left# 下传递,从右下角开始for attn_rep in range(attn_rep):# 重复n次的门控残差网络for rep in range(nr_resnet):ul_list.append(nn.gated_resnet(ul_list[-1], conv=nn.down_right_shifted_conv2d))# 注意力机制ul = ul_list[-1]# 准备原始内容,包括了原始输入x,上一次的输出ul,以及背景信息raw_content = tf.concat([x, ul, background], axis=3)# 生成key和queryq_size = 16raw = nn.nin(nn.gated_resnet(raw_content, conv=nn.nin), nr_filters // 2 + q_size)key, mixin = raw[:, :, :, :q_size], raw[:, :, :, q_size:]raw_q = tf.concat([ul, background], axis=3)query = nn.nin(nn.gated_resnet(raw_q, conv=nn.nin), q_size)# 计算注意力mixed = nn.causal_attention(key, mixin, query, downsample=att_downsample)# 将注意力的结果与原始内容进行拼接ul_list.append(nn.gated_resnet(ul, mixed, conv=nn.nin))# /// 通过PixelCNN进行下行传递 ///x_out = nn.nin(tf.nn.elu(ul_list[-1]), 10 * nr_logistic_mix)return x_out

使用pytorch实现因果注意力模块causal_atttention模块

  • 实现整个注意力机制,最重要的是实现作者自己定义的causal_attention模块,这个模块实现了三个矩阵query、key还有value的全部操作,同时包含了因果卷积的内容
  • 具体实现如下
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
def get_causal_mask(canvas_size, downsample):"""生成一个上三角矩阵作为因果掩码。参数:- canvas_size: 整数, 矩阵的维度。- downsample: 下采样的倍数。返回:- 因果掩码: 上三角矩阵。"""# 生成一个canvas_size x canvas_size的上三角矩阵mask = torch.triu(torch.ones(canvas_size, canvas_size), diagonal=1+downsample)# 转换为float类型并反转矩阵,使得上三角部分为0,其他部分为1mask = 1.0 - maskreturn mask# causal_attention模块的具体实现
class CausalAttention(nn.Module):# 这里是实现对应因果注意力机制的模块def __init__(self):super(CausalAttention,self).__init__()def forward(self,query,key,mixin,downSample = 1,use_pos_enc = False):'''query:查询矩阵key:关键字矩阵mixin:value矩阵前向传播,实现query和key的点积,以及因果掩码的生成'''# 获取key的形状bs,h,w,nr_chns = key.size()# 进行下采样if downSample > 1:key = F.max_pool2d(key,downSample)mixin = F.max_pool2d(mixin,dowmSample)# 判定是否包含位置编码,这里就是单纯增加了两个维度if use_pos_enc:pos1 = torch.arange(0.,h) / hpos2 = torch.arange(0.,w) / wmixin  =torch.cat([mixin,pos1[None,:,None,None].expand(bs,h,w,1),pos2[None,:,None,None].expand(bs,h,w,1)],dim = 3)# 因果卷积# 生成因果卷积的掩码canvas_size = h * wcanvas_size_q = h * wcausal_mask = get_causal_mask(canvas_size_q,downSample).to(key.device)# 实现key和query的点乘,计算每一个键和查询的相似度,同时屏蔽未来信息# view函数,改变张量的形状,但是不改变数据query = query.view(bs, canvas_size_q, nr_chns) # 形状为:bs,H*W,nr_chnskey = key.view(bs, canvas_size, nr_chns)  # 形状为:bs,H*W,nr_chnsdot = torch.bmm(query, key.permute(0, 2, 1))  # 执行矩阵的批量乘法,bs维度相同,# (H*W,nr_chns) 和(nr_chns,H*W)两个矩阵的点积# 最终的矩阵为(H*W,H*W)# 首先将三角掩码矩阵进行反转,然后再乘以一个极大的负数# 确保未来信息在面对进行softmax激活时,能够变为0dot = dot - (1. - causal_mask) * 1e10# 减去最大值,确保数值稳定性dot = dot - torch.max(dot, dim=-1, keepdim=True)[0]# 实现softmax激活函数,并且加上掩码卷积,抑制未来信息causal_exp_dot = torch.exp(dot / np.sqrt(nr_chns).astype(np.float32)) * causal_maskcausal_probs = causal_exp_dot / (torch.sum(causal_exp_dot, dim=-1, keepdim=True) + 1e-6)# 计算输出矩阵,最终的权重参数乘以对应的因果卷积系数mixin = mixin.view(bs, canvas_size, -1)mixed = torch.bmm(causal_probs, mixin)return mixed.view(bs, h, w, -1)# Test the PyTorch implementation
key = torch.rand(16, 32, 32, 64)
mixin = torch.rand(16, 32, 32, 64)
query = torch.rand(16, 32, 32, 64)
causal_attention = CausalAttention()
result = causal_attention(key, mixin, query)result.shape

问题

  • 这个结构真的复杂,是怎么探索出来?
  • 为什么要重复那么多次门控残差网络?
  • 为什么要重复那么多次注意力机制来提取信息?

总结

  • 这里是实现了具体的注意力模块,这里重点是他所调用的一个因果注意力模块,通过这个模块能够实现注意力机制的同时调用因果卷积,来屏蔽未来信息。
  • 但是具体的执行结果,并不知道作者是怎么探索出来,难道是通过实验吗?如果是这样,自己也可以通过实验,来探索一下,适合特定格式下的声音生成模型的具体结构。
  • 这里学到了很多,chatGPT问了几百条,加上自己的理解。
  • 通过这篇文章,我还知道,我们确实需要不断看新的论文,要总是试试看新的论文能不能添加到对应结构中。

引用

ChatGPT-Plus

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/95256.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ELK安装、部署、调试(一)设计规划及准备

一、整体规划如图: 【filebeat】 需要收集日志的服务器,安装filebeat软件,用于收集日志。logstash也可以收集日志,但是占用的系统资源过大,所以使用了filebeat来收集日志。 【kafka】 接收filebeat的日志&#xff…

SWAT-MODFLOW地表水与地下水耦合

耦合模型被应用到很多科学和工程领域来改善模型的性能、效率和结果,SWAT作为一个地表水模型可以较好的模拟主要的水文过程,包括地表径流、降水、蒸发、风速、温度、渗流、侧向径流等,但是对于地下水部分的模拟相对粗糙,考虑到SWAT…

【LeetCode算法系列题解】第46~50题

CONTENTS LeetCode 46. 全排列(中等)LeetCode 47. 全排列 II(中等)LeetCode 48. 旋转图像(中等)LeetCode 49. 字母异位词分组(中等)LeetCode 50. Pow(x, n)(中等&#xf…

4.2 实现基于栈的表达式求值计算器(难度4/10)

本作业主要考察:解释器模式的实现思想/栈结构在表达式求值方面的绝对优势 C数据结构与算法夯实基础作业列表 通过栈的应用,理解特定领域设计的关键作用,给大家眼前一亮的感觉。深刻理解计算机语言和人类语言完美结合的杰作。是作业中的上等…

QT设置mainwindow的窗口title

QT设置mainwindow的窗口title 在QT程序中,通常会有**aaaa-[bbbbbbb]**这种形式的title,对于刚上手qt的程序员同学,可能会简单的以为修改这种title,就是使用setWindowTitle这个接口,其实只对了一半,这种形式…

聊聊每日站会

这是鼎叔的第七十四篇原创文章。行业大牛和刚毕业的小白,都可以进来聊聊。 欢迎关注本专栏和微信公众号《敏捷测试转型》,星标收藏,大量原创思考文章陆续推出。 每日站会是一线敏捷团队自己的会议,快速同步成员为达成迭代目标所…

Docker consul容器服务自动发现和更新

目录 一、什么是服务注册与发现 二、Docker-consul集群 1.Docker-consul 2.registrator 3.Consul-template 三、Docker-consul实现过程 四、Docker-consul集群配置 1.下载consul服务 2.web服务器启动多例nginx容器,使用registrator自动发现 3.使用…

js:创建一个基于vite 的React项目

相关文档 Vite 官方中文文档React 中文文档React RouterRedux 中文文档Ant Design 5.0Awesome React 创建vite react项目 pnpm create vite react-app --template react# 根据提示,执行命令 cd react-app pnpm install pnpm run dev项目结构 $ tree -L 1 . ├─…

Android Native Code开发学习(三)对java中的对象变量进行操作

Android Native Code开发学习(三) 本教程为native code学习笔记,希望能够帮到有需要的人 我的电脑系统为ubuntu 22.04,当然windows也是可以的,区别不大 对java中的对象变量进行操作 首先我们新建一个java的类 pub…

Oracle21C--Windows卸载与安装

卸载方法: (1)WinR,输入services.msc,打开服务,把Oracle相关的服务全部停止运行(重要) (2)WinR,输入regedit,打开注册表,删除Oracle开…

【Linux】文件

Linux 文件 什么叫文件C语言视角下文件的操作文件的打开与关闭文件的写操作文件的读操作 & cat命令模拟实现 文件操作的系统接口open & closewriteread 文件描述符进程与文件的关系重定向问题Linux下一切皆文件的认识文件缓冲区缓冲区的刷新策略 stuout & stderr 什…

已解决module ‘pip‘ has no attribute ‘pep425tags‘报错问题(如何正确查看pip版本、支持、32位、64位方法汇总)

本文摘要:本文已解决module ‘pip‘ has no attribute ‘pep425tags‘的相关报错问题,并总结提出了几种可用解决方案。同时结合人工智能GPT排除可能得隐患及错误。并且最后说明了如何正确查看pip版本、支持、32位、64位方法汇总 😎 作者介绍&…