pytorch08:学习率调整策略

在这里插入图片描述

目录

  • 一、为什么要调整学习率?
    • 1.1 class _LRScheduler
  • 二、pytorch的六种学习率调整策略
    • 2.1 StepLR
    • 2.2 MultiStepLR
    • 2.3 ExponentialLR
    • 2.4 CosineAnnealingLR
    • 2.5 ReduceLRonPlateau
    • 2.6 LambdaLR
  • 三、学习率调整小结
  • 四、学习率初始化

一、为什么要调整学习率?

学习率(learning rate):控制更新的步伐
一般在模型训练过程中,在开始训练的时候我们会设置学习率大一些,随着模型训练epoch的增加,学习率会逐渐设置小一些。

1.1 class _LRScheduler

学习率调整的父类函数
在这里插入图片描述
主要属性:
• optimizer:关联的优化器
• last_epoch:记录epoch数
• base_lrs:记录初始学习率
主要方法:
• step():更新下一个epoch的学习率,该操作必须放到epoch循环下面
• get_lr():虚函数,计算下一个epoch的学习率

二、pytorch的六种学习率调整策略

2.1 StepLR

在这里插入图片描述

功能:等间隔调整学习率
主要参数:
• step_size:调整间隔数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现:

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plttorch.manual_seed(1)LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))optimizer = optim.SGD([weights], lr=LR, momentum=0.9)# ------------------------------ 1 Step LR ------------------------------
# flag = 0
flag = 1
if flag:scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略,50轮下降一次,每次下降10倍lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()  # 学习率更新策略plt.plot(epoch_list, lr_list, label="Step LR Scheduler")plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果:
在这里插入图片描述

因为我们设置每50个epoch降低一次学习率,所以在7774554

2.2 MultiStepLR

在这里插入图片描述

功能:按给定间隔调整学习率
主要参数:
• milestones:设定调整时刻数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现

flag = 1
if flag:milestones = [50, 125, 160]  # 设置学习率下降的位置scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()plt.plot(epoch_list, lr_list, label="Multi Step LR Scheduler\nmilestones:{}".format(milestones))plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果
在这里插入图片描述

根据我们设置milestones = [50, 125, 160],发现学习率在这三个地方发生下降。

2.3 ExponentialLR

在这里插入图片描述

功能:按指数衰减调整学习率
主要参数:
• gamma:指数的底
调整方式:lr = lr * gamma^epoch;这里的gamma通常设置为接近1的数值,例如:0.95

代码实现

flag = 1
if flag:gamma = 0.95scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()plt.plot(epoch_list, lr_list, label="Exponential LR Scheduler\ngamma:{}".format(gamma))plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果
在这里插入图片描述

可以发现学习率是呈指数下降的。

2.4 CosineAnnealingLR

在这里插入图片描述

功能:余弦周期调整学习率
主要参数:
• T_max:下降周期
• eta_min:学习率下限
调整方式:
在这里插入图片描述

代码实现

flag = 1
if flag:t_max = 50scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()plt.plot(epoch_list, lr_list, label="CosineAnnealingLR Scheduler\nT_max:{}".format(t_max))plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果
在这里插入图片描述

2.5 ReduceLRonPlateau

在这里插入图片描述

功能:监控指标,当指标不再变化则调整,例如:可以监控我们的loss或者准确率,当其不发生变化的时候,调整学习率。
主要参数:
• mode:min/max 两种模式
min模式:当某一个值不下降的时候我们调整学习率,通常用于监控损失
max模型:当某一个值不上升的时候我们调整学习率,通常用于监控精确度
• factor:调整系数
• patience:“耐心”,接受几次不变化
• cooldown:“冷却时间”,停止监控一段时间
• verbose:是否打印日志
• min_lr:学习率下限
• eps:学习率衰减最小值

代码实现

flag = 1
if flag:loss_value = 0.5accuray = 0.9factor = 0.1  # 学习率变换参数mode = "min"patience = 10  # 能接受多少轮不变化cooldown = 10  # 停止监控多少轮min_lr = 1e-4  # 设置学习率下限verbose = True  # 打印更新日志scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,cooldown=cooldown, min_lr=min_lr, verbose=verbose)for epoch in range(max_epoch):for i in range(iteration):# train(...)optimizer.step()optimizer.zero_grad()#if epoch == 5:# loss_value = 0.4scheduler_lr.step(loss_value) #监控的标量是否下降

输出结果
在这里插入图片描述

2.6 LambdaLR

在这里插入图片描述
功能:自定义调整策略
主要参数:
• lr_lambda:function or list

代码实现

flag = 1
if flag:lr_init = 0.1weights_1 = torch.randn((6, 3, 5, 5))weights_2 = torch.ones((5, 5))optimizer = optim.SGD([{'params': [weights_1]},{'params': [weights_2]}], lr=lr_init)# 设置两种不同的学习率调整方法lambda1 = lambda epoch: 0.1 ** (epoch // 20)  # 每到20轮的时候学习率变为原来的0.1倍lambda2 = lambda epoch: 0.95 ** epoch  # 将学习率进行指数下降scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])lr_list, epoch_list = list(), list()for epoch in range(max_epoch):for i in range(iteration):# train(...)optimizer.step()optimizer.zero_grad()scheduler.step()lr_list.append(scheduler.get_lr())epoch_list.append(epoch)print('epoch:{:5d}, lr:{}'.format(epoch, scheduler.get_lr()))plt.plot(epoch_list, [i[0] for i in lr_list], label="lambda 1")plt.plot(epoch_list, [i[1] for i in lr_list], label="lambda 2")plt.xlabel("Epoch")plt.ylabel("Learning Rate")plt.title("LambdaLR")plt.legend()plt.show()

输出结果
在这里插入图片描述

通过lambda方法定义了两种不同的学习率下降策略。

三、学习率调整小结

  1. 有序调整:Step、MultiStep、Exponential 和 CosineAnnealing
  2. 自适应调整:ReduceLROnPleateau
  3. 自定义调整:Lambda

四、学习率初始化

1、设置较小数:0.01、0.001、0.0001
2、搜索最大学习率: 参考该篇《Cyclical Learning Rates for Training Neural Networks》
方法:我们可以设置学习率逐渐从小变大观察精确度的一个变化,下面这幅图,当学习率为0.055左右的时候模型精确度最高,当学习率大于0.055的时候精确度出现下降情况,所以在模型训练过程中我们可以设置学习率为0.055作为我们的初始学习率。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/322981.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【基础篇】十四、GC算法

文章目录 1、实现思路2、SWT3、GC算法4、标记清除算法Mark Sweep GC5、复制算法Copying GC5、标记整理算法6、分代算法Generational GC7、分代的整体流程 1、实现思路 Java实现垃圾回收的步骤: 根据GC Root对象可达性分析,将内存中对象标记为存活的、可…

【python入门】day18:文件、os相关操作

编码格式 1、ASCLL–ISO8859-1–GBK–UTF-8– ISO8859-1–GBK等文件在格式上会显示为ASCLL 2、python文件默认格式 utf-8,看格式流程 选择python文件–用记事本打开–另存为… 这时可看到该文件的格式 3、修改python文件默认编码格式,在文件开头添加上: …

Docker容器相关操作

文章目录 容器相关操作1 新建并启动容器2 容器日志3 删除容器4 列出容器5 创建容器6 启动、重启、终止容器7 进入容器8 查看容器9 更新容器10 杀掉容器11 docker常用命令汇总 容器相关操作 ​ 容器是镜像的运行时实例。正如从虚拟机模板上启动 VM 一样,用户也同样可…

分布式【Zookeeper三大核心之数据节点ZNode】

ZooKeeper在分布式领域,能够帮助解决很多很多的分布式难题,但是底层却只是依赖于两个主要的组件:ZNode文件/数据存储系统和watch监听系统,另外还有一大模块,就是ACL系统。本节我们介绍下znode文件/数据存储系统。 一、…

uniapp微信小程序投票系统实战 (SpringBoot2+vue3.2+element plus ) -后端架构搭建

锋哥原创的uniapp微信小程序投票系统实战: uniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )_哔哩哔哩_bilibiliuniapp微信小程序投票系统实战课程 (SpringBoot2vue3.2element plus ) ( 火爆连载更新中... )共计21条视频…

VuePress部署到GitHub Pages

一、git push自动部署 1、创建用于工作流的文件 在项目根目录下创建一个用于 GitHub Actions 的工作流 .yml 文件 name: docson:# 每当 push 到 main 分支时触发部署push:branches: [main]# 手动触发部署workflow_dispatch:jobs:docs:runs-on: ubuntu-lateststeps:- uses: a…

Python-CSV文件的存储

CSV文件存储 CSV其文件以纯文本形式存储表格数据。CSV文件是一个字符序列,可以由任意数目的记录组成,各种记录由某种换行符分隔开。它比Excel文件更加简洁,XLS文本是电子表格,包含文本、数值、公式和格式等内容,CSV中则…

安装extiverse/mercury时报错

问题描述 作者在安装 Flarum 的插件 extiverse/mercury 时报错,内容如下图所示 解决方案 ⚠警告:请备份所有数据再进行接下来的操作,此操作可能会导致网站不可用! 报错原因:主要问题是在安装过程中解决依赖关系。具…

vue项目使用vue-pdf插件预览pdf文件

1、安装vue-pdf&#xff1a;npm install --save vue-pdf 2、使用 具体实现代码&#xff1a;pdfPreview.vue <template><div class"container"><pdfref"pdf":src"pdfUrl":page"currentPage":rotate"pageRotate&qu…

提升办公效率:掌握批量文件重命名的技巧

在日常生活和工作中&#xff0c;经常要处理大量的文件&#xff0c;如文档、图片、音频等。在这些情况下&#xff0c;会遇到要批量重命名文件的情况。如果一个一个地重命名&#xff0c;不仅耗时&#xff0c;而且效率低下。今天来讲解一些技巧通过批量重命名文件&#xff0c;从而…

【springboot+vue项目(零)】开发项目经验积累(处理问题)

一、VUEElement UI &#xff08;一&#xff09;elementui下拉框默认值不是对应中文问题 v-model绑定的值必须是字符串&#xff0c;才会显示默认选中对应中文&#xff0c;如果是数字&#xff0c;则显示数字&#xff0c;修改为&#xff1a; handleOpenAddDialog() {this.dialogT…

Android中的Intent

一.显式Intent 显示Intent是明确目标Activity的类名 1. 通过Intent(Context packageContext, Class<?> cls)构造方法 2.通过Intent的setComponent()方法 3.通过Intent的setClass/setClassName方法 通过Intent(Context packageContext, Class<?> cls)构造方法 通…