with torch.no_grad()在Pytorch中的应用

with torch.no_grad()在Pytorch中的应用

参考:
https://blog.csdn.net/qq_24761287/article/details/129773333
https://blog.csdn.net/sazass/article/details/116668755

在学习Pytorch时,老遇到 with torch.no_grad(),搞不清其作用,现在详细了解一下。

1、with torch.no_grad()含义

torch.no_grad() 上下文管理器通常用于那些不需要计算梯度的操作,例如在模型评估或推断时。在这些情况下,关闭自动求导功能可以提高代码执行效率,因为不需要计算梯度的操作通常比需要计算梯度的操作更快。

with torch.no_grad():# some code that doesn't require gradients

2、with torch.no_grad()运用场景

简单来说,如果不需要在接下来步骤中用到所计算的式子的梯度,就可以使用with torch.no_grad()来提升运算速度。

2.1 只评估模型

在模型的评估模式下,对验证数据集进行前向传播并计算性能指标,而不计算或存储梯度信息。这有助于节省内存和提高代码执行效率。在此处能使用with torch.no_grad()的根本原因是我们不依赖于模型得到的结果去执行梯度下降操作,例如:

model.eval()
with torch.no_grad():for inputs, targets in validation_loader:outputs = model(inputs)# 计算指标,如准确率、损失等
2.2 此模型的计算结果不参与此模型的梯度下降

在SAC算法的更新过程中,需要用到策略policy网络的结果去更新Q网络的参数,在计算策略policy网络的结果时,该计算结果并不会用于更新policy网络,因此我们需要使用with torch.no_grad():对next_log_prob = self.policy_net.evaluate(next_state)进行修饰。

        predicted_q_value1 = self.soft_q_net1(state)predicted_q_value1 = predicted_q_value1.gather(1, action.unsqueeze(-1))predicted_q_value2 = self.soft_q_net2(state)predicted_q_value2 = predicted_q_value2.gather(1, action.unsqueeze(-1))log_prob = self.policy_net.evaluate(state)# with torch.no_grad()表示不带梯度,因为只是用policy_net得到next_log_prob,对更新Q网络不起作用with torch.no_grad():next_log_prob = self.policy_net.evaluate(next_state)# reward = reward_scale * (reward - reward.mean(dim=0)) / (reward.std(dim=0) + 1e-6) # normalize with batch mean and std; plus a small number to prevent numerical problem# Training Q Functionself.alpha = self.log_alpha.exp()target_q_min = (next_log_prob.exp() * (torch.min(self.target_soft_q_net1(next_state), self.target_soft_q_net2(next_state)) - self.alpha * next_log_prob)).sum(dim=-1).unsqueeze(-1)target_q_value = reward + (1 - done) * gamma * target_q_min  # if done==1, only rewardq_value_loss1 = self.soft_q_criterion1(predicted_q_value1,target_q_value.detach())  # detach: no gradients for the variableq_value_loss2 = self.soft_q_criterion2(predicted_q_value2, target_q_value.detach())self.soft_q_optimizer1.zero_grad()q_value_loss1.backward()self.soft_q_optimizer1.step()self.soft_q_optimizer2.zero_grad()q_value_loss2.backward()self.soft_q_optimizer2.step()
2.3 模型更新参数

当你在优化算法中更新模型参数时,不需要在参数更新步骤中计算梯度。在更新参数时使用 torch.no_grad() 可以防止出现错误,并确保计算过程正确。

def sgd(params, lr, batch_size):with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

3、with torch.no_grad()本质作用

在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。例子如下所示:

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = True)
z = torch.randn(10, 5, requires_grad = True)
with torch.no_grad():w = x + y + zprint(w.requires_grad)print(w.grad_fn)
print(w.requires_grad)输出:
False
None
False

4、为什么要使用with torch.no_grad()

如果在这些情况下没有使用torch.no_grad() 会导致哪些错误?

  1. 额外的内存消耗:计算和存储梯度需要额外的内存。在不需要梯度的情况下仍然计算梯度会导致不必要的内存消耗。在内存有限的设备上,如GPU,这可能导致内存不足而无法执行计算。
  2. 降低计算速度:计算梯度会增加计算负担。如果在不需要梯度的情况下仍然计算梯度,会降低计算速度,从而增加模型评估和推理的时间。
  3. 可能的计算错误:在某些情况下,如在优化算法中更新参数时,如果不使用torch.no_grad(),可能导致错误。例如,如果你在需要梯度的张量上执行原地操作,PyTorch会抛出RuntimeError,因为这样的操作会破坏计算图和梯度计算。

虽然在某些情况下忘记使用 torch.no_grad() 可能不会立即导致错误,但为了确保计算效率和正确性,建议在不需要梯度计算的情况下使用 torch.no_grad()。

下面给出使用with torch.no_grad()修饰不需要求导语句和不使用的对比,可以看到在同样的实际内,使用修饰会带来更好的速度。
在这里插入图片描述
效果也是使用了with torch.no_grad()更好,但是这些都是参考,毕竟每次训练的收敛速度都不太一致:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/294611.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蚂蚁集团5大开源项目获开放原子 “2023快速成长开源项目”

12月16日,在开放原子开源基金会主办的“2023开放原子开发者大会”上,蚂蚁集团主导开源的图数据库TuGraph、时序数据库CeresDB、隐私计算框架隐语SecretFlow、前端框架OpenSumi、数据域大模型开源框架DB-GPT入选“2023快速成长开源项目”。 (图…

SecureCRT连接vmware虚拟机的centos系统配置

软件版本:VMware10.0.3,centos 7,securecrt 8.7.2 1,虚拟网络编辑器选择桥接模式, 2,如果不小心删除网络,centos关机状态下,选择恢复默认设置。 3,进入linux系统&#…

2023年12月22日历史上的今天大事件早读

1808年12月22日 贝多芬《第五交响曲》在维也纳首演 1900年12月22日 清廷全盘接受议和大纲 1900年12月22日 第一辆奔驰汽车诞生 1905年12月22日 中日签订《满州善后协约》 1930年12月22日 苏联全面控制食品供应 1936年12月22日 苏联作家奥斯特洛夫斯基病逝 1942年12月22日…

使用 uiautomatorviewer 获取元素的定位信息

1. 使用 adb 连接设备(真机或模拟器) 连接夜神模拟器:adb connect 127.0.0.1:62001 连接MuMu模拟器:adb connect 127.0.0.1:7555 2. 打开 uiautomatorviewer 在 android-sdk --> tools 目录,找到 uiautomatorvie…

智能优化算法应用:基于社交网络算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于社交网络算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于社交网络算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.社交网络算法4.实验参数设定5.算法结果6.…

Java:LocalDateTime、LocalDate、Date、String相互转化

前言:LocalDateTime、LocalDate、Date三者联系 非线程安全的方法 Date类的大部分方法都不是线程安全的,比如setYear()、setMonth()、setDate()、setHours()、setMinutes()、setSeconds()等方法。这些方法都可以修改Date对象的内部状态。如果多个线程同时…

基于ssm重庆理工大学心理咨询管理子系统的分析与实现论文

摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理,然而,随着近些年信息技术的迅猛发展,让许多比较老套的信息管理模式进行了更新迭代,心理咨询预约信息因为其管理内容繁杂,管理数量繁多导致手工进行处理不能…

EasyExcel使用: RGB字体,RGB背景颜色,fillForegroundColor颜色对照表

EasyExcel使用: RGB字体,RGB背景颜色,fillForegroundColor颜色对照表 使用EasyExcel导出表格可能会对字体颜色和单元格背景颜色进行自定义的修改。 可以自定义字体颜色或者每个单元格的颜色 要想自定义颜色,需要重写CellWriteHandler接口&am…

Vue 组件传参 prop向下传递、emit向上传递

组件传参 学习了组件用法,就像一种嵌套引用关系,在这个关系中,经常会涉及相互传数据的需求,即父组件传子组件,子组件传父组件。 父、子组件的关系可以总结为 : prop 向下传递(和事件无关), emit 事件&…

AI-数学-初高中-不等式

原教学视频:初高衔接】2不等式解法梳理_哔哩哔哩_bilibili 1.一元二次不等式 2.分式不等式 3.高次函数图像 4.绝对值不等式 5.根式不等式

tsconfig.app.json文件报红:Option ‘importsNotUsedAsValues‘ is deprecated...

在创建vue3 vite ts项目时的 tsconfig.json(或者tsconfig.app.json) 配置文件经常会报一个这样的错误: 爆红: Option ‘importsNotUsedAsValues’ is deprecated and will stop functioning in TypeScript 5.5. Specify compi…

打造完美有声书体验,Audiobook Builder for Mac助您一键生成

在快节奏的生活中,有声书成为越来越多人追求放松与娱乐的方式。然而,找到合适的有声书却不容易,而Audiobook Builder for Mac正是为解决这个问题而诞生的完美解决方案。 Audiobook Builder for Mac是一款专业的有声书生成工具,它…