TensorFlow入门(二十三、退化学习率)

学习率

        学习率,控制着模型的学习进度。模型训练过程中,如果学习率的值设置得比较大,训练速度会提升,但训练结果的精度不够,损失值容易爆炸;如果学习率的值设置得比较小,精度得到了提升,但训练过程会耗费太多的时间,收敛速度慢,同时也容易出现过拟合的情况。

退化学习率

        退化学习率又叫学习率衰减或学习率更新。更新学习率是希望训练过程中,在精度和速度之间找到一个平衡,兼得学习率大核学习率小的优点。即当训练刚开始时使用大的学习率加快速度,训练到一定程度后使用小的学习率来提高精度。

TensorFlow中常用的退化学习率方法

        ①指数衰减方法

                指数衰减是较为常用的衰减方法,学习率是跟当前的训练轮次指数相关的。

                tf.train.exponential_decay(learning_rate,global_step,decay_steps,decay_rate,staircase = False,name = None)

                参数learning_rate为初始学习率;global_step为当前训练轮次,即epoch;decay_steps用于定义衰减周期,跟参数staircese配合,可以在decay_step个训练轮次内保持学习率不变;decay_rate为衰减率系数;staircase用于定义是阶梯型衰减,还是连续衰减,默认是False,即连续衰减(标准的指数型衰减)。

                指数衰减方法中学习率的具体计算公式如下:

                        decayed_learning_rate = learning_rate*decay_rate^(global_step/decay_steps)

                指数衰减方法中学习率的衰减轨迹如下图:

                        

                        红色的是阶梯型指数衰减,在一定轮次内学习率保持一致

                        绿色的是标准的指数衰减,即连续型指数衰减

        ②自然指数衰减方法

                指数衰减的一种特殊情况,学习率也是跟当前的训练轮次指数相关,只不过是以e为底数。函数中的参数意义与指数衰减方法中的参数相同。

                tf.train.natural_exp_decay(learning_rate,global_step,decay_steps,decay_rate,staircase = False,name = None)

                自然指数衰减方法中的学习率的具体计算公式如下:

                        decayed_learning_rate = learning_rate*exp(-decay_rate*global_step)

                自然指数衰减方法中学习率的衰减轨迹如下图:

                        

                        左下部分的两条曲线是自然指数衰减,右上部分的两条曲线是指数衰减。可以明显看到,自然指数衰减对学习率的衰减程度要远大于一般的指数衰减,它一般用于可以较快收敛的网络,或者是训练数据集比较大的场合。

        ③倒数衰减方法

                训练过程中,倒数衰减方法不固定最小学习率,迭代次数越多,学习率越小。学习率的大小跟训练次数有一定的反比关系。

               tf.train.inverse_time_decay(learning_rate,global_step,decay_steps,decay_rate,staircase = False,name = None)

                参数global_step为用于衰减计算的全局步数,decay_steps为衰减步数,decay_rate为衰减率,staircase用于定义是应用离散阶梯型衰减,还是连续衰减。

                倒数衰减方法中学习率的具体计算公式如下:

                        decayed_learning_rate = learning_rate/(1+decay_rate*global_step/decay_step)

                倒数衰减方法中学习率的衰减轨迹如下图:

                        

                        绿色的是离散阶梯型衰减,红色的是连续型衰减

        ④分段常数衰减方法

                分段常数衰减可以针对不同任务设置不同的学习率,从而进行精细调参。

                tf.train.piecewise_constant(x,boundaries,values,name = None)

                参数x是标量,指的是global_step,即训练次数;boundaries为学习率参数应用区间列表,即迭代次数所在的区间;values为学习率列表,存放在不同区间该使用的学习率的值。需要注意 : values的长度比boundaries的长度多1,因为两个数可以制定出三个区间,有三个区间要用3个学习率。

                分段常数衰减方法中学习率的衰减轨迹如下图:

                        ​​​​​​​

                        每个区间内,学习率的值是不一样的

        ⑤多项式衰减方法

                多项式衰减方法的原理为 : 定义一个初始的学习率和一个最低的学习率,然后按照设置的衰减规则,学习率从初始学习率逐渐降低到最低的学习率,并且可以定义学习率降低到最低的值之后,是一直保持使用这个最低的学习率,还是再升高到一定的值,然后再降低到最低的学习率,循环反复这个过程。

                tf.train.polynomial_decay(learning_rate,global_step,decay_steps,end_learning_rate = 0.0001,power = 1.0,cycle = False,name = None)

                参数global_step为当前训练轮次,即epoch;decay_steps为定义衰减周期;end_learning_rate是最小的学习率,默认值是0.0001;power是多项式的幂,默认值是1,即线性的。cycle用于定义学习率是否到达最低学习率后升高,然后再降低,默认False,保持最低的学习率。

                一般情况下多项式衰减方法中学习率的具体计算公式如下:

                        global_step = min(global_step,decay_steps)

                        decayed_learning_rate = (learning_rate - end_learning_rate) * (1 - global_step / decay_steps)^(power) + end_learning_rate

                如果定义cycle为True,学习率在到达最低学习率后反复升高降低,学习率计算公式如下:

                        decay_steps = decay_steps * ceil(global_step / decay_steps)

                        decayed_learning_rate = (learning_rate - end_learning_rate) * (1 - global_step / decay_steps)^(power) + end_learning_rate

                多项式衰减方法中学习率的衰减轨迹如下图:

                        

                        红色的为cycle = False时的情况,下降后不再上升,保持不变;绿色的为cycle = True时的情况,下降后反复升降。

                多项式衰减中设置学习率反复升降的目的是为了防止神经网络后期训练的学习率过小,导致网络参数陷入某个局部,找不到最优解;设置学习率升高机制,有可能使网络找出局部最优解。

指数衰减示例代码如下:

import tensorflow.compat.v1 as tf
tf.compat.v1.disable_v2_behavior()global_step = tf.Variable(0,trainable=False)
#定义初始学习率
initial_learning_rate = 0.1
#使用指数衰减方法
learning_rate = tf.train.exponential_decay(initial_learning_rate,global_step,decay_steps = 20,decay_rate = 0.8)#定义一个操作,global_step每次加1后完成计步
opt = tf.train.GradientDescentOptimizer(learning_rate)
add_global = global_step.assign_add(1)init = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)print(sess.run(learning_rate))#循环20次,将每次的学习率打印出来for i in range(20):g,rate = sess.run([add_global,learning_rate])print(g,rate)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/132182.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux文件与目录的增删改查

一、增 1、mkdir命令 作用: 创建一个新目录。格式: mkdir [选项] 要创建的目录 常用参数: -p:创建目录结构中指定的每一个目录,如果目录不存在则创建,如果目录已存在也不会被覆盖。用法示例: 1、mkdir directory:创建单个目录 这个命令会在当前目录下创建一个名为…

MySQL(存储过程,store procedure)——存储过程的前世今生 MySQL存储过程体验 MybatisPlus中使用存储过程

前言 SQL(Structured Query Language)是一种用于管理关系型数据库的标准化语言,它用于定义、操作和管理数据库中的数据。SQL是一种通用的语言,可以用于多种关系型数据库管理系统(RDBMS),如MySQ…

win10 wsl安装步骤

参考&#xff1a; 安装 WSL | Microsoft Learn 一、安装wsl 1.若要查看可通过在线商店下载的可用 Linux 发行版列表&#xff0c;请输入&#xff1a; wsl --list --online 或 wsl -l -o> wsl -l -o 以下是可安装的有效分发的列表。 使用 wsl.exe --install <Distro>…

WebKit Insie: Active 样式表

WebKit Inside: CSS 样式表的匹配时机介绍了当 HTML 页面有不同 CSS 样式表引入时&#xff0c;CSS 样式表开始匹配的时机。后续文章继续介绍 CSS 样式表的匹配过程&#xff0c;但是在匹配之前&#xff0c;首先需要收集页面里面的 Active 样式表。 1 Active 样式表 在一个 HTML …

访问Apache Tomcat的manager页面

配置访问Tomcat manager页面的用户名、密码、角色 Tomcat安装完成后&#xff0c;包含了一个管理应用&#xff0c;默认安装在 <Tomcat安装目录>/webapps/manager 例如&#xff1a; 要使用管理页面的功能&#xff0c;需要在conf/tomcat-users.xml文件中配置用户、密码及…

【C++初阶(一)】学习前言 命名空间与IO流

本专栏内容为&#xff1a;C学习专栏&#xff0c;分为初阶和进阶两部分。 通过本专栏的深入学习&#xff0c;你可以了解并掌握C。 &#x1f493;博主csdn个人主页&#xff1a;小小unicorn ⏩专栏分类&#xff1a;C &#x1f69a;代码仓库&#xff1a;小小unicorn的代码仓库&…

C++11 Qt QFutureWatcher lambda

目录 Lambda 介绍 【QT】Qt之QFutureWatcher 简述 传参&#xff1a; 还可以使用 QProgressDialog 作为阻堵 函数&#xff0c;变成同步&#xff1b; 完成后&#xff0c;关闭&#xff1b; MyQProgressDialog 效果&#xff1a; Lambda 介绍 Lambda 函数也叫匿名函数&…

MAX17058_MAX17059 STM32 iic 驱动设计

本文采用资源下载链接&#xff0c;含完整工程代码 MAX17058-MAX17059STM32iic驱动设计内含有代码、详细设计过程文档&#xff0c;实际项目中使用代码&#xff0c;稳定可靠资源-CSDN文库 简介 MAX17058/MAX17059 IC是微小的锂离子(Li )在手持和便携式设备的电池电量计。MAX170…

2562. 找出数组的串联值

2562. 找出数组的串联值 难度: 简单 来源: 每日一题 2023.10.12 给你一个下标从 0 开始的整数数组 nums 。 现定义两个数字的 串联 是由这两个数值串联起来形成的新数字。 例如&#xff0c;15 和 49 的串联是 1549 。 nums 的 串联值 最初等于 0 。执行下述操作直到 nu…

信钰证券:汇金增持提振市场情绪 保险、银行等板块集体拉升

12日&#xff0c;两市股指盘中全线走高&#xff0c;沪指一度克复3100点&#xff0c;上证50指数涨超1%。 稳妥、银行、券商板块团体拉升&#xff0c;到发稿&#xff0c;银行板块方面&#xff0c;瑞丰银行涨约6%&#xff0c;盘中一度涨停&#xff1b;紫金银行、渝农银行、西安银…

04_学习springdoc与oauth结合_简述

文章目录 1 前言2 基本结构3 需要做的配置 简述4 需要做的配置 详述4.1 backend-api-gateway 的配置4.1.1 application.yml 4.2 backend-film 的配置4.2.1 pom.xml 引入依赖4.2.2 application.yml 的配置4.2.3 Spring Security 资源服务器的配置类 MyResourceServerConfig4.2.4…

WebRTC 系列(四、多人通话,H5、Android、iOS)

WebRTC 系列&#xff08;三、点对点通话&#xff0c;H5、Android、iOS&#xff09; 上一篇博客中&#xff0c;我们已经实现了点对点通话&#xff0c;即一对一通话&#xff0c;这一次就接着实现多人通话。多人通话的实现方式呢也有好几种方案&#xff0c;这里我简单介绍两种方案…