探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

Grouped-query Attention,简称GQA

分组查询注意力(Grouped-query Attention,简称GQA)是多查询和多头注意力的插值。它在保持与多查询注意力相当的处理速度的同时,实现了与多头注意力相似的质量。

在这里插入图片描述

自回归解码的标准做法是缓存序列中先前标记的键和值,以加快注意力计算的速度。

  • 然而,随着上下文窗口或批量大小的增加,多头注意力(Multi-Head Attention,简称MHA)模型中键值缓存(Key-Value Cache,简称KV Cache)的大小所关联的内存成本显著增加。

  • 多查询注意力(Multi-Query Attention,简称MQA)是一种机制,它对多个查询仅使用单个键值头,这可以节省内存并大幅加快解码器的推理速度。

  • Llama(一种模型)整合了GQA,以解决在Transformer模型自回归解码期间的内存带宽挑战。主要问题源于GPU进行计算的速度比它们将数据移入内存的速度快。在每个阶段都需要加载解码器权重和注意力键,这消耗了大量的内存。

在这里插入图片描述
在这里插入图片描述

class SelfAttention(nn.Module): def  __init__(self, args: ModelArgs):super().__init__()self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads# Indicates the number of heads for the queriesself.n_heads_q = args.n_heads# Indiates how many times the heads of keys and value should be repeated to match the head of the Queryself.n_rep = self.n_heads_q // self.n_kv_heads# Indicates the dimentiona of each headself.head_dim = args.dim // args.n_headsself.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))def forward(self, x: torch.Tensor, start_pos: int, freq_complex: torch.Tensor):batch_size, seq_len, _ = x.shape #(B, 1, dim)# Apply the wq, wk, wv matrices to query, key and value# (B, 1, dim) -> (B, 1, H_q * head_dim)xq = self.wq(x)# (B, 1, dim) -> (B, 1, H_kv * head_dim)xk = self.wk(x)xv = self.wv(x)# (B, 1, H_q * head_dim) -> (B, 1, H_q, head_dim)xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# (B, 1, H_kv * head_dim) -> (B, 1, H_kv, head_dim)xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)# Apply the rotary embeddings to the keys and values# Does not chnage the shape of the tensor# (B, 1, H_kv, head_dim) -> (B, 1, H_kv, head_dim)xq = apply_rotary_embeddings(xq, freq_complex, device=x.device)xk = apply_rotary_embeddings(xk, freq_complex, device=x.device)# Replace the enty in the cache for this tokenself.cache_k[:batch_size, start_pos:start_pos + seq_len] = xkself.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv# Retrive all the cached keys and values so far# (B, seq_len_kv, H_kv, head_dim)keys = self.cache_k[:batch_size, 0:start_pos + seq_len]values = self.cache_v[:batch_size, 0:start_pos+seq_len] # Repeat the heads of the K and V to reach the number of heads of the querieskeys = repeat_kv(keys, self.n_rep)values = repeat_kv(values, self.n_rep)# (B, 1, h_q, head_dim) --> (b, h_q, 1, head_dim)xq = xq.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)# (B, h_q, 1, head_dim) @ (B, h_kv, seq_len-kv, head_dim) -> (B, h_q, 1, seq_len-kv)scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)scores = F.softmax(scores.float(), dim=-1).type_as(xq)# (B, h_q, 1, seq_len) @ (B, h_q, seq_len-kv, head_dim) --> (b, h-q, q, head_dim)output = torch.matmul(scores, values)# (B, h_q, 1, head_dim) -> (B, 1, h_q, head_dim) -> ()output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))return self.wo(output) # (B, 1, dim) -> (B, 1, dim)

系列博客

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/651648.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker 集群管理实战mesos+zookeeper+marathon(三)

接上文mesoszookeeper管理docker集群,已安装并成功启动mesos master和mesos slave。 https://www.toutiao.com/article/7221354604351537698/?log_from6b55db495da1d_1681366356776 这个教程主要演示部署和使用marathon 1.1 marathon概述 1.2 下载和解压marath…

ip https证书360

https证书主要作用是保障网络安全,在http协议的基础上通过SSL/TLS加密技术实现安全通信协议。对客户端以及服务器之间的传输数据进行加密,确保数据的完整性和机密性,维护用户隐私。通过HTTPS协议,我们可以安全地进行在线购物、网上…

腾讯云邮件推送如何设置?群发邮件的技巧?

腾讯云邮件推送功能有哪些?怎么有效使用邮件推送? 腾讯云邮件推送以其稳定、高效的特点,受到了众多企业的青睐。那么,腾讯云邮件推送如何设置呢?又有哪些群发邮件的技巧呢?下面AokSend就来详细探讨一下。 …

k8s-身份认证与权限

认证概述 Kubernetes作为一个分布式集群的管理工具,保证集群的安全性是其一个重要的任务。所谓的安全性其实就是保证对Kubernetes的各种客户端进行认证和鉴权操作。 在Kubernetes集群中,客户端通常有两类: User Account:一般是独…

LabVIEW高效目标跟踪系统

LabVIEW高效目标跟踪系统 随着机器视觉技术的飞速发展,设计和实现高效的目标跟踪系统成为了众多领域关注的焦点。基于LabVIEW平台,结合NI Vision机器视觉库,开发了一种既高效又灵活的目标跟踪系统。通过面向对象编程方法和队列消息处理器程序…

Web前端一套全部清晰 ② day2 HTML 标签之文字排版,图片、链接、音视频链接

虽然辛苦&#xff0c;我还是会选择那种滚烫的人生 —— 24.4.25 HTML初体验 1.HTML定义 HTML 超文本标记语言 超文本 —— 链接 标记 —— 标记也叫标签&#xff0c;带尖括号的文本 标签语法 开始标签 需要加粗的文字 结束标签 标签成对出现&#xff0c;中间包裹内容 <>里…

从零入门区块链和比特币(第一期)

欢迎来到我的区块链与比特币入门指南&#xff01;如果你对区块链和比特币感兴趣&#xff0c;但不知道从何开始&#xff0c;那么你来对地方了。本博客将为你提供一个简明扼要的介绍&#xff0c;帮助你了解这个领域的基础知识&#xff0c;并引导你进一步探索这个激动人心的领域。…

绘唐科技AIGC怎么激活

绘唐科技AIGC怎么激活绘唐科技AIGC怎么激活绘唐科技AIGC怎么激活绘唐科技AIGC怎么激活 这里激活免费3天体验 Docshttps://qvfbz6lhqnd.feishu.cn/wiki/D3YLwmIzmivZ7BkDij6coVcbn7W

RabbitMQ(高级)笔记

一、生产者可靠性 &#xff08;1&#xff09;生产者重连&#xff08;不建议使用&#xff09; logging:pattern:dateformat: MM-dd HH:mm:ss:SSSspring:rabbitmq:virtual-host: /hamllport: 5672host: 192.168.92.136username: hmallpassword: 123listener:simple:prefetch: 1c…

新标准日本语初下 课后练习作业

新版标准日本语初下 第二十五課 これは明日会議で使う資料です 第二十五課 これは明日会議で使う資料です &#xff12;&#xff14;&#xff0d;&#xff10;&#xff14;&#xff0d;&#xff12;&#xff16; 練習&#xff12;&#xff15;&#xff0d;1&#xff0d;1 例…

浅谈大数据时代下的电商风控||电商数据API接口

抢抢抢&#xff01;最后1天&#xff0c;双十一直播活动来啦&#xff01;抢直播专属优惠…… 视频号 随着大数据时代的兴起&#xff0c;互联网电商风控已经从无风控、人工抽取规则为主的简易规则模型发展到当前基于大数据的风控。与金融风控不同&#xff0c;互联网电商风控呈现出…

汽车新智能图谱里:理解腾讯的AI TO B路径

将自身的C2B产品和产业理解充分AI化&#xff0c;在自身内部场景率先验证跑通后&#xff0c;进而释放给产业伙伴&#xff0c;对应到具体的需求痛点&#xff0c;一起打磨对应的行业AI模型。 这也恰是腾讯“实用”标签背后的AI产业路径。 作者|皮爷 出品|产业家 成本、性价…