人工智能-优化算法之学习率调度器

学习率调度器

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。

  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。

  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 (Izmailov et al, 2018)来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论

%matplotlib inline
import math
import torch
from torch import nn
from torch.optim import lr_scheduler
from d2l import torch as d2ldef net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return modelloss = nn.CrossEntropyLoss()
device = d2l.try_gpu()batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)# 代码几乎与d2l.train_ch6定义在卷积神经网络一章LeNet一节中的相同
def train(net, train_iter, test_iter, num_epochs, loss, trainer, device,scheduler=None):net.to(device)animator = d2l.Animator(xlabel='epoch', xlim=[0, num_epochs],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):metric = d2l.Accumulator(3)  # train_loss,train_acc,num_examplesfor i, (X, y) in enumerate(train_iter):net.train()trainer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()trainer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])train_loss = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % 50 == 0:animator.add(epoch + i / len(train_iter),(train_loss, train_acc, None))test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)animator.add(epoch+1, (None, None, test_acc))if scheduler:if scheduler.__module__ == lr_scheduler.__name__:# UsingPyTorchIn-Builtschedulerscheduler.step()else:# Usingcustomdefinedschedulerfor param_group in trainer.param_groups:param_group['lr'] = scheduler(epoch)print(f'train loss {train_loss:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')

让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为0.3并训练30次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。

lr, num_epochs = 0.3, 30
net = net_fn()
trainer = torch.optim.SGD(net.parameters(), lr=lr)
train(net, train_iter, test_iter, num_epochs, loss, trainer, device)

train loss 0.128, train acc 0.951, test acc 0.885

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/238590.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

acwing算法基础之动态规划--数位统计DP、状态压缩DP、树形DP和记忆化搜索

目录 1 基础知识2 模板3 工程化 1 基础知识 暂无。。。 2 模板 暂无。。。 3 工程化 题目1:求a~b中数字0、数字1、…、数字9出现的次数。 思路:先计算1~a中每位数字出现的次数,然后计算1~b-1中每位数字出现的次数,两个相减即…

7、单片机与W25Q128(FLASH)的通讯(SPI)实验(STM32F407)

SPI接口简介 SPI 是英语Serial Peripheral interface的缩写,顾名思义就是串行外围设备接口。是Motorola首先在其MC68HCXX系列处理器上定义的。 SPI,是一种高速的,全双工,同步的通信总线,并且在芯片的管脚上只占用四根…

uniapp uni-popup组件在微信小程序中滚动穿透问题

起因 在微信小程序中使用uni-popup组件时&#xff0c;出现滚动穿透&#xff0c;并且uni-popup内部内容不会滚动问题。 解决 滚动穿透 查阅官方文档&#xff0c;发现滚动穿透是由于平台差异性造成的&#xff0c;具体解决可以参照文档禁止滚动穿透 <template><page-…

python 实现链表

链表基础知识 链表是在物理内存中不连续&#xff0c;数据通过链表中的指针来链接到下一个元素。 链表由一系列节点组成&#xff0c;节点在运行时动态生成&#xff0c;节点一般包括两个部分&#xff1a;存储数据的数据域&#xff0c;存储下一个节点的指针域 链表的常用操作&a…

熬夜会秃头——beta冲刺Day4

这个作业属于哪个课程2301-计算机学院-软件工程社区-CSDN社区云这个作业要求在哪里团队作业—beta冲刺事后诸葛亮-CSDN社区这个作业的目标记录beta冲刺Day4团队名称熬夜会秃头团队置顶集合随笔链接熬夜会秃头——Beta冲刺置顶随笔-CSDN社区 一、团队成员会议总结 1、成员工作进…

C语言:写一个函数,输入一个十六进制数,输出相应的十进制数

分析&#xff1a; 当用户运行该程序时&#xff0c;程序会提示用户输入一个十六进制数。用户需要在命令行中输入一个有效的十六进制数&#xff0c;例如&#xff1a;"1A3F"。 接下来&#xff0c;程序调用了名为 xbed 的函数&#xff0c;并将用户输入的十六进制数作…

【翻译】直流电动机的控制

直流电&#xff08;DC&#xff09;电机由于其转矩易于控制&#xff0c;速度控制范围广&#xff0c;已广泛应用于可调速驱动或可变转矩控制中。然而&#xff0c;直流电机有一个主要的缺点&#xff0c;即它们需要机械装置&#xff0c;如换向器和刷子来连续旋转。这些机械部件需要…

《opencv实用探索·六》简单理解图像膨胀

1、图像膨胀原理简单理解 膨胀是形态学最基本的操作&#xff0c;都是针对白色部分&#xff08;高亮部分&#xff09;而言的。膨胀就是使图像中高亮部分扩张&#xff0c;效果图拥有比原图更大的高亮区域。 2、图像膨胀的作用 注意一般情况下图像膨胀和腐蚀是联合使用的。 &…

前端面试高频考点—事件循环Event loop

目录 事件循环 执行步骤 概念讲解 主线程 微任务(micro task) 宏任务(macro task) Event Loop经典例题 这段代码的执行结果是什么&#xff1f; 正确答案&#xff1a; 具体流程&#xff1a; 事件循环 主线程从"任务队列"中读取执行事件&#xff0c;这个过程…

利用 NRF24L01 无线收发模块实现传感器数据的无线传输

NRF24L01 是一款常用的无线收发模块&#xff0c;适用于远程控制和数据传输应用。本文将介绍如何利用 NRF24L01 模块实现传感器数据的无线传输&#xff0c;包括硬件的连接和配置&#xff0c;以及相应的代码示例。 一、引言 NRF24L01 是一款基于 2.4GHz 射频通信的低功耗无线收发…

RHCE学习笔记(RHEL8) - RH294

Chapter Ⅰ 介绍Ansible ansible ansible是一款开源自动化平台 ansible围绕一种无代理架构构建,在控制节点上安装ansible,且客户端不需要任何特殊的代理软件;ansible使用SSH等标准协议连接受管主机,并在受管主机上运行代码或命令来确保他们处于ansible指定的状态 Ansible帮…

[每周一更]-(第75期):Go相关粗浅的防破解方案

Go作为编译语言&#xff0c;天然存在跨平台的属性&#xff0c;我们在编译完成后&#xff0c;可以再不暴露源代码的情况下&#xff0c;运行在对应的平台中&#xff0c;但是 还是架不住有逆向工程师的反编译、反汇编的情形&#xff1b;&#xff08;当然我们写的都不希望被别人偷了…