[内存泄漏][PyTorch](create_graph=True)

PyTorch保存计算图导致内存泄漏

  • 1. 内存泄漏定义
  • 2. 问题发现背景
  • 3. pytorch中关于这个问题的讨论

1. 内存泄漏定义

  内存泄漏(Memory Leak)是指程序中已动态分配的堆内存由于某种原因程序未释放或无法释放,造成系统内存的浪费,导致程序运行速度减慢甚至系统崩溃等严重后果。

2. 问题发现背景

  在使用深度学习求解PDE时,由于经常需要计算高阶导数,在pytorch框架下写的代码需要用到torch.autograd.grad(create_graph=True)或者torch.backward(create_graph=True)这个参数,然后发现了这个内存泄漏的问题。如果要保存计算图用来计算高阶导数,那么其所占的内存不会被释放,会一直占用。也就是如果设置create_graph=True,那么其保存的计算图所占的内存只有在程序运行结束时才会释放,这样导致了一个问题,即如果在循环中需要保存计算图,例如每个循环都需要计算一次黑塞矩阵,那么这个内存占用就会越来越多,最终导致out of memory报错。
在这里插入图片描述

3. pytorch中关于这个问题的讨论

  官网中关于这个问题的讨论见https://github.com/pytorch/pytorch/issues/7343,这里提出的内存泄漏的例子如下:

import torch
import gc_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
y.backward(torch.ones_like(y), create_graph=True)
del x, y
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
可以看到虽然删除了变量,依然造成了内存泄漏。这里红色的警告就是关于这个内存泄漏的问题。

UserWarning: Using backward() with create_graph=True will create a reference cycle between
the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad 
when creating the graph to avoid this. If you have to use this function, make sure to reset 
the .grad fields of your parameters to None after use to break the cycle and avoid the leak. 
(Triggered internally at C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\engine.cpp:1000.)
allow_unreachable=True, accumulate_grad=True) 
# Calls into the C++ engine to run the backward pass

看这个UserWarning,提示我们使用torch.autograd.grad()函数可以避免这个梯度泄漏,然后对代码进行改动:

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
# y.backward(torch.ones_like(y), create_graph=True)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
结果显示没有梯度泄漏。进一步,我们求一下二阶导数:

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z, q
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
结果也没有内存泄漏。但是,如果我们不删除结果二阶导数q,这样是出于如果写在一个函数中,需要将q作为return值返回的情况。

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
del x, y, z
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述
可以看到,这还是会导致一部分内存泄漏。那如果我们一定要返回q,又不想内存泄漏,这里本人想到一直办法,就是将q转换成numpy数据类型,返回这个numpy数组,就不会导致内存泄漏了。

import torch
import gc
from torch.autograd import grad_ = torch.randn(1, device='cuda')
del _
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())
x = torch.randn(1, device='cuda', requires_grad=True)
y = x.tanh()
z = grad(y, x, retain_graph=True, create_graph=True)
print(torch.cuda.memory_allocated())
q = grad(z, x)
k = q[0].cpu().numpy()
del x, y, z, q
torch.cuda.synchronize()
gc.collect()
print(torch.cuda.memory_allocated())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/191065.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

element表格头部加入图标

首先看看效果 下面是代码 <el-table-column prop"integralBalance"><template slot"header" slot-scope"scope"><div style"display: flex;justify-content: center;align-items: center;">积分余额<i class&qu…

代码随想录算法训练营Day 54 || 392.判断子序列、115.不同的子序列

392.判断子序列 力扣题目链接(opens new window) 给定字符串 s 和 t &#xff0c;判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些&#xff08;也可以不删除&#xff09;字符而不改变剩余字符相对位置形成的新字符串。&#xff08;例如&#xff0c;&quo…

二叉树的遍历(非递归版)

文章目录 二叉树的前序遍历二叉树的中序遍历二叉树的后序遍历 正文开始前给大家推荐个网站&#xff0c;前些天发现了一个巨牛的 人工智能学习网站&#xff0c; 通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。 点击跳转到网站。 二叉树的前序遍历 用递归实…

ROSCon 2023 大会回顾

系列文章目录 文章目录 系列文章目录前言一、会议内容二、其他活动 前言 我们与 ROSCon 2023 全体 700 多名与会者的合影。 视频回放链接 一、会议内容 ROSCon 2023 是我们第十二届年度 ROS 开发者大会&#xff0c;于 2023 年 10 月 18 日至 20 日在路易斯安那州新奥尔良举行。…

【论文精读3】CasMVSNet

模型处理过程&#xff1a; 一. 问题引入 基于学习的MVS算法因为受到显存的限制&#xff0c;输出的深度图的空间分辨率只有输入图像的1/16大小&#xff08;长宽均为输入图像的1/4大小&#xff09;。以MVSNet为例&#xff0c;对于16001184大小的输入图像&#xff0c;需要构建hwD…

【机器学习Python实战】线性回归

&#x1f680;个人主页&#xff1a;为梦而生~ 关注我一起学习吧&#xff01; &#x1f4a1;专栏&#xff1a;机器学习python实战 欢迎订阅&#xff01;后面的内容会越来越有意思~ ⭐内容说明&#xff1a;本专栏主要针对机器学习专栏的基础内容进行python的实现&#xff0c;部分…

Web之CSS笔记

Web之HTML、CSS、JS 二、CSS&#xff08;Cascading Style Sheets层叠样式表&#xff09;CSS与HTML的结合方式CSS选择器CSS基本属性CSS伪类DIVCSS轮廓CSS边框盒子模型CSS定位 Web之HTML笔记 二、CSS&#xff08;Cascading Style Sheets层叠样式表&#xff09; Css是种格式化网…

浏览器页面被恶意控制时的解决方法

解决360流氓软件控制浏览器页面 提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、接受360安全卫士的好意&#xff08;尽量不要选&#xff09;二、拒绝360安全卫士的好意&#xff08;强烈推荐&#xff09;第…

Leetcode—876.链表的中间结点【简单】

2023每日刷题&#xff08;三十三&#xff09; Leetcode—876.链表的中间结点 实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* middleNode(struct ListNode* head) {struct ListNod…

Redis持久化机制详解

使用缓存的时候&#xff0c;我们经常需要对内存中的数据进行持久化也就是将内存中的数据写入到硬盘中。大部分原因是为了之后重用数据&#xff08;比如重启机器、机器故障之后恢复数据&#xff09;&#xff0c;或者是为了做数据同步&#xff08;比如 Redis 集群的主从节点通过 …

传输层——TCP协议

文章目录 一.TCP协议二.TCP协议格式1.序号与确认序号2.窗口大小3.六个标志位 三.确认应答机制&#xff08;ACK&#xff09;四.超时重传机制五.连接管理机制1.三次握手2.四次挥手 六.流量控制七.滑动窗口八.拥塞控制九.延迟应答十.捎带应答十一.面向字节流十二.粘包问题十三.TCP…