Pytorch实现warm up和consine decay

在深度学习领域,模型训练过程中的不稳定性是一个常见的问题。为了解决这个问题,在Resnet这篇论文也提及了Warm Up的方法,通过逐渐增加学习率,引导模型在训练初期更稳定地收敛。同时在warm up之后结合consine decay的方法让训练变得更有效。

warm up和consine decay的意义

  • warm up来自于这篇文章:https://arxiv.org/pdf/1706.02677.pdf
  • consine decay来自于这篇文章:https://arxiv.org/pdf/1812.01187.pdf

在第一轮训练的时候,每个数据点对模型来说都是新的,模型会很快地进行数据分布修正,如果这时候学习率就很大,极有可能导致开始的时候就对该数据“过拟合”,后面要通过多轮训练才能拉回来,浪费时间。
当训练了一段时间(比如两轮、三轮)后,模型已经对每个数据点看过几遍了,或者说对当前的batch而言有了一些正确的了解,较大的学习率就不那么容易会使模型学偏,所以可以适当调大学习率。这个过程就可以看做是warmup。
那么为什么之后还要decay呢?当模型训到一定阶段后(比如10个epoch),模型的分布就已经比较固定了,或者说能学到的新东西就比较少了。如果还沿用较大的学习率,就会破坏这种稳定性,用我们通常的话说,就是已经接近loss的local optimal了,为了靠近这个最低点,我们就要慢慢来。

代码实现

1. 非常简单的前期准备工作

import torch,math
import matplotlib.pyplot as plt# 定义一个简单的网络
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear = torch.nn.Linear(1, 1, bias=False)def forward(self, x):return self.linear(x)# 声明一些超参数
epochs = 100000
warm_up_epochs = epochs*0.3
milestones = [epochs*0.3, epochs*0.7]# 优化器
optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=5e-4)

2. 设置scheduler调度器


# MultiStepLR without warm up
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)# warm_up_with_multistep_lr
warm_up_with_multistep_lr = lambda epoch: epoch / warm_up_epochs if epoch <= warm_up_epochs else 0.1**len([m for m in milestones if m <= epoch])
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)# warm_up_with_cosine_lr
warm_up_with_cosine_lr = lambda epoch: epoch / warm_up_epochs if epoch <= warm_up_epochs else 0.5 * ( math.cos((epoch - warm_up_epochs) /(epochs - warm_up_epochs) * math.pi) + 1)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_cosine_lr)

上面的三段代码分别是:

  1. 不使用warm up + multistep learning rate 衰减
  2. 使用warm up + multistep learning rate 衰减
  3. 使用warm up + consine learning rate衰减

代码均使用pytorch中的torch.optim.lr_scheduler.LambdaLR自定义学习率衰减器

3. 模拟训练过程

lrs = []
for item in range(epochs):lr = optimizer.param_groups[0]["lr"]optimizer.zero_grad()optimizer.step()scheduler.step()lrs.append(scheduler.get_last_lr())plt.plot(range(epochs),lrs)
plt.show()

结果

1. MultiStepLR without warm up

在这里插入图片描述

2. warm_up_with_multistep_lr

warm_up_with_multistep_lr

3. warm_up_with_cosine_lr

warm_up_with_cosine_lr

参考

  1. https://pytorch.org/docs/stable/nn.html
  2. 知乎 https://zhuanlan.zhihu.com/p/148487894
  3. 知乎 https://zhuanlan.zhihu.com/p/424373231

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/19823.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

辅助驾驶功能开发-功能规范篇(22)-3-L2级辅助驾驶方案功能规范

1.3.3 TLA系统功能定义 1.3.3.1 状态机 1.3.3.2 状态迁移图 1.3.3.3 功能定义 1.3.3.3.1 信号需求列表 1.3.3.3.2 系统开启关闭 1&#xff09;初始化 车辆上电后&#xff0c;交通灯辅助系统&#xff08;TLA&#xff09;进行初始化&#xff0c;控制器需在 220ms 内发出第一帧…

Spring-Interceptor拦截器

使用步骤 申明拦截器bean&#xff0c;并实现HandlerInterceptor接口 true为放行&#xff0c;false为拦截 2.定义配置类&#xff0c;继承WebMvcConfigurationSupport&#xff0c;实现addInterceptors方法&#xff0c;该方法调用具体的拦截器进行拦截 也可以在配子类通过实现W…

JMeter进行WebSocket压力测试

背景 之前两篇内容介绍了一下 WebSocket 和 SocketIO 的基础内容。之后用 Netty-SocketIO 开发了一个简单的服务端&#xff0c;支持服务端主动向客户端发送消息&#xff0c;同时也支持客户端请求&#xff0c;服务端响应方式。本文主要想了解一下服务端的性能怎么样&#xff0c;…

一百二十九、Kettle——从MySQL增量导入到GreenPlum

一、目标 用Kettle从MySQL增量导入数据到GreePlum 二、前提准备 &#xff08;一&#xff09;kettle已连上MySQL &#xff08;二&#xff09;kettle已连上GreenPlum 三、实施步骤 &#xff08;一&#xff09;打开kettle&#xff0c;新建转换任务。拖拽2个表输入、替换NULL…

Spring 如何解决 Bean 的循环依赖(循环引用)

Component public class A {Autowiredprivate B b;}Component public class B {Autowiredprivate A a;}上面的情况就是 循环依赖 Bean的创建初始化过程如下 如果不采取措施&#xff0c;那么循环依赖就会进入死循环 但 Spring 已经帮我们解决了大部分循环依赖问题 具体是如何解…

prompt:有需求就有价值,prompt案例

prompt&#xff1a;有需求就有价值 此文章来自于小七姐 首先来看需求&#xff1a; 客户需要生成1000条俏皮灵动&#xff0c;趣味盎然&#xff0c;比喻精妙的和美食有关的短句子&#xff0c;要求文风优美&#xff0c;句子让人充满食欲。 客户使用这些句子的场景比较奇妙&…

路径规划算法:基于食肉植物优化的路径规划算法- 附代码

路径规划算法&#xff1a;基于食肉植物优化的路径规划算法- 附代码 文章目录 路径规划算法&#xff1a;基于食肉植物优化的路径规划算法- 附代码1.算法原理1.1 环境设定1.2 约束条件1.3 适应度函数 2.算法结果3.MATLAB代码4.参考文献 摘要&#xff1a;本文主要介绍利用智能优化…

2023年Q2空调行业品牌数据榜单(京东商品数据)

随着夏季的来临&#xff0c;高温天气也带动部分家电行业的销售&#xff0c;以空调为代表的家电市场正逐步恢复活力。结合鲸参谋电商数据分析平台的相关数据&#xff0c;我们来分析一下2023年Q2空调市场的具体销售表现。 根据鲸参谋平台的数据显示&#xff0c;2023年4-6月份&am…

数据库备份恢复和索引视图

样例表如下&#xff1a; /***************************样例表***************************/CREATE DATABASE booksDB;use booksDB;CREATE TABLE books(bk_id INT NOT NULL PRIMARY KEY,bk_title VARCHAR(50) NOT NULL,copyright YEAR NOT NULL);INSERT INTO booksVALUES (1107…

ICEFlow VPN敏感信息泄露

一个新的历史时期开始了&#xff0c;而眼下又是一个艰难的转折阶段&#xff1a;既要除旧&#xff0c;又要布新&#xff1b;这需要魄力&#xff0c;需要耐力&#xff0c;需要能力&#xff0c;需要精力&#xff0c;当然也需要体力——尽管这一切他都不够&#xff0c;但他自信他的…

前端学习——Web API (Day5)

BOM操作 Window对象 BOM 定时器-延时函数 案例 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport&q…

Broadcast 广播使用详解

和你一起终身学习&#xff0c;这里是程序员Android 经典好文推荐&#xff0c;通过阅读本文&#xff0c;您将收获以下知识点: 一、Broadcast概述二、Broadcast的注册三、Broadcast的注册类型四、静态注册开机广播的实现五、动态监听亮灭屏幕广播实现六、广播的发送方法七、参考文…