大模型LLM训练显存消耗详解

参考论文:ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

大模型的显存消耗一直都是面试常见的问题,这次我就彻彻底底的根据论文ZeRO中的调研和分析做一次分析

显存消耗的两个部分:Model States(跟模型的参数量和优化器相关)Residual Memory Consumption(跟训练时的batchsize,序列长度有关)

接下来,我就从这两个部分详细分析:


Model States

在这里插入图片描述

一个模型在显存消耗上,分为三个部分

  1. Optimizer States
  2. Gradients States
  3. Parameters States

更加具体的说,对于一个模型参数(Parameters)我们需要维护维护三个不同方面的参数
我们假设:模型的参数量大小为ModelSize

Parameters States

故名思义就是模型本身的权重参数,对于一个使用Float32存储的参数,我们需要32/8=4byte进行存储。

Gradients States

记录参数的梯度,对于一个使用Float32存储的参数,我们同样需要一个相同大小的梯度(4byte)保存它的梯度。

Optimizer States

对于最常用的Adam优化器以及其变体,对于一个使用Float32存储的参数需要维护两个额外的参数momentumvariance,也就是需要2*4=8byte进行保存


总的来说,对于Float32保存的模型来说,我们显存消耗是16(4+4+8)* ModelSize byte

但是对于半精度保存的模型(Float16),每个参数Parameters StatesGradients States的显存消耗都是2byte。在训练时,我们仍然需要保存其Float32的Parameters States用以加速运算,同时Adam优化器的两个参数momentumvariance同样也是Float32形式保存的,每个参数消耗的即为4+4+4=12 byte。所以半精度保存的模型,计算时的显存消耗仍然为16(2+2+12)* ModelSize byte


Residual Memory Consumption

剩下的显存消耗跟我们训练时的配置有关
主要有三个部分

  1. Activations
  2. Temporary buffers
  3. Memory Fragmentation

Activations

对于一个transformer based的模型来说,Activations的显存消耗和如下公式是成比例的:

number of transformer layers × hidden dimensions × sequence length × batch size

对于GPT2来说,这个比例大约为12

Temporary buffers 和 Memory Fragmentation

这两个参数不容易具体量化,Temporary buffers是多卡训练过程中为了提升梯度计算的效率,通常会执行一些类似于gradient all-reducegradient norm computation等操作,把数据集合到一个临时的缓存区中,这个临时区也会占用相当数量的显存

Memory Fragmentation,内存碎片的产生会导致内存空间的利用效率低下,即使有空余空间但是不足以分配给一个新的内存请求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/475644.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python学习 --- 文件操作

1.文件的基础操作 --- 打开,关闭与读文件 文件的主要操作有:打开,关闭与读写 1. name 是文件的路径,要用字符串的形式来表示 2. mode 模式也要用字符串的形式来表示 3.open函数会返回一个文件对象,该文件对象指向的是被打开的文件 1.read方法在调用完之后会生成一个指…

html表格标签(下):lable标签,select标签和textara标签

html表格标签(下):lable标签,select标签和textarea标签 lable标签 搭配 input 使用,点击 label 标签就能选中对应的单选/复选框, 能够提升用户体验。 for 属性: 指定当前 label 和哪个相同 id 的 input 标签对应 (此时点击才是有用的) 运行效果&#x…

Java集合篇之深度解析Queue,单端队列、双端队列、优先级队列、阻塞队列

写在开头 队列是Java中的一个集合接口,之前的文章已经讲解了List和Set,那么今天就来唠一唠它吧。队列的特点:存储的元素是有序的、可重复的。 队列的两大接口Queue vs Deque Queue 是单端队列,只能从一端插入元素,另…

微信小程序之开发会议OA项目

目录 前言 本篇目标 首页 会议 投票 个人中心 会议OA项目-首页 配置 tabbar mock工具 page swiper 会议信息 会议OA项目-会议 自定义tabs组件 会议管理 会议OA项目-投票 会议OA项目-个人中心 前言 文章含源码资源,投票及个人中心详细自行查看…

【Kuiperinfer】笔记01 项目预览与环境配置

学习目标 实现一个深度学习推理框架设计、编写一个计算图实现常见的算子,例如卷积、池化、全连接学会如何进行算子的优化加速使用自己的推理框架推理常见模型,检查结果是否能够和torch对齐 什么是推理框架? 推理框架用于对已经训练完成的模…

JavaScript中延迟加载的方式有哪些

在web前端开发中,性能优化一直是一个非常重要的话题。当我们开发一个页面时,为了提高用户的体验和页面加载速度,我们往往需要采用一些延迟加载的技术。JavaScript中延迟加载的方式有很多种,下面我将为大家详细介绍几种常用的方式。…

Spring MVC(基于 Spring4.x)基础学习

一、SpringMVC概述 二、SpringMVC的HelloWorld 三、使用RequestMapping映射请求 四、映射请求参数&请求头 五、处理模型数据 六、视图和视图解析器 七、RESTful CRUD 八、SpringMVC表单标签&处理静态资源 九、数据转换&数据格式化&数据校验 十、处理JSON:使用…

IO进程-day1

1、使用fgets统计给定文件的行数。 #include<stdio.h> #include<string.h> #include<stdlib.h>int main(int argc, const char *argv[]) {if(argc ! 2){printf("inout file error\n");printf("usage:./a.out srcfile destfile\n");ret…

ClickHouse从入门到精通(高级)

第1章 Explain查看执行计划 第2章 建表优化 第3章 ClickHouse语法优化规则 第4章 查询优化 第5章 数据一致性(重点) 第6章 物化视图 第7章 MaterializeMySQL引擎 第8章 常见问题排查

PHP支持的伪协议

php.ini参数设置 在php.ini里有两个重要的参数allow_url_fopen、allow_url_include。 allow_url_fopen:默认值是ON。允许url里的封装协议访问文件&#xff1b; allow_url_include:默认值是OFF。不允许包含url里的封装协议包含文件&#xff1b; 各协议的利用条件和方法 php:/…

萝卜大杂烩 | 把微信接入ChatGPT,变成聊天机器人竟然这么简单!(一起来尝试吧~)

本文来源公众号“萝卜大杂烩”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;把微信接入ChatGPT&#xff0c;变成聊天机器人竟然这么简单&#xff01; 最近的 ChatGPT 又再次火热起来了&#xff0c;各种周边工具也是层出不穷&…

mfc140u.dll文丢失导致应用程序无法正常,有哪些解决办法

mfc140u.dll是Microsoft Foundation Classes&#xff08;MFC&#xff09;的一个重要组件&#xff0c;它提供了许多用于开发Windows应用程序的功能和工具。然而&#xff0c;当系统或应用程序升级、恶意软件感染或文件损坏以及用户错误操作等情况发生时&#xff0c;mfc140u.dll文…