Pytorch自动混合精度的计算:torch.cuda.amp.autocast

1 autocast介绍

1.1 什么是AMP?

默认情况下,大多数深度学习框架都采用32位浮点算法进行训练。2017年,NVIDIA研究了一种用于混合精度训练的方法,该方法在训练网络时将单精度(FP32)与半精度(FP16)结合在一起,并使用相同的超参数实现了与FP32几乎相同的精度。

FP16也即半精度是一种计算机使用的二进制浮点数据类型,使用2字节存储。而FLOAT就是FP32。

1.2 autocast作用

torch.cuda.amp.autocast是PyTorch中一种混合精度的技术(仅在GPU上训练时可使用),可在保持数值精度的情况下提高训练速度和减少显存占用。

    def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True):

它是一个自动类型转换器,可以根据输入数据的类型自动选择合适的精度进行计算,从而使得计算速度更快,同时也能够节省显存的使用。使用autocast可以避免在模型训练过程中手动进行类型转换,减少了代码实现的复杂性。

在深度学习中,通常会使用浮点数进行计算,但是浮点数需要占用更多的显存,而低精度数值可以在减少精度的同时,减少缓存使用量。因此,对于正向传播和反向传播中的大多数计算,可以使用低精度型的数值,提高内存使用效率,进而提高模型的训练速度。

1.3 autocast原理

autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:

  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

核心代码:autocast_mode.cpp

2 autocast优缺点

PyTorch中的autocast功能是一个性能优化工具,它可以自动调整某些操作的数据类型以提高效率。具体来说,它允许自动将数据类型从32位浮点(float32)转换为16位浮点(float16),这通常在使用深度学习模型进行训练时使用。

2.1 autocast优点

  • 提高性能:使用16位浮点数(half precision)进行计算可以在支持的硬件上显著提高性能,特别是在最新的GPU上。

  • 减少内存占用:16位浮点数占用的内存比32位少,这意味着在相同的内存限制下可以训练更大的模型或使用更大的批量大小。

  • 自动管理autocast能够自动管理何时使用16位浮点数,何时使用32位浮点数,这降低了手动管理数据类型的复杂性。

  • 保持精度:尽管使用了较低的精度,但autocast通常能够维持足够的数值精度,对最终模型的准确度影响不大。

2.2 autocast缺点

  • 硬件要求:并非所有的GPU都支持16位浮点数的高效运算。在不支持或优化不足的硬件上,使用autocast可能不会带来性能提升。

  • 精度问题:虽然在大多数情况下精度损失不显著,但在某些应用中,尤其是涉及到小数值或非常大的数值范围时,降低精度可能会导致问题。

  • 调试复杂性:由于autocast在模型的不同部分自动切换数据类型,这可能会在调试时增加额外的复杂性。

  • 算法限制:某些特定的算法或操作可能不适合在16位精度下运行,或者在半精度下的实现可能还不成熟。

  • 兼容性问题:某些PyTorch的特性或第三方库可能还不完全支持半精度运算。

在实际应用中,是否使用autocast通常取决于特定任务的需求、所使用的硬件以及对性能和精度的权衡。通常,对于大多数现代深度学习应用,特别是在使用最新的GPU时,使用autocast可以带来显著的性能优势。

3 使用示例

3.1 autocast混合精度计算

with autocast(): 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算,并且这里使用了GPU进行加速。使用示例如下:

# 导入相关库
import torch
from torch.cuda.amp import autocast# 定义一个模型
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.linear = torch.nn.Linear(10, 1)def forward(self, x):with autocast():x = self.linear(x)return x# 初始化数据和模型
x = torch.randn(1, 10).cuda()
model = MyModel().cuda()# 进行前向传播
with autocast():output = model(x)# 计算损失
loss = output.sum()# 反向传播
loss.backward()

3.2 autocast与GradScaler一起使用

因为autocast会损失部分精度,从而导致梯度消失的问题,并且经过中间层时可能计算得到inf导致最终loss出现nan。所以我们通常将GradScaler与autocast配合使用来对梯度值进行一些放缩,来缓解上述的一些问题。

from torch.cuda.amp import autocast, GradScalerdataloader = ...
model = Model.cuda(0)
optimizer = ...
scheduler = ...
scaler = GradScaler()  # 新建GradScale对象,用于放缩
for epoch_idx in range(epochs):for batch_idx, (dataset) in enumerate(dataloader):optimizer.zero_grad()dataset = dataset.cuda(0)with autocast():  # 自动混精度logits = model(dataset)loss = ...scaler.scale(loss).backward()  # scaler实现的反向误差传播scaler.step(optimizer)  # 优化器中的值也需要放缩scaler.update()  # 更新scalerscheduler.step()
...

4 可能出现的问题

使用autocast技术进行混精度训练时loss经常会出现'nan',有以下三种可能原因:

  • 精度损失,有效位数减少,导致输出时数据末位的值被省去,最终出现nan的现象。该情况可以使用GradScaler(上文所示)来解决。
  • 损失函数中使用了log等形式的函数,或是变量出现在了分母中,并且训练时,该数值变得非常小时,混精度可能会让该值更接近0或是等于0,导致了数学上的log(0)或是x/0的情况出现,从而出现'inf'或'nan'的问题。这种时候需要针对该问题设置一个确定值。例如:当log(x)出现-inf的时候,我们直接将输出中该位置的-inf设置为-100,即可解决这一问题。
  • 模型内部存在的问题,比如模型过深,本身梯度回传时值已经非常小。这种问题难以解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/175777.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

20.2 设备树中的 platform 驱动编写

一、设备树下的 platform 驱动 platform 驱动框架分为总线、设备和驱动,总线不需要我们去管理,这个是 Linux 内核提供。在有了设备树的前提下,我们只需要实现 platform_driver 即可。 1. 修改 pinctrl-stm32.c 文件 先复习一下 pinctrl 子系…

网络运维Day17

文章目录 什么是数据库MySQL介绍实验环境准备构建MySQL服务连接数据库修改root密码 数据库基础常用的SQL命令分类SQL命令使用规则MySQL基本操作创建库创建表查看表结构 记录管理命令 数据类型数值类型 数据类型日期时间类型时间函数案例枚举类型 约束条件案例修改表结构添加新字…

【源码运行打包】kkFileView 下载与安装

目录导航 1、源码下载2、IDEA部署2.1、克隆代码2.2、配置maven2.3、下载依赖报错2.4、执行maven打包 3、Centos7.9部署启动3.1、环境要求3.2、部署jdk环境3.3、上传部署包3.4、解压部署包3.5、访问测试3.6、解决乱码 4、使用指南5、部署包下载 文件预览服务 kkFileView &#x…

Linux socket编程(3):利用fork实现服务端与多个客户端建立连接

上一节,我们实现了一个客户端/服务端的Socket通信的代码,在这个例子中,客户端连接上服务端后发送一个字符串,而服务端接收到字符串并打印出来后就关闭所有套接字并退出了。 上一节的代码较为简单,在实际的应用中&…

CloudCompare 二次开发(21)——点云平面拟合

目录 一、概述二、代码集成三、结果展示本文由CSDN点云侠原创,原文链接。爬虫网站自重。 一、概述 由CloudCompare——点云平面拟合一文的实际操作知:CloudCompare软件中的已经集成了点云平面拟合功能,但是无法输出平面的标准方程。因此,本文在原有算法的基础上进行修改,…

PP-ChatOCRv2、PP-TSv2、大模型半监督学习工具...PaddleX新特性等你来pick!

小A是一名刚刚毕业的算法工程师,有一天,他被老板安排了一个活,要对一批合同扫描件进行自动化信息抽取,输出结构化的分析报表。OCR问题不大,但是怎么进行批量的结构化信息抽取呢?小A陷入了苦苦思索… 小B是…

Maven 的 spring-boot-maven-plugin 红色报错

1、想要处理此情况&#xff0c;在工具下面加上指定的版本号。 2、给自己的maven的setting文件加工一下。 <mirrors><!--阿里云镜像1--><mirror><id>aliyunId</id><mirrorOf>central</mirrorOf><name>aliyun maven</name>…

【京东API】商品详情+搜索商品列表接口

利用电商API获取数据的步骤 1.申请API接口&#xff1a;首先要在相应电商平台上注册账号并申请API接口。 2.获取授权&#xff1a;在账号注册成功后&#xff0c;需要获取相应的授权才能访问电商API。 3.调用API&#xff1a;根据电商API提供的请求格式&#xff0c;通过编程实现…

PostGIS学习教程五:数据

教程的数据是有关纽约市的四个shapefile文件和一个包含社会人口经济数据的数据表。在前面一节我们已经将shapefile加载为PostGIS表&#xff0c;在后面我们将添加社会人口经济数据。 下面描述了每个数据集的记录数量和表属性。这些属性值和关系是我们以后分析的基础。 要在pgAdm…

Hadoop-HDFS架构与设计

HDFS架构与设计 一、背景和起源二、HDFS概述1.设计原则1.1 硬件错误1.2 流水访问1.3 海量数据1.4 简单一致性模型1.5 移动计算而不是移动数据1.6 平台兼容性 2.HDFS适用场景3.HDFS不适用场景 三、HDFS架构图1.架构图2.Namenode3.Datanode 四、HDFS数据存储1.数据块存储2.副本机…

647. 回文子串 516.最长回文子序列

647. 回文子串 题目&#xff1a; 给你一个字符串 s &#xff0c;请你统计并返回这个字符串中 回文子串 的数目。 回文字符串 是正着读和倒过来读一样的字符串。 子字符串 是字符串中的由连续字符组成的一个序列。 具有不同开始位置或结束位置的子串&#xff0c;即使是由相…

一台电脑存在多个版本的python . python切换使用

1、安装库的命令 py -3.X -m pip install XXX 2、查看已安装库的命令 py -3.X -m pip list 3、pip更新的命令 py -3.X -m pip install --upgrade pip 4、切换 默认那个不用改&#xff0c;新建那个需要改。 A.首先在环境变量——>系统变量——>Path中加入python安装…