深度学习代码优化(Config,Registry,Hook)

社区开放麦#9 | OpenMMLab 模块化设计背后的功臣

1. 配置文件管理Config

1.1 早期配置参数加载

早期深度学习项目的代码大多使用parse_args,在代码启动入口加载大量参数,不利于维护。

在这里插入图片描述
在这里插入图片描述
常见的配置文件有3中格式:pythonjsonyaml 格式的配置文件,推荐使用Yaml文件来配置训练参数。

基本所有能影响你模型的因素,都被涵括在了这个文件里,而在代码中,你只需要用一个简单的 yaml.load()就能把这些参数全部读到一个dict里。更关键的是,这个配置文件可以随着你的checkpoint一起被存到相同的文件夹,方便你直接拿来做断点训练、finetune或者直接做测试,用来做测试时你也可以很方便把结果和对应的参数对上。

1.2 方案:Click+OmegaConf

效果和hydra类似,把所有的参数都写在 YAML 文件中。click读取命令行中的config文件路径(也可以不传入,使用代码中默认的config文件路径)然后用Omegaconf根据传入的路径读取配置文件,因此只需要在命令行指定配置文件路径,而不是用argparse控制所有的参数,参数一多命令行参数在shell文件中就会特别长,看起来很乱。

pretrained_model_path: "./ckpt/stable-diffusion-v1-5"
pretrained_controlnet_model_path: "./ckpt/sd-controlnet-canny"
control_type: 'canny'dataset_config:video_path: "videos/hat.mp4"prompt: "A woman with a white hat"n_sample_frame: 1# n_sample_frame: 22sampling_rate: 1stride: 80offset: left: 0right: 0top: 0bottom: 0editing_config:use_invertion_latents: Trueuse_inversion_attention: Trueguidance_scale: 12editing_type: "attribute"dilation_kernel: 3editing_phrase: "hat"  # P_objuse_interpolater: True  # frame interpolaterediting_prompts: "A woman with a pink hat"  # P_tgt# source promptclip_length: "${..dataset_config.n_sample_frame}"num_inference_steps: 50prompt2prompt_edit: Truemodel_config:lora: 160# temporal_downsample_time: 4SparseCausalAttention_index: ['first','second','last'] least_sc_channel: 640# least_sc_channel: 100000test_pipeline_config:target: video_diffusion.pipelines.p2p_ddim_spatial_temporal_controlnet.P2pDDIMSpatioTemporalControlnetPipelinenum_inference_steps: "${..validation_sample_logger.num_inference_steps}"seed: 0

yaml文件全部放在configs路径下:

├── configs
│   ├── LOVECon.yaml
│   ├── TokenFlow.yaml
│   ├── Tune-A-Video.yaml
└── main.py

我们就可以对启动函数 run() 使用装饰器@click传入config.yaml路径,然后用OmegaConf像属性一样读写,处理好参数之后,再加载主函数main()

import click
from typing import Optional,Dict
from omegaconf import DictConfig, OmegaConf
from rich import print  # colorful printdef main(config: str,**kwargs):print("Training...")@click.command()
@click.option("--config", type=str, default="Project_Manage\configs\data.yaml")
def run(config):# load configomega_dict = OmegaConf.load(config)print(omega_dict)# read configprint(omega_dict.data_setting.data_path)# write configomega_dict.seed = 2# add configomega_dict.update({"num": 2})# merge configmerge_dict = OmegaConf.merge(omega_dict, OmegaConf.load("Project_Manage\configs\model.yaml"))print(merge_dict)# save configOmegaConf.save(merge_dict, "Project_Manage\configs\merge.yaml")main(config=config, **omega_dict)if __name__ == "__main__":  run()

2. 注册器机制Registry

2.1 预备知识:python装饰器

  • 一等对象first class:python中一切皆对象,函数不例外。first class是指可以运行时创建、可以赋值给变量、可以当参数传递、可以做函数返回值的东西。
    在这里插入图片描述

  • 高阶函数high order function:拿其他函数作为参数返回值的函数。
    在这里插入图片描述

  • 内层函数、外层函数:当函数嵌套定义的时候,外层函数的变量作用域 会扩展到 内层函数(说人话就是:inner函数可以使用outer函数的变量)。outer()作为高阶函数,返回一等对象inner()

def outer(a):def inner():return areturn inner  # outer函数返回:inner函数(一等对象)
outer(1)()  # 最后的()调用inner函数
> 1
# 等价于 #
def outer(a):def inner():return areturn inner()  # outer函数返回:inner函数调用结果
outer(1)
> 1
  • 闭包:当一个函数返回另一个函数时,内部函数访问外部函数的变量参数内部函数可见的外部对象们(变量或函数)就构成一个闭包环境__closure__。在下面例子中,inner函数形成了一个闭包,包含2个int对象,分别对应outer函数的参数a和b(闭包环境__closure__中可能有多个变量,是一个list)。当outer函数被调用时,它会返回inner函数的引用,同时实例化inner闭包环境中的int对象,inner函数仍然可以访问outer函数传递的参数a和b完成调用。
def outer(a, b):def inner():return a + breturn inner  inner = outer(1, 2)  # outer函数返回:inner函数(一等对象)
inner.__closure__  # inner的闭包环境:(<cell : int object>, <cell : int object>)
inner.__closure__[0].cell_contents  # 1
inner.__closure__[1].cell_contents  # 2
inner()  # 3
  • 万能形参*是对序列进行解包打包*args就是对传入的多个value参数(也叫positional arguments)进行打包成元组**kwargs就是对传入的多个key=value参数(也叫keyword arguments)进行打包成字典*args必须写在**kwargs之前)。 使用了万能形参,管你多少个参数,管你什么类型,我都可以扔到这两个里面。这就减少了重复写同名函数(避免函数重载)。
def foo(*number):  # 对1, 2, 3, 4, 5打包print(type(number), number)
foo(1, 2, 3, 4, 5)def f(a, b, c):  # 对[1,2,3]解包print(a, b, c)
f(*[1, 2, 3])
def foo(*args, **kwargs):print ('args = ', args)    print ('kwargs = ', kwargs)print ("-"*40)
if __name__ == '__main__':foo(1 ,2 ,3 ,4)  # 对 value 参数进行打包foo(a=1 ,b=2 ,c=3)  # 对 key=value 参数进行打包foo(1 ,2 ,3 ,4, a=1 ,b=2 ,c=3)foo('a', 1, None, a=1, b='2', c=3)
args =  (1, 2, 3, 4)
kwargs =  {}
----------------------------------------
args =  ()
kwargs =  {'a': 1, 'b': 2, 'c': 3}
----------------------------------------
args =  (1, 2, 3, 4)
kwargs =  {'a': 1, 'b': 2, 'c': 3}
----------------------------------------
args =  ('a', 1, None)
kwargs =  {'a': 1, 'b': '2', 'c': 3}
----------------------------------------
  • 装饰器:用@语法糖定义和应用装饰器装饰器是一种高阶函数,可以修改其他函数的行为添加额外的功能。my_decorator是一个装饰器函数,它接受一个函数func作为参数,在原始函数执行前后添加了一些额外的操作,并返回一个新的函数wrapper。具体来说有4种类型:(真正的装饰器接受func,可能会加上外层函数接受装饰器的配置参数)

(1)装饰器需要配置,原函数需要包装。

def decorator(func):  # 外层装饰器接受funcprint('do something')return func  # 不包装直接返回func# 使用 @ 语法糖应用装饰器
@decorator
def my_function():print("excute my func")# 调用被装饰后的函数
my_function()

do something
excute my func

(2)装饰器需要配置,原函数需要包装。返回的wrapper是真正的装饰器函数。

def decorator(num):  # 外层函数接受配置参数numdef wrapper(func):  # 内层wrapper才是真正的装饰器print('do something', num)return func  # 不包装直接返回funcreturn wrapper# 使用 @ 语法糖应用装饰器
@decorator(123)
def my_function():print("excute my func")# 调用被装饰后的函数
my_function()

(3)装饰器需要配置,原函数需要包装。最经典应用的就是pre_processpost_process使用time.time(),计算func的执行时间。

def decorator(func):  # 外层装饰器接受funcprint('do something')def wrapper(*args, **kwargs):  # 包装函数func为wrapperprint('pre_process')result = func(*args, **kwargs)print('post_process')return result  # 返回包装函数wrapper执行结果return wrapper# 使用 @ 语法糖应用装饰器
@decorator
def my_function():print("excute my func")# 调用被装饰后的函数
my_function()

(4)装饰器需要配置,原函数需要包装。

def decorator(x):  # 外层函数接受配置参数numdef inner_dec(func):  # 内层装饰器接受funcprint("do something", x)def wrapper(*args, **kwargs):  # 包装函数func为wrapperprint('pre_process')result = func(*args, **kwargs)print('post_process')return resultreturn wrapperreturn inner_dec# 使用 @ 语法糖应用装饰器
@decorator(123)
def my_function():print("excute my func")# 调用被装饰后的函数
my_function()
  • 类装饰器:装饰器也不一定只能用函数来写,也可以使用类装饰器,用法与函数装饰器并没有太大区别,实质是使用了类方法中的__call__魔法方法来实现类的直接调用。
class logging(object):def __init__(self, func):self.func = funcdef __call__(self, *args, **kwargs):print("[DEBUG]: enter {}()".format(self.func.__name__))return self.func(*args, **kwargs)@logging
def hello(a, b, c):print(a, b, c)hello("hello,","good","morning")
-----------------------------
>>>[DEBUG]: enter hello()
>>>hello, good morning

类装饰器也是可以带参数的,如下实现

class logging(object):def __init__(self, level):self.level = leveldef __call__(self, func):def wrapper(*args, **kwargs):print("[{0}]: enter {1}()".format(self.level, func.__name__))return func(*args, **kwargs)return wrapper@logging(level="TEST")
def hello(a, b, c):print(a, b, c)hello("hello,","good","morning")
-----------------------------
>>>[TEST]: enter hello()
>>>hello, good morning

2.2 Registry机制

前面我们读取到的Config实际上是一个大型的字典,仅实现了对参数的模块化解析:包含dataset的configmodel的configlr的configoptmizer的configtrain的config等。
在这里插入图片描述

但是这些都是字典参数,并没有对各个模块进行实例化,Registry要做的就是,从配置文件Config中直接解析出对应模块的信息,用Registry把模型结构与训练策略给实例化出来

在众多深度学习开源库的代码中经常出现Registry代码块,例如OpenMMlab,facebookresearch、BasicSR中都使用了注册器机制。下面以BasicSR为例,解释一下Registry:

class Registry():"""The registry that provides name -> object mapping, to support third-partyusers' custom modules.To create a registry (e.g. a backbone registry):.. code-block:: pythonBACKBONE_REGISTRY = Registry('BACKBONE')To register an object:.. code-block:: python@BACKBONE_REGISTRY.register()class MyBackbone():...Or:.. code-block:: pythonBACKBONE_REGISTRY.register(MyBackbone)"""def __init__(self, name):"""Args:name (str): the name of this registry"""self._name = nameself._obj_map = {}def _do_register(self, name, obj, suffix=None):if isinstance(suffix, str):name = name + '_' + suffixassert (name not in self._obj_map), (f"An object named '{name}' was already registered "f"in '{self._name}' registry!")self._obj_map[name] = objdef register(self, obj=None, suffix=None):"""Register the given object under the the name `obj.__name__`.Can be used as either a decorator or not.See docstring of this class for usage."""if obj is None:# used as a decoratordef deco(func_or_class):name = func_or_class.__name__self._do_register(name, func_or_class, suffix)return func_or_classreturn deco# used as a function callname = obj.__name__self._do_register(name, obj, suffix)def get(self, name, suffix='basicsr'):ret = self._obj_map.get(name)if ret is None:ret = self._obj_map.get(name + '_' + suffix)print(f'Name {name} is not found, use name: {name}_{suffix}!')if ret is None:raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")return retdef __contains__(self, name):return name in self._obj_mapdef __iter__(self):return iter(self._obj_map.items())def keys(self):return self._obj_map.keys()DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')

上面的代码为数据集,架构,网络,损失以及度量方式都创建了一个注册器对象。核心代码在register函数里,register函数使用了装饰器的设计,也就是只要在功能模块前进行@xx.register()进行装饰,就会对原有功能模块进行注册,并且最终返回原始的功能模块,不修改其原有功能。

在更下层的_do_register()中可以看到,这里使用的是一个字典来执行注册操作,记录的键值对分别是模块的名称以及模块本身。这样一来,读取配置文件中的模块字符串后,我们就能够直接通过函数名或者类名找到其具体实现。

使用方法如下所示,只需要在此类前加上装饰,后期则直接能够从字符串L1Loss找到其对应的实现。

@LOSS_REGISTRY.register()
class L1Loss(nn.Module):"""L1 (mean absolute error, MAE) loss.Args:loss_weight (float): Loss weight for L1 loss. Default: 1.0.reduction (str): Specifies the reduction to apply to the output.Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'."""def __init__(self, loss_weight=1.0, reduction='mean'):super(L1Loss, self).__init__()if reduction not in ['none', 'mean', 'sum']:raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')self.loss_weight = loss_weightself.reduction = reductiondef forward(self, pred, target, weight=None, **kwargs):"""Args:pred (Tensor): of shape (N, C, H, W). Predicted tensor.target (Tensor): of shape (N, C, H, W). Ground truth tensor.weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None."""return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)

3. Hook

推荐Pytorch_linghtning,对于训练的封装。(mmcv的Runner也类似)

3.1 钩子编程

hook允许你在特定的代码点插入自定义的代码。通过使用钩子(hooks),你可以在程序执行到特定的位置时注入自己的代码以便进行额外的处理或修改程序的行为

如下面的例子,正常的git commit添加pre-commit-hook后,就会在git commit前执行一些检查操作(文件大小是否合格等):

在这里插入图片描述
但是随着需求不断增加,插入的代码也越来越乱,相比于直接修改原始代码这种侵入式的修改,我们需要一种非侵入式的修改,使得hook加入的更加清晰直观。如下,直接在forward中添加打印模型结构和参数的代码。
在这里插入图片描述
在实际操作中,我们常常在函数执行的前后注册hook函数,实现非侵入式的修改。如pytorch的nn.Module的forward底层是__call__方法,它在执行forward之前会执行_forward_pre_hooks,在执行forward之后会执行_forward_hooks
在这里插入图片描述

3.2 Pytorch_Lightning hook介绍

在这里插入图片描述

下面PL模型的实现可以在fit(train + validate), validate, test, predict每个epoch每个batch前后添加hook函数:如setupon_xxx_epoch_endon_xxx_batch_end等(end函数一般用来作为loss和acc的log hook)。

class LitModel(pl.LightningModule):def __init__(...):# init: 初始化,包括模型和系统的定义。def prepare_data(...):# 准备数据,包括下载数据、预处理等等def setup(...):# 执行fit(train + validate), validate, test, or predict前的hook function,进行数据划分等操作def configure_optimizers(...)# configure_optimizers: 优化器定义,返回一个优化器,或数个优化器,或两个List(优化器,Scheduler)def forward(...):# forward: 前向传播,和正常的Ptorch的forward一样def train_dataloader(...)# 加载train datadef training_step(...)# training_step(self, batch, batch_idx): 即每个batch的处理函数, z=self(x)等价于z=forward(x)def on_train_epoch_end(...)# training epoch end hook functiondef validation_dataloader(...)# 加载validationdatadef validation_step(...)# validation_step(self, batch, batch_idx): 即每个batch的处理函数def on_validation_epoch_end(...)# validation epoch end hook functiondef test_dataloader(...)# 加载testdatadef test_step(...)# test_step(self, batch, batch_idx): 即每个batch的处理函数def on_test_epoch_end(...)# test epoch end hook functiondef any_extra_hook(...)

上面介绍的PL的hook函数只是比较常用的,更多更全的PL ho
ok介绍可以在官网中查看:https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/core/hooks.html

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/230847.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux下文件操作函数

一.常见IO函数 fopen fclose fread fwrite fseek fflush fopen 运行过程 &#xff1a;打开文件 写入数据 数据写到缓冲区 关闭文件后 将数据刷新入磁盘 1.fopen 返回文件类型的结构体的指针 包括三部分 1).文件描述符&#xff08;整形值 索引到磁盘文件&#xff09;…

不同类型的开源许可证

不同类型的开源许可证 什么是开源许可证 最简单的解释是&#xff0c;开源许可证是计算机软件和其他产品的许可证&#xff0c;允许在定义的条款和条件下使用、修改或共享源代码、蓝图或设计。开源并不意味着该软件可以根据需要使用、复制、修改和分发。根据开源许可证的类型&a…

【批处理常用命令及用法大全】

文章目录 1 echo 和 回显控制命令2 errorlevel程序返回码3 dir显示目录中的文件和子目录列表4 cd更改当前目录5 md创建目录6 rd删除目录7 del删除文件8 ren文件重命名9 cls清屏10 type显示文件内容11 copy拷贝文件12 title设置cmd窗口的标题13 ver显示系统版本14 label 和 vol设…

笔记61:注意力提示

本地笔记地址&#xff1a;D:\work_file\&#xff08;4&#xff09;DeepLearning_Learning\03_个人笔记\3.循环神经网络\第10章&#xff1a;动手学深度学习~注意力机制 a a a a a a a a

工艺系统所管理数字化实践

摘要 本文介绍了上海核工程设计研究院在数字化转型方面的实践&#xff0c;包括业务数字化和管理数字化两个方面。业务数字化方面&#xff0c;该院通过开发小工具改进工作流程。管理数字化方面&#xff0c;该院采用零代码平台集中管理管道力学信息相关模型和数据&#xff0c;并…

函数保留凸性的一些运算,限制为一条线

凸优化在学术研究中非常重要&#xff0c;经常遇到的问题是证明凸性。常规证明凸性的方式是二阶导数的黑塞矩阵为半正定&#xff0c;或者在一维函数时二阶导数大于等于零。但很多时候的数学模型并不那么常规、容易求导的&#xff0c;若能够知道一些保留凸性的运算&#xff0c;将…

【Qt之QSqlRelationalDelegate】描述及使用

描述 QSqlRelationalDelegate类提供了一个委托&#xff0c;用于显示和编辑来自QSqlRelationalTableModel的数据。 与默认委托不同&#xff0c;QSqlRelationalDelegate为作为其他表的外键的字段提供了一个组合框。 要使用该类&#xff0c;只需在带有QSqlRelationalDelegate实例…

用flutter 写一个专属于儿子的听书的app

背景: 儿子最近喜欢上了用儿童手表听故事&#xff0c;但是手表边里的应用免费内容很少&#xff0c;会员一年要300多&#xff0c;这么一笔巨款&#xff0c;怎能承担的起&#xff0c;所以打算自己开发一个专属于儿子的听书app。 最终效果&#xff1a; 架构&#xff1a; 后端由两…

Python工具Anaconda+Pycharm安装教程详解

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、介绍二、Anaconda的安装三、Pycharm的安装四、环境配置五、python库文安装件——以opencv为例关于Python技术储备一、Python所有方向的学习路线二、Python基础学…

ChatGPT成了背锅侠:利用AI做蹭热点视频

我是卢松松&#xff0c;点点上面的头像&#xff0c;欢迎关注我哦&#xff01; 在抖音\视频号上已经有很多人利用ChatGPT做热点视频的案例了&#xff0c;视频都是点赞大几千、几万。看完本文&#xff0c;你会略知一二&#xff0c;如下图所示&#xff1a; 这个视频&#xff0c…

Elk:filebeat 日志收集工具和logstash

Elk:filebeat 日志收集工具和logstash Filebeat是一个轻量级的日志手机工具,所使用的系统资源比logstash部署和启动时使用的资源要小得多 Filebeat可以在非java环境使用&#xff0c;他可以代理logstash在非java环境上收集日志 缺点 Filebeat无法实现数据的过滤,一般是结合l…

建立健全涉密测绘外业安全保密管理制度,落实监管人员和保密责任,外业所用涉密计算机纳入涉密单机进行管理

建立健全涉密测绘外业安全保密管理制度&#xff0c;落实监管人员和保密责任&#xff0c;外业所用涉密计算机纳入涉密单机进行管理 1.涉密测绘外业安全保密管理制度 2.外业人员及设备清单&#xff08;包括&#xff1a;外业从业人员名单、工作岗位&#xff0c;外业设备名称、密…