PyTorch-Lightning:trining_step的自动优化

文章目录

  • PyTorch-Lightning:trining_step的自动优化
      • 总结:
    • class _ AutomaticOptimization()
      • def run
      • def _make_closure
      • def _training_step
        • class ClosureResult():
          • def from_training_step_output
      • class Closure

PyTorch-Lightning:trining_step的自动优化

使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。

总结:

在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。

在这里插入图片描述

class _ AutomaticOptimization()

实现自动优化(前向,梯度清零,后向,optimizer step)

在training_epoch_loop中会调用这个类的run函数。

def run

首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。

可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。

def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:closure = self._make_closure(kwargs, optimizer, batch_idx)if (# when the strategy handles accumulation, we want to always call the optimizer stepnot self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()):# For gradient accumulation# -------------------# calculate loss (train step + train step end)# -------------------# automatic_optimization=True: perform ddp sync only when performing optimizer_stepwith _block_parallel_sync_behavior(self.trainer.strategy, block=True):closure()# ------------------------------# BACKWARD PASS# ------------------------------# gradient update with accumulated gradientselse:self._optimizer_step(batch_idx, closure)result = closure.consume_result()if result.loss is None:return {}return result.asdict()

def _make_closure

创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数

比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step

def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:step_fn = self._make_step_fn(kwargs)backward_fn = self._make_backward_fn(optimizer)zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)

def _training_step

通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。

将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure

class ClosureResult():

一个数据类,包含closure_loss,loss,extra

    closure_loss: Optional[Tensor]loss: Optional[Tensor] = field(init=False, default=None)extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output

一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。

class Closure

闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。

在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。

这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。

参数:

  • step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
  • backward_fn: 梯度回传函数
  • zero_grad_fn: 梯度清零函数

按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数

class Closure(AbstractClosure[ClosureResult]):warning_cache = WarningCache()def __init__(self,step_fn: Callable[[], ClosureResult],backward_fn: Optional[Callable[[Tensor], None]] = None,zero_grad_fn: Optional[Callable[[], None]] = None,):super().__init__()self._step_fn = step_fnself._backward_fn = backward_fnself._zero_grad_fn = zero_grad_fn@override@torch.enable_grad()def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:step_output = self._step_fn()if step_output.closure_loss is None:self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")if self._zero_grad_fn is not None:self._zero_grad_fn()if self._backward_fn is not None and step_output.closure_loss is not None:self._backward_fn(step_output.closure_loss)return step_output@overridedef __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:self._result = self.closure(*args, **kwargs)return self._result.loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/611015.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

纯小白蓝桥杯备赛笔记--DAY14(计算几何)

文章目录 计算几何基础平面几何距离圆的周长和面积圆与圆之间的关系:海伦公式计算三角形面积点到直线的距离 点积和叉积例题: 点和线的关系点的表示形式和代码判断点在直线的那边点到线的垂足点到线的距离例题-1242例题-1240升级--点到线段的距离--1285 …

基于级联H桥的多电平逆变器PWM控制策略的simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 级联H桥(CHB)多电平逆变器是一种通过多个H桥单元级联实现更高电压等级和更高质量输出波形的电力电子转换装置。这种逆变器在高压大功率场合应用广泛&am…

软考123-上午题-【软件工程】-系统设计

一、系统设计 1-1、概要设计 设计软件系统总结结构数据结构及数据库设计编写概要设计文档评审 1-1-1、设计软件系统总结结构 其基本任务是采用某种设计方法,将一个复杂的系统按功能划分成模块; 确定每个模块的功能;确定模块之间的调用关系…

【LeetCode: 680. 验证回文串 II + 贪心 + 边界处理】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

PaddleDetection 项目使用说明

PaddleDetection 项目使用说明 PaddleDetection 项目使用说明数据集处理相关模块环境搭建 PaddleDetection 项目使用说明 https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.7/configs/ppyoloe/README_cn.md 自己项目: https://download.csdn.net/d…

基于GSP工具箱的NILM算法matlab仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于GSP工具箱的NILM算法matlab仿真。GSP是图形信号处理的缩写,GSP非常适合对未知数据进行分类,尤其是当训练数据非常短时。GSPBox的基本理论是谱图论和…

牛客 NC36 在两个长度相等的排序数组中找到上中位数【中等 模拟 Java,Go,PHP】

题目 题目链接: https://www.nowcoder.com/practice/6fbe70f3a51d44fa9395cfc49694404f 思路 直接模拟2个数组有顺序放到一个数组中,然后返回中间的数参考答案java import java.util.Scanner;// 注意类名必须为 Main, 不要有任何 package xxx 信息 pu…

神经网络背后的数学原理

原文地址:The Math Behind Neural Networks 2024 年 3 月 29 日 深入研究现代人工智能的支柱——神经网络,了解其数学原理,从头开始实现它,并探索其应用。 神经网络是人工智能 (AI) 的核心,为…

在线拍卖系统|基于Springboot的在线拍卖系统设计与实现(源码+数据库+文档)

在线拍卖系统目录 基于Springboot的在线拍卖系统设计与实现 一、前言 二、系统设计 三、系统功能设计 1、前台: 2、后台 用户功能模块 5.2用户功能模块 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取&#xff1a…

基于STC12C5A60S2系列1T 8051单片机的液晶显示器LCD1602显示汉字的功能

基于STC12C5A60S2系列1T 8051单片机的液晶显示器LCD1602显示汉字的功能 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍LCD1602字符型液晶显示器介绍一、LCD1602字符型…

搜维尔科技:【煤矿安全仿真】煤矿事故预防处置VR系统,矿山顶板灾害,冲击地压灾害等预防演练!

产品概述 煤矿事故预防处置VR系统 系统内容: 事故预防处置VR系统的内容包括:火灾的预防措施、火灾预兆、防灭火系统、火灾案例重现、顶板事故预兆、顶板事故原因、顶板事故案例重现、瓦斯概念及性质、瓦斯的涌出形式、瓦斯预兆、瓦斯爆炸条件及预防措…

spiiii

数据手册里面有这么一段解释,就是说如果我们开启了看门狗,那么LSI就会跟随强制打开,等待LSI稳定之后就可以自动为独立看门狗提供时钟了。所以这里的第一步开启时钟不需要我们写代码来执行 2.写入预分频器和重装寄存器 在写入这两个寄存器之前…