Gemma模型论文详解(附源码)

原文链接:Gemma模型论文详解(附源码)

1. 背景介绍

Gemma模型是在2023.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不仅提供了预训练的checkpoints,还提供了用于对话、指令跟随等fine-tune的checkpoints。在QA问答、常识。在11

在这里插入图片描述

2. 模型介绍

2.1 模型结构

Gemma模型使用了transformer decoder结构进行训练,训练的上下文大小为8192个token,模型参数如下:
在这里插入图片描述

相比原始transformer结构的区别:

  • Multi-Query Attention:7B模型使用了multi-head attention,2B模型使用了multi-query attention (with 𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1)。对比llama2中用了group-query attention
    在这里插入图片描述

  • RoPE Embeddings: 不使用绝对位置编码,在每一层前加下RoPE Embedding,同时共享输入与输出层的embedding权重。

  • GeGLU Activations: ReLU的激活替换为GeGLU的激活。对比llama中用了swiglu。

  • Normalizer Location: 在transformer的每一层layer的前后都进行规一化,这里使用RMSNorm做为规一化层。

2.2 训练搭建

Gemma使用TPUv5e进行训练;一个pod中有256块TPUv5e芯片,256块芯片被设计为16X16的2D拓扑;Gemma-7B使用16个pods(4096块卡)进行训练,Gemma-2B使用2个pods(512块卡)。7B模型在一个pod内使用16路模型并行和16路数据并行,2B模型在一个pod内使用256路数据并行。优化器状态使用ZeRO-3进行切分,减少显存占用。在pod外使用类似Pathways的方式减少数据复制的成本。

和Gemini模型训练一样,综合了Jax和Pathways的单控制器single controller编程范式,使用单个python进程编排整个训练; 使用GSPMD partitioner用于训练step的计算,使用XLA compiler减少中间结果的大小。

2.3 训练数据

Gemma 2B和7B分别基于2T和6T个token进行训练,token来源于纯英文的文本,内容包括网页、数学、代码等。使用SentencePiece的tokenizer,字典大小有256K个token。数据过滤使用基于模型的分类器去除有害的、低质量的内容。最后采用类似Gemini的方式进行训练数据的混合,提升高质量数据的占比。

2.4 指令微调(Instruction Tuning)

2B和7B进行有监督微调(SFT)训练中使用混合生成数据和人工标注的prompt文本对,同时进行RLHF训练。在SFT阶段,基于给定的一个prompt,通过测试模型生成多个响应的回答结果,通过一个更大更好的模型进行结果的好坏判断。基于不同的侧重方向(指令跟随/事实/创造性/安全等)构建不同的prompt。使用多种基于LM的自动判断方法,比如chain-of-thought prompting

训练和推理过程中使用相同的数据格式,格式的设计重点在于两点,一个是确定多轮对话中的角色,一个是确定一轮对话的开始结束。对应格式标记和示例的训练数据如下:

在这里插入图片描述
在这里插入图片描述

3. 源码

  • Tensorflow实现的源码在github google-deepmind/gemma中,PyTorch实现的源码在github google/gemma_pytorch。

  • 模型的配置在gemma/config.py文件中, 7B与2B区别主要在于num_hidden_layers/num_attention_heads/num_key_value_heads/hidden_size/intermediate_size

@dataclasses.dataclass
class GemmaConfig:# The number of tokens in the vocabulary.vocab_size: int = 256000# The maximum sequence length that this model might ever be used with.max_position_embeddings: int = 8192# The number of blocks in the model.num_hidden_layers: int = 28# The number of attention heads used in the attention layers of the model.num_attention_heads: int = 16# The number of key-value heads for implementing attention.num_key_value_heads: int = 16# The hidden size of the model.hidden_size: int = 3072# The dimension of the MLP representations.intermediate_size: int = 24576# The number of head dimensions.head_dim: int = 256# The epsilon used by the rms normalization layers.rms_norm_eps: float = 1e-6# The dtype of the weights.dtype: str = 'bfloat16'# Whether a quantized version of the model is used.quant: bool = False# The path to the model tokenizer.tokenizer: Optional[str] = 'tokenizer/tokenizer.model'def get_dtype(self) -> Optional[torch.dtype]:"""Gets the torch dtype from the config dtype string."""return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)def get_config_for_7b() -> GemmaConfig:return GemmaConfig()def get_config_for_2b() -> GemmaConfig:return GemmaConfig(num_hidden_layers=18,num_attention_heads=8,num_key_value_heads=1,hidden_size=2048,intermediate_size=16384)
  • 模型定义在gemma/model.py文件中,GemmaDecoderLayer的定义如下:
class GemmaDecoderLayer(nn.Module):def __init__(self,config: gemma_config.GemmaConfig,):super().__init__()self.self_attn = GemmaAttention(hidden_size=config.hidden_size,num_heads=config.num_attention_heads,num_kv_heads=config.num_key_value_heads,head_dim=config.head_dim,quant=config.quant,)self.mlp = GemmaMLP(hidden_size=config.hidden_size,intermediate_size=config.intermediate_size,quant=config.quant,)self.input_layernorm = RMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = RMSNorm(config.hidden_size,eps=config.rms_norm_eps)
  • GeGLU的实现跟llama的swiglu不同,geglu相比glu区是采用了gelu的激活,以下是glu的计算示例图:
    在这里插入图片描述

代码参考如下,代码中self.gate_proj对应上图中的B矩阵,gate相当于 σ ( B ) \sigma(B) σ(B)self.up_proj对应上图中的A矩阵.

class GemmaMLP(nn.Module):def __init__(self,hidden_size: int,intermediate_size: int,quant: bool,):super().__init__()self.gate_proj = Linear(hidden_size, intermediate_size, quant)self.up_proj = Linear(hidden_size, intermediate_size, quant)self.down_proj = Linear(intermediate_size, hidden_size, quant)def forward(self, x):gate = self.gate_proj(x)gate = F.gelu(gate)up = self.up_proj(x)fuse = gate * upoutputs = self.down_proj(fuse)return outputs

4. 参考

  • google-deepmind/gemma
  • Gemma 开放模型
  • Gemma: Open Models Based on Gemini Research and Technology
  • gemma-open-models
  • github google/gemma_pytorch
  • github google-deepmind/gemma
  • Grouped Query Attention论文阅读
  • SwiGLU论文阅读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/487099.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker部署seata1.6.0

docker部署seata1.6.0 Seata 是 阿里巴巴 开源的 分布式事务中间件,解决 微服务 场景下面临的分布式事务问题。需要先搭建seata服务端然后与springcloud的集成以实现分布式事务控制的过程 ,项目中只需要在远程调用APi服务的方法上使用注解 GlobalTransa…

Selenium定位不到元素怎么办?一定要这么做

在使用Selenium进行自动化测试时,碰到无法定位元素该怎么办?这里总结了9种情况下的元素定位方法: 1、frame/iframe表单嵌套 WebDriver只能在一个页面上对元素识别与定位,对于frame/iframe表单内嵌的页面元素无法直接定位。 解决…

如何用jmeter请求application/octet-stream,image/jpeg

用postman调用时: 用jmeter: 注意上图不要勾选,不然会把所有的内容都以二进制传进去,我们不勾选只传二进制的图片内容,勾选了会把MIME类型、参数名称都转为二进制传进去。会报错。

最优传输(Optimal Transport)

最优传输(Optimal Transport)是一种数学理论和计算方法,用于描述两个概率分布之间的距离或者对应关系。它的核心概念是如何以最佳方式将一组资源(如质量、能量等)从一个位置传输到另一个位置。 基本概念: …

【SpringCloudAlibaba系列--OpenFeign组件】OpenFeign的配置、使用与测试以及OpenFeign的负载均衡

步骤一 准备两个服务,provider和consumer 本文使用kotlin语言 provider是服务的提供者,由provider连接数据库 RestController RequiredArgsConstructor RequestMapping("/provider/depart") class DepartController(private val departServ…

vscode与vue环境配置

一、下载并安装VScode 安装VScode 官网下载 二、配置node.js环境 安装node.js 官网下载 会自动配置环境变量和安装npm包(npm的作用就是对Node.js依赖的包进行管理),此时可以执行 node -v 和 npm -v 分别查看node和npm的版本号: 配置系统变量 因为在执…

springboot213大学生心理健康管理系统的设计与实现

大学生心理健康管理系统的设计与实现 摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理,然而,随着近些年信息技术的迅猛发展,让许多比较老套的信息管理模式进行了更新迭代,试卷信息因为其管理内容繁杂,管理…

套接字与套接字编程

对于刚刚学习计算机网络:自顶向下的同学们,在观看了中科大的视频---TCP Socket以及UDP Socket会感到些许疑惑,不过没事,在这篇小文章将会为你解开Socket的神秘面纱 什么是Socket?: Socket 是一套用于不同主机之间通信…

JAVA代码审计之XSS漏洞

Part1 漏洞案例demo&#xff1a; 没有java代码审计XSS漏洞拿赏金的案例。 所以将就看看demo吧 漏洞原理&#xff1a;关于XSS漏洞的漏洞原理核心其实没啥好说的&#xff0c;网上一查一大堆 反射性XSS漏洞 <% page language"java" contentType"text/html; c…

java——File类和字符集

目录 File类File类的常用操作&#xff1a;案例&#xff1a;文件搜索的实现案例&#xff1a;递归文件夹删除 字符集几种常见的字符集总结字符集的编码和解码 File类 File是java.io.包下的类&#xff0c;File类的对象&#xff0c;用于代表当前操作系统的文件&#xff08;可以是文…

ESP8266智能家居(4)——开发APP基础篇

1.前期准备 安装好Android studio 开发环境 准备一台完好的安卓手机 手机要处于开发者模式 设置 --->关于手机---> 一直点击版本号 &#xff08;不同手机进入开发者模式的步骤可能不太一样&#xff09; 进入开发者模式后&#xff0c;找到辅助功能&#xff0c;打开开…

厌倦了混乱的代码?掌握编写干净代码库的艺术

对于入门的开发人员来说&#xff0c;虽然克服了最初的障碍&#xff0c;学会了编程&#xff0c;找到了理想的工作。但其编程旅程并没有就此结束。他们面临真正的挑战&#xff1a;如何编写更好的代码。这不仅仅是为了完善功能&#xff0c;还要编写出经得起时间考验的优雅、可维护…