PyTorch官网demo解读——第一个神经网络(3)

上一篇:PyTorch官网demo解读——第一个神经网络(2)-CSDN博客

上一篇文章我们讲解了第一个神经网络的模型,这一篇我们来聊聊梯度下降。

大佬说梯度下降是深度学习的灵魂;梯度是损失函数(代价函数)的导数,而下降的目的是让我们的损失不断减少,达到模型收敛的效果,最终拟合出最优的参数w。

所以,我们要先从损失函数(代价函数)说起。

  • 损失函数

从上一篇我们知道这个神经网络的模型是:y = wx + b

对于单一个样本(x, y),它的损失值就是:loss = wx + b - y

为了简单好理解,我们先把b去掉,那么 loss = wx - y,这个误差值可能是负数,而我们衡量一个误差值使用负数好像有点奇怪,于是我们使用均方差,那么单个样本的损失就变成这样:

loss = (wx - y) ^2

假定我们有n个样本【(x1, y1), (x2, y2) …… (xn, yn)】,那么我们的损失函数就是:

这个损失函数实际上是一个开口向上的抛物线,我随机取了5个样本值,手搓了一下,画出来如下图:

  • 梯度下降

梯度实际上就是损失函数的导数,即抛物线上某个点的变化率

所以梯度函数是:

t = 2aw + b

下降就是每次迭代,将当前的权重减去 梯度值乘以学习率,即:

w = w - t * lr

其中 lr 表示学习率,学习率可以理解为我们每次迈的步伐大小,如果迈的步伐太大,会导致在逼近最优参数时难以收敛,步伐太大跨过去了,在最优参数左右摇摆。所以通常学习率会设置比较小的值。

  • 代码

上面我们万般无奈地使用了一堆数学公式...,但有时候数学是最好的抽象方式,就像程序员喜欢说 read the f**king source code 一样。语言是对事物的抽象,而数学是对语言的进一步抽象吧,语言无法表达某些自然的规律,所以需要通过数学来表达,哈哈!

好啦,回顾第一篇的demo代码,关于损失函数,pytorch demo的代码极为精简:

# 丢失函数 loss function
def nll(input, target):return -input[range(target.shape[0]), target].mean()
loss_func = nll

而梯度下降就更为精简了,梯度是自动推导的,只需设置一个标志:

weights.requires_grad_()

没错,只需上面这行代码,在每次训练迭代后调用 loss.backward(),梯度就被计算出来啦,而我们只需在每次迭代中减去梯度值就好,如下:

pred = model(xb) # 通过模型预测
loss = loss_func(pred, yb) # 通过与实际结果比对,计算丢失值loss.backward() # 反向传播
with torch.no_grad():weights -= weights.grad * lr  # 调整权重值bias -= bias.grad * lr  # 调整偏差值weights.grad.zero_()bias.grad.zero_()

weights.grad.zero_()的作用是将每次迭代的梯度清零,不然下次计算梯度的时候会进行叠加。

关于梯度下降,就聊到这里吧!有问题可以留言探讨,共同学习!

未有知而不行者,知而不行,只是未知!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/298951.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

labelme目标检测数据类型转换

1. labelme数据类型 LabelMe是一个开源的在线图像标注工具,旨在帮助用户创建和标记图像数据集。它提供了一个用户友好的界面,让用户可以直观地在图像上绘制标记框、多边形、线条等,以标识和注释图像中的对象或区域。 GitHub:http…

python使用apscheduler定时任务,固定周几运行程序

在add_job中添加参数day_of_week即可: day_of_week "0"表示:只有周一运行day_of_week "0-4"表示:周一到周五运行day_of_week "0,1,2"表示:周一二三运行 示例程序 from datetime import datet…

Docker介绍、常用命令与操作

Docker介绍、常用命令与操作 学习前言为什么要学习DockerDocker里的必要基础概念常用命令与操作1、基础操作a、查看docker相关信息b、启动或者关闭docker 2、容器操作a、启动一个镜像i、后台运行ii、前台运行 b、容器运行情况查看c、日志查看d、容器删除 3、镜像操作a、镜像拉取…

Unity网格篇Mesh(一)

Unity网格篇Mesh(一) 本文的目标1.渲染仔细看下面的图你会发现,锯齿状 2.创建网格顶点4 x 2网格网格的顶点 3.创建网格网格只在Play模式下显示逆时针和顺时针三角形第一个三角面一个四边形由两个三角面组成第一个四边形填充剩余网格 接下一篇…

Wordpress对接Lsky Pro 兰空图床插件

Wordpress对接Lsky Pro 兰空图床插件 wordpress不想存储图片到本地,访问慢,wordpress图片没有cdn想要使用图床,支持兰空自定义接口 安装教程—在wp后台选择插件zip—然后启用—设置自己图床API接口就ok了,文件全部解密&#xff0c…

LazyForEach常见使用问题

目录 1、渲染结果非预期 2、重新渲染时图片闪烁 3、ObjectLink属性变化UI未更新 上篇文章中我们介绍了LazyForEach的基本使用,展示了如何使用LazyForEach构造一个列表,并演示数据的添加、删除、修改如何与LazyForEach配合并正确的更新UI。本篇将介绍使…

全自动洗衣机什么牌子好?最好用的四款内衣洗衣机推荐

随着科技的快速发展,现在的人们越来越注重自己的卫生问题,不仅在吃上面会注重卫生问题,在用的上面也会更加严格要求,而衣服做为我们最贴身的东西,我们对它的要求也会更加高,所以最近这几年较火爆的无疑是内…

游戏行业变天,游戏股遭暴击,腾讯网易等股票还能投资吗?

来源:猛兽财经 作者:猛兽财经 国家新闻出版署发布游戏新规 12月22日国家新闻出版署发布了《网络游戏管理办法》(草案征求意见稿),其中提到网络游戏不得设置每日登陆、首次充值、连续充值等诱导性奖励,而且…

Ubuntu系统如何安装SVN服务端并通过客户端无公网ip实现远程访问?

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

小狐狸ChatGPT系统 H5前端底部菜单导航文字修改方法

小狐狸ChatGPT系统后端都前端都是编译过的,需要改动点什么非常难处理,开源版修改后也需要编译后才能使用,大部分会员也不会使用,像简单的修改下底部菜单文字、图标什么的可以对照处理。这里以小狐狸ChatGPT系统1.9.2版本H5端为例&…

Vue中Render函数、_ref属性、_props配置的使用

Render函数 由于导入的vue为vue.runtime.xxx.js是运行版的vue.只包含:核心功能:没有模板解析器 完整版的Vue为vue.js包含:核心功能模板解析器 vue.runtime.esm.js中的esm为ES6的模块化 //导入的vue并非完整的vue,这样做的好处是…

Pandas有了平替Polars

Polars是一个Python数据处理库,旨在提供高性能、易用且功能丰富的数据操作和分析工具。它的设计灵感来自于Pandas,但在性能上更加出色。 Polars具有以下主要特点: 强大的数据操作功能:Polars提供了类似于Pandas的数据操作接口&am…