【深度学习】如何找到最优学习率

经过了大量炼丹的同学都知道,超参数是一个非常玄乎的东西,比如batch size,学习率等,这些东西的设定并没有什么规律和原因,论文中设定的超参数一般都是靠经验决定的。但是超参数往往又特别重要,比如学习率,如果设置了一个太大的学习率,那么loss就爆了,设置的学习率太小,需要等待的时间就特别长,那么我们是否有一个科学的办法来决定我们的初始学习率呢?

在这篇文章中,我会讲一种非常简单却有效的方法来确定合理的初始学习率。

学习率的重要性

目前深度学习使用的都是非常简单的一阶收敛算法,梯度下降法,不管有多少自适应的优化算法,本质上都是对梯度下降法的各种变形,所以初始学习率对深层网络的收敛起着决定性的作用,下面就是梯度下降法的公式

深度学习:如何找到最优学习率

这里 α 就是学习率,如果学习率太小,会导致网络loss下降非常慢,如果学习率太大,那么参数更新的幅度就非常大,就会导致网络收敛到局部最优点,或者loss直接开始增加,如下图所示。

深度学习:如何找到最优学习率

学习率的选择策略在网络的训练过程中是不断在变化的,在刚开始的时候,参数比较随机,所以我们应该选择相对较大的学习率,这样loss下降更快;当训练一段时间之后,参数的更新就应该有更小的幅度,所以学习率一般会做衰减,衰减的方式也非常多,比如到一定的步数将学习率乘上0.1,也有指数衰减等。

这里我们关心的一个问题是初始学习率如何确定,当然有很多办法,一个比较笨的方法就是从0.0001开始尝试,然后用0.001,每个量级的学习率都去跑一下网络,然后观察一下loss的情况,选择一个相对合理的学习率,但是这种方法太耗时间了,能不能有一个更简单有效的办法呢?

一个简单的办法

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,同时推荐大家去看看这篇论文,有一些非常启发性的学习率设置想法。

这个方法在论文中是用来估计网络允许的最小学习率和最大学习率,我们也可以用来找我们的最优初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。

下面就是随着迭代次数的增加,学习率不断增加的曲线,以及不同的学习率对应的loss的曲线。

深度学习:如何找到最优学习率
深度学习:如何找到最优学习率

从上面的图片可以看到,随着学习率由小不断变大的过程,网络的loss也会从一个相对大的位置变到一个较小的位置,同时又会增大,这也就对应于我们说的学习率太小,loss下降太慢,学习率太大,loss有可能反而增大的情况。从上面的图中我们就能够找到一个相对合理的初始学习率,0.1。

之所以上面的方法可以work,因为小的学习率对参数更新的影响相对于大的学习率来讲是非常小的,比如第一次迭代的时候学习率是1e-5,参数进行了更新,然后进入第二次迭代,学习率变成了5e-5,参数又进行了更新,那么这一次参数的更新可以看作是在最原始的参数上进行的,而之后的学习率更大,参数的更新幅度相对于前面来讲会更大,所以都可以看作是在原始的参数上进行更新的。正是因为这个原因,学习率设置要从小变到大,而如果学习率设置反过来,从大变到小,那么loss曲线就完全没有意义了。

实现

上面已经说明了算法的思想,说白了其实是非常简单的,就是不断地迭代,每次迭代学习率都不同,同时记录下来所有的loss,绘制成曲线就可以了。下面就是使用PyTorch实现的代码,因为在网络的迭代过程中学习率会不断地变化,而PyTorch的optim里面并没有把learning rate的接口暴露出来,导致显示修改学习率非常麻烦,所以我重新写了一个更加高层的包mxtorch,借鉴了gluon的一些优点,在定义层的时候暴露初始化方法,支持tensorboard,同时增加了大量的model zoo,包括inceptionresnetv2,resnext等等,提供预训练权重,model zoo参考于Cadene的repo。目前这个repo刚刚开始,欢迎有兴趣的小伙伴加入我。

下面就是部分代码,近期会把找学习率的代码合并到mxtorch中。这里使用的数据集是kaggle上的dog breed,使用预训练的resnet50,ScheduledOptim的源码点这里。

   
  1. criterion = torch.nn.CrossEntropyLoss()
  2. net = model_zoo.resnet50(pretrained=True)
  3. net.fc = nn.Linear(2048, 120)
  4.  
  5. with torch.cuda.device(0):
  6. net = net.cuda()
  7.  
  8. basic_optim = torch.optim.SGD(net.parameters(), lr=1e-5)
  9. optimizer = ScheduledOptim(basic_optim)
  10.  
  11.  
  12. lr_mult = (1 / 1e-5) ** (1 / 100)
  13. lr = []
  14. losses = []
  15. best_loss = 1e9
  16. for data, label in train_data:
  17. with torch.cuda.device(0):
  18. data = Variable(data.cuda())
  19. label = Variable(label.cuda())
  20. # forward
  21. out = net(data)
  22. loss = criterion(out, label)
  23. # backward
  24. optimizer.zero_grad()
  25. loss.backward()
  26. optimizer.step()
  27. lr.append(optimizer.learning_rate)
  28. losses.append(loss.data[0])
  29. optimizer.set_learning_rate(optimizer.learning_rate lr_mult)
  30. if loss.data[0] < best_loss:
  31. best_loss = loss.data[0]
  32. if loss.data[0] > 4 best_loss or optimizer.learning_rate > 1.:
  33. break
  34.  
  35. plt.figure()
  36. plt.xticks(np.log([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]), (1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1))
  37. plt.xlabel(‘learning rate’)
  38. plt.ylabel(‘loss’)
  39. plt.plot(np.log(lr), losses)
  40. plt.show()
  41. plt.figure()
  42. plt.xlabel(‘num iterations’)
  43. plt.ylabel(‘learning rate’)
  44. plt.plot(lr)

one more thing

通过上面的例子我们能够有一个非常有效的方法寻找初始学习率,同时在我们的认知中,学习率的策略都是不断地做decay,而上面的论文别出心裁,提出了一种循环变化学习率的思想,能够更快的达到最优解,非常具有启发性,推荐大家去阅读阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/221945.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS水平居中与垂直居中的方法

当我们页面布局的时候&#xff0c;通常需要把某一个元素居中&#xff0c;这一篇文章为大家介绍一下居中的几种方法&#xff0c;本人文笔有限&#xff0c;请见谅&#xff01; 一.水平居中 行内元素水平居中的方法&#xff0c;我们使用text-align:center; <!DOCTYPE html&g…

PubMedBERT:生物医学自然语言处理领域的特定预训练模型

今年大语言模型的快速发展导致像BERT这样的模型都可以称作“小”模型了。Kaggle LLM比赛LLM Science Exam 的第四名就只用了deberta&#xff0c;这可以说是一个非常好的成绩了。所以说在特定的领域或者需求中&#xff0c;大语言模型并不一定就是最优的解决方案&#xff0c;“小…

多集群部署中的 Kubernetes 弹性 (RTO/RPO)

啊&#xff0c;Kubernetes&#xff01;我们DevOps挑战的万灵药。 Kubernetes是一个开源的容器编排工具&#xff0c;本应加速软件交付、保护我们的应用程序、降低成本并减少我们的头痛问题&#xff0c;对吗&#xff1f; 不过说真的&#xff0c;Kubernetes已经彻底改变了我们编…

flink源码分析之功能组件(三)-rpc组件

简介 本系列是flink源码分析的第二个系列,上一个《flink源码分析之集群与资源》分析集群与资源,本系列分析功能组件,kubeclient,rpc,心跳,高可用,slotpool,rest,metrics,future。 本文解释rpc组件,rpc组件用于个核心组件,包括作业管理器,资源管理器和任务管理器之…

【1++的Linux】之信号量

&#x1f44d;作者主页&#xff1a;进击的1 &#x1f929; 专栏链接&#xff1a;【1的Linux】 文章目录 一&#xff0c;信号量二&#xff0c;基于环形队列的生产消费者模型三&#xff0c;线程池 一&#xff0c;信号量 1&#xff0c;什么是信号量&#xff1f; 任何时候都有一个…

深度学习之循环神经网络

视频链接&#xff1a;6 循环神经网络_哔哩哔哩_bilibili 给神经网络增加记忆能力 对全连接层而言&#xff0c;输入输出的维数固定&#xff0c;因此无法处理序列信息 对卷积层而言&#xff0c;因为卷积核的参数是共享的&#xff0c;所以卷积操作与序列的长度无关。但是因为卷积…

04_MySQL备份与恢复

任务背景 一、真实案例 某天&#xff0c;公司领导安排刚入职不久的小冯同学将生产环境中的数据(MySQL数据库)全部导入到测试环境给测试人员使用。当小冯去拿备份数据时发现&#xff0c;备份数据是1个礼拜之前的。原因是之前运维同事通过脚本每天对数据库进行备份&#xff0c;…

开发测试利器之Fiddler网络调试工具详细安装使用教程(包含汉化脚本)

一、Fiddler简介 Fiddler 是一款功能强大的网络调试工具&#xff0c;可以帮助开发人员和测试人员分析和调试网络流量。它通过截取计算机和服务器之间的HTTP/HTTPS请求&#xff0c;并提供详细的请求和响应信息来帮助我们理解和诊断网络通信。 Fiddler 可以用于各种用途&#x…

Redis之C语言底层数据结构笔记

目录 动态字符串SDS Dict ZipList QuickList ​ SkipList 动态字符串SDS Dict ZipList QuickList SkipList

【Linux】23、内存超详细介绍

文章目录 零、资料一、内存映射1.1 TLB1.2 多级页表1.3 大页 二、虚拟内存空间分布2.1 用户空间的段2.2 内存分配和回收2.2.1 小对象2.2.2 释放 三、查看内存使用情况3.1 Buffer 和 Cache3.1.1 proc 文件系统3.1.2 案例3.1.2.1 场景 1&#xff1a;磁盘和文件写案例3.1.2.2 场景…

淘宝平台商品详情平台订单接入说明

一 文档说明 本文档面向对象为电商平台商品详情数据和订单进行管理的第三方开发者或自研商家 二 支持范围 目前API已经支持订单的接单、取消、退单处理等操作。如果您的订单管理需求现有API不能满足&#xff0c;可以联系我们提出API需求。 参数说明 通用参数说明 参数不要乱…

css优化滚动条样式

css代码&#xff1a; ::-webkit-scrollbar {width: 6px;height: 6px; }::-webkit-scrollbar-track {background-color: #f1f1f1; }::-webkit-scrollbar-thumb {background-color: #c0c0c0;border-radius: 3px; }最终样式&#xff1a;