【李宏毅机器学习·学习笔记】Tips for Training: Batch and Momentum

本节课主要介绍了Batch和Momentum这两个在训练神经网络时用到的小技巧。合理使用batch,可加速模型训练的时间,并使模型在训练集或测试集上有更好的表现。而合理使用momentum,则可有效对抗critical point。

课程视频:
Youtube:https://www.youtube.com/watch?v=zzbr1h9sF54
知乎:https://www.zhihu.com/zvideo/1617121300702498816
课程PPT:
https://view.officeapps.live.com/op/view.aspx?src=https%3A%2F%2Fspeech.ee.ntu.edu.tw%2F~hylee%2Fml%2Fml2021-course-data%2Fsmall-gradient-v7.pptx&wdOrigin=BROWSELINK

一、Batch

在optimization的过程中,我们实际算微分的时候并不是对所有的data做微分,而是将data分为一个一个的batch (mini batch) 计算微分。
在这里插入图片描述
例如上图,程序先使用第一个batch的数据计算Loss L1,再用L1计算gradient g0,并使用g0来update参数(θ0→θ1);之后,程序又使用第二个batch计算Loss L2,再用L2计算gradient g1,并使用g1来update参数(θ1→θ2)。当所有的batch都过完一遍后,我们就说过完了一个epoch
shuffle是与epoch相关的另一个概念。在每一个epoch开始之前我们都会将其分为一个个batch,shuffle的作用就是确保每一个epoch的batch都不一样。

batch的大小对训练的过程和结果都有一定的影响。如下图,假设一个数据集中有N个样例,左边的batch size = N,即full batch,相当于没有使用batch,程序一次遍历完所有的样例后才update参数;右边的batch size = 1,可视为small batch,程序遍历一个样例即更新一次参数,在一个epoch里需要更新N次参数。从图中看,当batch size = 1时,由于每次只根据一个样例来计算loss,它求出来的gradient噪声是比较大的,所以update的方向看上去是曲曲折折;而左边是根据所有样例来计算loss,其参数的update看上去更为稳健。
在这里插入图片描述

small batch和large batch之间的差异,具体还可从以下几个维度来看:

1. Smaller batch requires longer time for one epoch

如果不考虑并行运算,large batch在训练的过程中,一次需要读更多的数据,它所花费的时间应该比small batch的时间长。但是时间上GPU一般都有并行运算的能力,它可以同时处理多笔数据。当batch size在一定的阈值,训练完一个batch,large batch所花费的时间并不一定比small batch所需的时间多 (如下图左边所示)。相反,在一个epoch中,batch size越小,update参数的次数就越多,训练花费的时间也就更久 (如下图右边所示)。
在这里插入图片描述

2. Smaller batch size has better performance

有实验表明,small batch在训练时取得的准确率更高,从下图可以看出,不管是在训练集还是验证集上,随着batch size增大,模型的准确率会降低。
在这里插入图片描述
对此的一种解释是,large batch更容易陷入critical point。如下图,左边的full batch如果卡在了gradient为0的点,那么update就会停在这个点;而邮编的samll batch,如果一笔batch卡在这个点,可以接着用另一笔来计算loss,或许可由此跳出critical point。
在这里插入图片描述

3. Small batch is better on testing data

有研究表明,在测试集上使用small batch得到的准确率可能更高。如下图所示,尽管在测试集上可能有的large batch的准确率可能略高于small batch,但在测试集上small batch的准确率无一例外均高于large batch。
在这里插入图片描述
一种可能的解释是,窄的峡谷没有办法困住small batch,而大的平原才有可能困住small size,而large batch则容易被困在峡谷里。从下图中可以看出,如果是在峡谷,训练集上的Loss和测试集上的Loss差距较大,从而导致准确率变低。
在这里插入图片描述
总的来说,small size和large size之间的对比如下:
在这里插入图片描述

二、Momemtum

momentum是另一个可能对抗critical point的技术。
如下图,我们假设error surface是一个斜坡,而参数是一个球,在现实的物理时间中,球从斜坡上滚下,不一定会被saddle point或local minima卡住,因为惯性在起作用,受惯性影响,即便受到阻力,球仍然会在一定的时间段内保持原来的运动状态。momentum在训练神经网络过程中的作用就相当于物理世界的惯性
在这里插入图片描述
下图是一个一般的(Vanilla)使用梯度下降法的过程(不考虑momentum)。我们先计算梯度g,再沿着梯度的反方向移动以更新θ。
在这里插入图片描述
而如果考虑momentum,整个梯度下降法更新参数 θ 的过程则如下。第 i 次update时的momentum mi,其方向与上一次参数更新的movement方向一致(如图中蓝色虚线所示)。如果考虑momentum,在第 i 次update参数 θ 时,会综合梯度 gi的反方向(如图中红色虚线所示)与mi(前一步移动的方向),来选择本次move的方向(如图中蓝色实线所示)。
在这里插入图片描述
从下图中左侧的公式推导过程可知,momentum mi其实可以看做之前之间计算出来的gradient的weighted sum。因而我们可以说,加上momentum后的uodate不是只考虑当前的gradient,而是考虑过去所有的gradient的总和
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/62080.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

# X11、Xlib、XFree86、Xorg、GTK、Qt、Gnome和KDE之间的关系

X11、Xlib、XFree86、Xorg、GTK、Qt、Gnome和KDE之间的关系 很多人对于他们是啥是傻傻分不清的,我做了个表格供大家参考。 摘抄: X11是X Window System Protocol, Version 11(RFC1013),是X server和X client之间的通…

Observability:识别生成式 AI 搜索体验中的慢速查询

作者:Philipp Kahr Elasticsearch Service 用户的重要注意事项:目前,本文中描述的 Kibana 设置更改仅限于 Cloud 控制台,如果没有我们支持团队的手动干预,则无法进行配置。 我们的工程团队正在努力消除对这些设置的限制…

100G光模块的应用案例分析:电信、云计算和大数据领域

100G光模块是一种高速光模块,由于其高速率和低延迟的特性,在电信、云计算和大数据领域得到了广泛的应用。在本文中,我们将深入探讨100G光模块在这三个领域的应用案例。 一、电信领域 在电信领域,100G光模块被广泛用于构建高速通…

ECRS工时分析:什么叫标准化作业管理?为什么要进行作业标准化管理

中国自古就有标准化。《孙子兵法》中,孙子训练射箭,射箭的姿势是“标准化操作”;中国武术中的套路是“标准化”;在中国古诗中,字数甚至被“标准化”来打开中国历史,“标准化”作业的例子数不胜数。 而在工厂…

mac-右键-用VSCode打开

1.点击访达,搜索自动操作 2.选择快速操作 3.执行shell脚本 替换代码如下: for f in "$" doopen -a "Visual Studio Code" "$f" donecommand s保存会出现一个弹框,保存为“用VSCode打开” 5.使用

Spring项目整合过滤链模式~实战应用

代码下载 设计模式代码全部在gitee上,下载链接: https://gitee.com/xiaozheng2019/desgin_mode.git 日常写代码遇到的囧 1.新建一个类,不知道该放哪个包下 2.方法名称叫A,干得却是A+B+C几件事情,随时隐藏着惊喜 3.想复用一个方法,但是里面嵌套了多余的逻辑,只能自己拆出来…

MFC计算分贝

分贝的一种定义是,表示功率量之比的一种单位,等于功率强度之比的常用对数的10倍; 主要用于度量声音强度,常用dB表示; 其计算,摘录网上一段资料; 声音的分贝值可以通过以下公式计算&#xff1…

css内容达到最底部但滚动条没有滚动到底部

也是犯了一个傻狗一样的错误 ,滚动条样式是直接复制的蓝湖的代码,有个高度,然后就出现了这样的bug 看了好久一直以为是布局或者overflow的问题,最后发现是因为我给这个滚动条加了个高度,我也是傻狗一样的,…

中国首份仿生机器人产业全景报告发布!大模型带来加速度,三大指标决定竞争格局

AGI火热发展,让仿生机器人的实现补全了最后一块重要拼图。 一直以来,仿生机器人都代表人类对于科技的一种终极想象,备受产业圈热捧。 马斯克、雷军等,纷纷押注这一赛道。特斯拉全尺寸仿生机器人Optimus、小米全尺寸通用人形机器…

docker容器监控:Cadvisor +Prometheus+Grafana的安装部署

目录 Cadvisor PrometheusGrafana的安装部署 一、安装docker: 1、安装docker-ce 2、阿里云镜像加速器 3、下载组件镜像 4、创建自定义网络 二、部署Cadvisor 1、被监控主机上部署Cadvisor容器 2、访问cAdvisor页面 三、安装prometheus 1、部署Prometheus…

【nacos】Param ‘serviceName‘ is illegal, serviceName is blank

报错信息 解决方式 一&#xff1a;缺少依赖 SpringBoot2.4之后不会默认加载bootstrap.yaml&#xff1b;需要手动在pom中加入如下依赖&#xff1a; <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-boot…

unable to write symref for HEAD: Permission denied

今天从gitee上面克隆项目到本地时报错如下 warning: unable to unlink ‘D:/IDEAcode/ruiji1.0/.git/HEAD.lock’: Invalid argument error: unable to write symref for HEAD: Permission denied 解决方法&#xff1a;将要存放项目的文件夹权限修改为完全控制 原先权限&…