【Bert101】变压器模型背后的复杂数学【01/4】

一、说明

        众所周知,变压器架构是自然语言处理(NLP)领域的突破。它克服了 seq-to-seq 模型(如 RNN 等)无法捕获文本中的长期依赖性的局限性。变压器架构被证明是革命性架构(如 BERT、GPT 和 T5 及其变体)的基石。正如许多人所说,NLP正处于黄金时代,说变压器模型是一切开始的地方并没有错。

二、对变压器架构的需求

        如前所述,需求是发明之母。传统的 seq-to-seq 模型在处理长文本时并不好。这意味着模型在处理输入序列的后半部分时,往往会忘记输入序列前半部分的学习。这种信息丢失是不可取的。

        尽管像 LSTM 和 GRU 这样的门控架构通过丢弃在记住重要信息的过程中无用的信息,在处理长期依赖关系方面表现出一些改进,但这仍然不够。世界需要更强大的东西,2015年,Bahdanau等人引入了“注意力机制”。 它们与RNN / LSTM结合使用,以模仿人类行为,专注于选择性事物,而忽略其余事物。Bahdanau建议为句子中的每个单词分配相对重要性,以便模型专注于重要的单词而忽略其余单词。它被认为是对神经机器翻译任务的编码器-解码器模型的巨大改进,很快,注意力机制的应用也在其他任务中推广。

变压器模型时代

        变压器模型完全基于一种注意力机制,也称为“自我注意”。这种架构在 2017 年的论文“注意力是你所需要的一切”中向世界介绍。它由编码器-解码器架构组成。

无花果。高级转换器模型体系结构(来源:作者)

在高层次上,

  • 编码器负责接受输入的句子并将其转换为隐藏的表示形式,丢弃所有无用的信息。
  • 解码器接受此隐藏表示形式并尝试生成目标句子。

在本文中,我们将深入研究变压器模型的编码器组件的详细细分。在下一篇文章中,我们将详细介绍解码器组件。让我们开始吧!

三、变压器编码器

        变压器的编码器块由一堆按顺序工作的N个编码器组成。一个编码器的输出是下一个编码器的输入,依此类推。最后一个编码器的输出是馈送到解码器块的输入句子的最终表示形式。

无花果。带堆叠编码器的Enoder模块(来源:作者)

如下图所示,每个编码器块可以进一步分成两个组件。

        无花果。编码器层的组件(来源:作者)

        让我们逐一详细研究这些组件中的每一个,以了解编码器块的工作原理。编码器模块中的第一个组成部分是多头注意力,但在我们进入细节之前,让我们先了解一个基本概念:自我注意

3. 1 自我注意机制

        每个人脑海中可能出现的第一个问题:注意力和自我注意力是不同的概念吗? 是的,他们是。(咄!

        传统上,注意力机制是为神经机器翻译任务而存在的,如上一节所述。因此,本质上应用了注意力机制来映射源句和目标句。当 seq-to-seq 模型逐个令牌执行翻译任务时,注意机制可帮助我们在为目标句子生成标记 x 时识别源句子中的哪些标记需要更多关注。为此,它利用编码器和解码器的隐藏状态表示来计算注意力分数,并根据这些分数生成上下文向量作为解码器的输入。如果您想了解有关注意力机制的更多信息,请跳到这篇文章(精彩的解释!

        回到自我注意,主要思想是计算注意力分数,同时将源句子映射到自身。如果你有这样的句子,

“男孩没有过马路因为它太宽了。

        我们人类很容易理解上面句子中的“它”一词指的是“道路”,但是我们如何使我们的语言模型也理解这种关系呢?这就是自我关注的地方!

在高层次上,将句子中的每个单词与句子中的每个其他单词进行比较,以量化关系并理解上下文。出于表示目的,您可以参考下图。

让我们详细看看这种自我注意是如何计算的(真实)。

  • 为输入句子生成嵌入

查找所有单词的嵌入并将它们转换为输入矩阵。这些嵌入可以通过简单的标记化和独热编码生成,也可以通过嵌入算法(如BERT等)生成。输入矩阵的维度将等于句子长度 x 嵌入维度。我们称此输入矩阵为 X 以供将来参考。

  • 将输入矩阵转换为Q,K和V

为了计算自我注意,我们需要将X(输入矩阵)转换为三个新矩阵:
- 查询 (Q)- 键 (K)- 值 (V)

 

为了计算这三个矩阵,我们将随机初始化三个权重矩阵,即Wq,Wk和Wv。输入矩阵X将与这些权重矩阵Wq,Wk和Wv相乘,分别获得Q,K和V的值。在此过程中将学习权重矩阵的最佳值,以获得更准确的Q,K和V值。

  • 计算 Q 和 K 转置的点积

从上图中,我们可以暗示 qi、ki 和 vi 表示句子中第 i 个单词的 Q、K 和 V 的值。

无花果。Q 和 K 转置的点积示例(来源:作者)

输出矩阵的第一行将告诉您由 q1 表示的 word1 如何与使用点积的句子中的其余单词相关。点积的值越高,单词越相关。为了直观地了解计算此点积的原因,您可以在信息检索方面理解 Q(查询)和 K(键)矩阵。所以在这里,
- Q 或查询 = 您正在搜索的术语
- K 或 Key = 搜索引擎中的一组关键字,与 Q 进行比较和匹配。

  • 缩放点积

与上一步一样,我们正在计算两个矩阵的点积,即执行乘法运算,该值有可能爆炸。为了确保这种情况不会发生并且梯度稳定,我们将 Q 和 K 转置的点积除以嵌入维度 (dk) 的平方根。

  • 使用 softmax 规范化值

使用 softmax 函数进行归一化将产生介于 0 和 1 之间的值。具有高比例点积的单元格将进一步升高,而低值将减少,使匹配的单词对之间的区别更加清晰。生成的输出矩阵可以视为分数矩阵S

  • 计算注意力矩阵 Z

        将值矩阵或V乘以从上一步获得的分数矩阵S,以计算注意力矩阵Z。

        但是等等,为什么要乘法?

        假设 Si = [0.9, 0.07, 0.03] 是句子中第 i 个单词的分数矩阵值。将此向量与 V 矩阵相乘以计算 Zi(第 i 个单词的注意力矩阵)。

Zi = [0.9 * V1 + 0.07 * V2 + 0.03 * V3]

        我们是否可以说,为了理解第 i 个单词的上下文,我们应该只关注 word1(即 V1),因为注意力分数值的 90% 来自 V1?我们可以清楚地定义重要的单词,在这些单词中,应该更加注意理解第i个单词的上下文。

        因此,我们可以得出结论,一个词在Zi表示中的贡献越高,单词之间的批判性和相关性就越大。

        现在我们知道了如何计算自我注意力矩阵,让我们了解多头注意力机制的概念。

3.2 多头注意力机制

        如果您的分数矩阵偏向于特定的单词表示,会发生什么?它会误导您的模型,结果不会像我们预期的那样准确。让我们看一个例子来更好地理解这一点。

        S1:“一切都很好

        Z(井) = 0.6 * V(全部) + 0.0 * v(是) + 0.4 * V(井)

        S2:“狗吃了食物,因为它饿了

        Z(它) = 0.0 * V(的) + 1.0 * V(狗) + 0.0 * V(吃) + ...... + 0.0 * V(饥饿)

        在 S1 情况下,在计算 Z(well) 时,对 V(all) 给予了更多的重要性。它甚至比V(好吧)本身还要多。无法保证这有多准确。

        在 S2 的情况下,在计算 Z(it) 时,所有的重要性都给了 V(dog),而其余单词的分数也是 0.0,包括 V(it)。这看起来可以接受,因为“it”这个词是模棱两可的。将它更多地与另一个词联系起来而不是将这个词本身联系起来是有意义的。这就是计算自我注意力的全部目的。处理输入句子中歧义单词的上下文。

        换句话说,我们可以说,如果当前单词是模棱两可的,那么在计算自我注意时可以更加重视其他单词,但在其他情况下,这可能会对模型产生误导。那么,我们现在该怎么办?

        如果我们计算多个注意力矩阵而不是计算一个注意力矩阵并从中导出最终的注意力矩阵会怎样?

        这正是多头注意力的全部意义所在!我们计算注意力矩阵z1,z2,z3,.....,zm的多个版本,并将它们连接起来以得出最终的注意力矩阵。这样我们就可以对自己的注意力矩阵更有信心。

        转到下一个重要概念,

3.3 位置编码

        在seq-to-seq模型中,输入的句子被逐字输入到网络,这允许模型跟踪单词相对于其他单词的位置。

        但在变压器模型中,我们遵循不同的方法。它们不是逐字逐句地输入,而是并行馈送,这有助于减少训练时间和学习长期依赖性。但是使用这种方法,单词顺序就丢失了。但是,要正确理解句子的含义,词序非常重要。为了克服这个问题,引入了一种称为“位置编码”(P)的新矩阵。

        该矩阵 P 与输入矩阵 X 一起发送,以包含与词序相关的信息。出于显而易见的原因,X 和 P 矩阵的维度是相同的。

        为了计算位置编码,使用下面给出的公式。

        fig。计算位置编码的公式(来源:作者)

在上面的公式中,

  • pos = 单词在句子中的位置
  • d = 字/标记嵌入的维度
  • i = 表示嵌入中的每个维度

在计算中,d 是固定的,但 pos 和 i 会有所不同。如果 d=512,则 i ∈ [0, 255],因为我们取 2i。

如果您想了解更多信息,本视频将深入介绍位置编码。

转换器神经网络可视化指南 — (第 1 部分)位置嵌入

我正在使用上述视频中的一些视觉效果来用我的话来解释这个概念。

无花果。位置编码向量表示(来源:作者)

上图显示了位置编码向量的示例以及不同的变量值。

无花果。具有常数 i 和 d 的位置编码向量(来源:作者)

fig。具有常数 i 和 d 的位置编码向量(来源:作者)

上图显示了如果 i 是常量并且只有 pos 变化,PE(pos, 2i) 的值将如何变化。众所周知,正弦波是一个周期函数,倾向于在固定间隔后重复。我们可以看到 pos = 0 和 pos = 6 的编码向量是相同的。这是不可取的,因为我们需要不同的位置编码向量来表示不同的 pos 值

这可以通过改变正弦波的频率来实现。

无花果。具有不同 pos 和 i 的位置编码向量(来源:作者)

随着i的值变化,正弦波的频率也随之变化,导致不同的波,因此,导致每个位置编码向量的值不同。这正是我们想要实现的目标。

位置编码矩阵(P)被添加到输入矩阵(X)并馈送到编码器。

无花果。将位置编码添加到输入嵌入(来源:作者)

编码器的下一个组件是前馈网络

3.4 前馈网络

        编码器块中的这个子层是具有两个密集层和 ReLU 激活的经典神经网络。它接受来自多头注意力层的输入,对同一层执行一些非线性变换,最后生成上下文化向量。全连接层负责考虑每个注意力头并从中学习相关信息。由于注意力向量彼此独立,因此它们可以以并行方式传递到变压器网络。

        编码器块的最后一个也是最后一个组件是Add&Norm组件

3.5 添加和规范组件

这是一个残差层,然后是层归一化。残差层确保在处理过程中不会丢失与子层输入相关的重要信息。而规范化层可促进更快的模型训练并防止值发生大量变化。

无花果。包含添加和规范层的编码器组件(来源:作者)

        在编码器中,有两个添加层和规范层:

  • 将多头注意力子层的输入连接到其输出
  • 将前馈网络子图层的输入连接到其输出

        至此,我们总结编码器的内部工作。总结本文,让我们快速回顾一下编码器使用的步骤:

  • 生成输入句子的嵌入或标记化表示。这将是我们的输入矩阵 X。
  • 生成位置嵌入以保留与输入句子的词序相关的信息,并将其添加到输入矩阵 X 中。
  • 随机初始化三个矩阵:Wq,Wk和Wv,即查询,键和值的权重。这些权重将在变压器模型训练期间更新。
  • 将输入矩阵X与Wq,Wk和Wv中的每一个相乘,以生成Q(查询),K(键)和V(值)矩阵。
  • 计算 Q 和 K 转置的点积,通过将其除以 dk 的平方根或嵌入维数来缩放乘积,最后使用 softmax 函数对其进行归一化。
  • 通过将 V 或值矩阵乘以 softmax 函数的输出来计算注意力矩阵 Z。
  • 将此注意力矩阵传递给前馈网络以执行非线性转换并生成上下文化嵌入。

四、后记 

        在下一篇文章中,我们将了解转换器模型的解码器组件的工作原理。这就是本文的全部内容。我希望你觉得它有用。如果你这样做了,请不要忘记鼓掌并与您的朋友分享。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/63808.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RocketMQ 延迟消息

RocketMQ 延迟消息 RocketMQ 消费者启动流程 什么是延迟消息 RocketMQ 延迟消息是指,生产者发送消息给消费者消息,消费者需要等待一段时间后才能消费到。 使用场景 用户下单之后,15分钟未支付,对支付账单进行提醒或者关单处理…

走进知识图谱(二)【世界知识图谱篇】知识表示的经典模型与平移模型及基于复杂关系建模的知识表示学习

上篇文章提到,该系列文章将主要围绕世界知识图谱和语言知识图谱这两大类知识图谱进行展开,并且提到知识图谱的主要研究包括了知识表示学习、知识自动获取和知识的推理与应用三大部分。今天主要介绍世界知识图谱的知识表示学习,其中包括经典的…

使用C语言实现UDP消息接收

目录 简介:步骤:步骤 1: 创建套接字步骤 2: 接收消息步骤 3: 完成 函数及变量解释总结: 简介: 在网络通信中,UDP(User Datagram Protocol)是一种无连接协议,它提供了一种快速、高效的数据传输方法。本文将向您展示如何使用C语言编…

Spring Bean 生命周期的执行流程

问题描述 Spring 生命周期全过程大致分为五个阶段: 1、创建前准备阶段 2、创建实例阶段 3、依赖注入阶段 4、 容器缓存阶段 5、销毁实例阶段 下图是 Spring Bean 生命周期完整流程图,其中对每个阶段的具体操作做了详细介绍: 一、创建前准备阶…

栈和队列详解

目录 栈 栈的概念及结构: 栈的实现: 代码实现: Stack.h stack.c 队列: 概念及结构: 队列的实现: 代码实现: Queue.h Queue.c 拓展: 循环队列(LeetCode题目链接&#xff0…

每天一道leetcode:516. 最长回文子序列(动态规划中等)

今日份题目: 给你一个字符串 s ,找出其中最长的回文子序列,并返回该序列的长度。 子序列定义为:不改变剩余字符顺序的情况下,删除某些字符或者不删除任何字符形成的一个序列。 示例1 输入:s "bbb…

【高频面试题】JVM篇

文章目录 一、JVM组成1.什么是程序计数器2.什么是Java堆?3.能不能介绍一下方法区(元空间)4.你听过直接内存吗5.什么是虚拟机栈6.垃圾回收是否涉及栈内存?7.栈内存分配越大越好吗?8.方法内的局部变量是否线程安全?9.什么…

【技巧】如何保护PowerPoint不被改动?

PPT,也就是PowerPoint,是很多小伙伴在工作生活中经常用到的图形演示文稿软件。 做好PPT后,担心自己不小心改动了或者不想他人随意更改,我们可以如何保护PPT呢?下面小编就来分享两个常用的方法: 1. 将PPT改…

吉利科技携手企企通,打造集团化数智供应链系统

近日,吉利科技集团有限公司(以下简称“吉利科技”)联合企企通成功召开SRM采购供应链管理项目启动会。企企通与吉利科技高层、项目负责人与团队成员出席此次启动会。 双方将携手在企业供应商全生命周期管理、采购全流程、电子招投标、采购分析…

阿里云预装LAMP应用导致MySQL不显示访问密码如何解决

😀前言 本篇博文是关于阿里云云服务器ECS部署MySQL过程中出现的一下坑,希望能够帮助到您😊 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我的文章可以帮助到大家…

【问题解决】Git命令行常见error及其解决方法

以下是我一段时间没有使用xshell,然后用git命令行遇到的一些系列错误和他们的解决方法 遇到了这个报错: fatal: Not a git repository (or any of the parent directories): .git 我查阅一些博客和资料,可以解决的方式: git in…

C++笔记之字节数组的处理

C笔记之字节数组的处理 code review! 文章目录 C笔记之字节数组的处理1.字节数组打印2.将字节数组转换为十六进制字符串并打印3.将字符串转为字节数组4.将字节数组转为字符串5.字节数组和字符数组的区别6.字节数组用于二进制数据存储7.字节数组用于网络通信数据传输8.使用 un…