解决deepspeed框架的bug:不保存调度器状态,模型训练重启时学习率从头开始

deepspeed存在一个bug,即在训练时不保存调度器状态,因此如果训练中断后再重新开始训练,调度器还是会从头开始而不是接着上一个checkpoint的调度器状态来训练。这个bug在deepspeed的github中也有其他人提出:https://github.com/microsoft/DeepSpeed/issues/3875
因此我们需要写一个保存调度器状态的代码,才可以解决这个问题。
具体方法是加一个callback类,专门负责保存调度器的状态以及在训练重新开始时加载调度器的状态:

class SchedulerStateCallback(TrainerCallback):def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):if os.environ.get("RANK", "0") == "0":scheduler = kwargs['lr_scheduler']scheduler_state = scheduler.state_dict()save_path = os.path.join(args.output_dir, SCHEDULER_NAME)torch.save(scheduler_state, save_path)#优化器状态已经被deepspeed框架保存了,所以这里没必要再保存# optimizer = kwargs['optimizer']# optimizer_state = optimizer.state_dict()# save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)# torch.save(optimizer_state, save_path)#torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):# 当训练开始时,尝试加载最近的调度器状态# load_path = os.path.join(args.output_dir, OPTIMIZER_NAME)# if os.path.exists(load_path):#     optimizer = kwargs['optimizer']#     optimizer_state = torch.load(load_path)#     optimizer.load_state_dict(optimizer_state)load_path = os.path.join(args.output_dir, SCHEDULER_NAME)if os.path.exists(load_path):scheduler = kwargs['lr_scheduler']scheduler_state = torch.load(load_path)scheduler.load_state_dict(scheduler_state)

解决效果如下,我们可以看到,在chaeckpoint10重新开始训练的时候,学习率是接着之前的学习率开始的(5.5e-7),而不是从头开始(0.5e-7):
在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/99382.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

laragon 为 php 安装 Xdebug 扩展

众所周知,php 自带的 var_dump() 输出格式很不直观 而 laragon 作为很好的 windos 下开发环境很受欢迎,本文就介绍如何快速为 laragon 的 php 安装 Xdebug,方便开发调试 一:启动开发环境,在任意可访问 php 页面中输出 …

英码科技受邀亮相2023WAIE物联网与人工智能展,荣获行业优秀创新力产品奖!

8月28日-30日,2023WAIE 物联网与人工智能展在深圳福田会展中心顺利举办。英码科技受邀亮相本届展会,并现场重点展出了面向智慧交通、智慧校园、智慧应急、智慧园区等不同行业的创新AIoT产品、AI技术服务等内容,与生态伙伴积极探讨市场需求和问…

WordPress(6)网站侧边栏倒计时进度小工具

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 效果图在这里插入图片描述一、添加位置二、主题style.css文件中添加美化1.引入库2.添加自定义的HTML模块效果图 提示:以下是本篇文章正文内容,下面案例可供参考 一、添加位置 在主题中 child.js…

【文心一言大模型插件制作初体验】制作面试错题本大模型插件

文心一言插件开发初体验 效果图 注意:目前插件仅支持在本地运行,虽然只能自用,但仍然是一个不错的选择。(什么?你说没有用?这不可能!文心一言app可以支持语音,网页端结合手机端就可…

16字节协议的串口通信

1.协议要求 协议为帧传输,一共16字节。主要是2字节的固定帧头 EB 90,2字节的帧计数(用来计数发出的帧),10字节的数据和2字节的校验位 帧头:2字节,固定值 8’HEB、8’H90 帧计数:2字节,用来说明发出去帧是…

说说Omega架构

分析&回答 Omega架构我们暂且称之为混合数仓。 什么是ECS设计模式 在谈我们的解法的时候,必须要先提ECS的设计模式。 简单的说,Entity、Component、System分别代表了三类模型。 实体(Entity):实体是一个普通的对象。通常&#xff0c…

Linux——进程间信号(超级详解!!)

索引 一.初始信号1.什么是信号2.前后台进程3.信号的种类4.信号的管理 二.信号产生前1.验证键盘是可以产生信号的2.通过系统调用接口发送信号3.由软件条件产生信号4.硬件异常产生信号5.总结6.core dump 信号产生中1.信号在内核中的表示2.信号集操作函数 信号产生后1.了解内核态和…

如何高效的解析Json?

Json介绍 Json是一种数据格式,广泛应用在需要数据交互的场景Json由键值对组成每一个键值对的key是字符串类型每一个键值对的value是值类型(boo1值数字值字符串值)Array类型object类型Json灵活性他可以不断嵌套,数组的每个元素还可以是数组或者键值对键值…

SpringMVC_基本使用

一、JavaWEB 1.回顾 JavaWEB 1.1新建项目结构 新建 javaweb 项目目录结构 1.2导入依赖 依赖 <dependency><groupId>javax.servlet</groupId><artifactId>javax.servlet-api</artifactId><version>3.1.0</version><scope>…

vue router进行路由跳转并携带参数(params/query)

在使用router.push进行路由跳转到另一个组件时&#xff0c;可以通过params或query来传递参数。 1. 使用params传参&#xff1a; // 在路由跳转时传递参数 router.push({ name: targetComponent, params: {paramName: paramValue // 参数名和值 } });// 在目标组件中通过$r…

企业宣传片和传统纸媒相关优劣

在当今数字化时代&#xff0c;传统纸媒和宣传片成为了企业和组织宣传推广的两种主要方式。然而&#xff0c;面对有限的资源和日益竞争的市场环境&#xff0c;我们需要仔细权衡选择哪种方式更加适合。接下来由深圳企业宣传片制作公司老友记小编从以下几个方面浅析一下它们的优势…

886. 可能的二分法

886. 可能的二分法 原题链接&#xff1a;完成情况&#xff1a;题解一&#xff1a;题解二&#xff1a; 原题链接&#xff1a; 886. 可能的二分法 https://leetcode.cn/problems/possible-bipartition/description/ 完成情况&#xff1a; 题解一&#xff1a; package LeetCod…