动手学深度学习(二)线性神经网络

推荐课程:跟李沐学AI的个人空间-跟李沐学AI个人主页-哔哩哔哩视频

回归任务是指对连续变量进行预测的任务。

一、线性回归

线性回归模型是一种常用的统计学习方法,用于分析自变量与因变量之间的关系。它通过建立一个关于自变量和因变量的线性方程,来对未知数据进行预测。

1.1 线性模型

举个例子,房价预测模型

  • 假设1︰影响房价的关键因素是卧室个数,卫生间个数和居住面积,记为x1,x2,x3。
  • 假设2:成交价是关键因素的加权和,y = w_1x_1 + w_2x_2 + w_3x_3 + b

权重w和偏差b的实际值在后面决定。

  • 给定n维输入,x=[x_1,x_2, ....x_n]^T,向量x对应于单个数据样本的特征
  • 线性模型有一个n维权重和一个标量偏差,w =[w_1, w_2, ..., w_n]^Tb权重w决定了每个特征对预测值的影响。偏置b是指当所有的特征都取0时,预测值应为多少。
  • 输出是输入的加权和,\hat{y} = w_1x_1+w_2x_2+ ...+ w_nx_n + b。我们常用\hat{y}表示预测值

则,该房价预测模型为:\hat{y} = w^Tx+ b,这是一个线性预测模型。给定一个数据集(如x),我们的目标就是寻找模型的权重w和偏置b,使得根据模型做出的预测大体符合数据中真实价格y。也是就说最佳的权重w和偏置b有能力使得预测值\hat{y}逼近真实值y,找到最佳的权重w和偏置b这是我们的最终目的。

1.2 损失函数(衡量预估质量)

用于比较真实值和预估值的差异,即以特定规则计算真实值和预估值的差值,例如房屋售价和估价。

假设y是真实值,\hat{y}是预测值,平方差损失\ell(y,\hat{y})=(y-\hat{y})^2,我们以该函数作为损失函数。

设训练集有n个样本,则这n个样本的损失均值

             L(w, b)=\frac{1}{n}\sum_{n}^{i=1}\ell^i(y,\hat{y})=\frac{1}{n}\sum_{n}^{i=1}(y_i-\hat{y_i})^2=\frac{1}{n}\sum_{n}^{i=1}(y_i-w^Tx_i+b)^2

Q:那么损失函数,对我们找到最优的权重w和偏置b有什么帮助呢?

我们可以看到,最佳的预测值与真实值之间的损失值一定是尽可能小的,因此我们只要求得最小的损失值,那么得到这个损失值的权重w和偏置b一定是最优的。

Q:怎么求得最小的损失值呢?

如,平方差损失函数是一个凹函数,那么求解最小的损失值,我们只需要将该函数关于w的偏导数设为0,求导即可。求解得到的w就是最优的权重w。预测出的预估值\hat{y}也就最接近真实值。这类解称为解析解。

二、基础优化算法(梯度下降算法)

在绝大多数的情况下,损失函数是很复杂的(比如逻辑回归),根本无法得到参数估计值的表达式,也就无从获取没有显示解(解析解)

此需要一种对大多数函数都适用的方法,这就引出了“梯度下降算法”,这种方法几乎可以优化所有深度学习模型它通过不断地在损失函数递减的方向上更新参数来降低误差(原理)

2.1 梯度下降公式

首先,我们需要确定初始化模型的参数w_0,接下来重复迭代更新参数t=1、2、3、....、n,更新权重的公式为:

其中,\textup{w}_{t-1}为上一次更新权重的结果,\eta为学习率(这是一个超参数,决定了每次参数更新的步长),\frac{\partial \ell }{\partial \textup{w}_{t-1}}损失函数递增的方向(注意公式中为负)。

2.2 选择学习率

梯度下降的过程宛如一个人在走下山路,一步一步地接近谷底,学习率相当于这个人的步长

学习率的选取不易过大,也不宜过小。学习率选取过大会使得权重更新的过程一直在震荡,而不是真正的在下降。学习率选取过小,会使得权重更新的过程十分缓慢,影响效率。

2.3 小批量随机梯度下降

一个神经网络模型的训练可能需要几分钟至数个小时,我们可以采用小批量随机梯度下降的方式来加快这一过程。

在整个训练集上计算梯度太昂贵了,因此可以随机采用 b 个样本i_1,i_2,...,i_b来求取整个训练集的近似损失(原理)。求近似损失公式为:

 其中,b批量大小,另一个重要的超参数。

Q:如何选择批量大小?

选择批量大小不能太小,也不能太大。批量大小选择过小,则每次计算量太小,不适合并行来最大利用计算资源。批量大小选择过大,内存消耗增加浪费计算,例如如果所有样本都是相同的。

三、线性回归的从零开始实现(代码实现)

3.1 生成数据集

首先,我们根据带有噪声的线性模型构造一个人造数据集,我们的目的是通过这个数据集来还原线性模型中正确的参数。

我们使用线性模型参数 \textup{w}=[2,-3.4]^Tb = 4.2​ 和噪声项 \varepsilon 生成数据集及其标签。

# 生成数据集
def synthetic_data(w, b, num_examples):"""生成 y=Xw + b + 噪声"""X = torch.normal(0, 1, (num_examples, len(w))) # 正态分布(均值为0,标准差为1)y = torch.matmul(X, w) + b # 矩阵相乘y += torch.normal(0, 0.01, y.shape) # 加入噪声项# 得到的y为行向量的形式,为了使其变为一列的形式需要进行reshapereturn X, y.reshape((-1, 1))

3.2 传输数据集

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读出的,没有特定的顺序random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i+batch_size,num_examples)])yield features[batch_indices],labels[batch_indices]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/53236.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Toyota Programming Contest 2023#4(AtCoder Beginner Contest 311)D题题解

文章目录 [Grid Ice Floor](https://atcoder.jp/contests/abc311/tasks/abc311_d)问题建模问题分析1.分析移动时前后两个点之间的联系2.方法1通过BFS将所有按照给定运动方式可以到达的点都标记代码 3.方法2采用DFS来标记路径上的点的运动状态代码 Grid Ice Floor 问题建模 给定…

高并发负载均衡---LVS

目录 前言 一:负载均衡概述 二:为啥负载均衡服务器这么快呢? ​编辑 2.1 七层应用程序慢的原因 2.2 四层负载均衡器LVS快的原因 三:LVS负载均衡器的三种模式 3.1 NAT模式 3.1.1 什么是NAT模式 3.1.2 NAT模式实现LVS的缺点…

springboot+vue网红酒店客房预定系统的设计与实现_ui9bt

随着计算机技术发展,计算机系统的应用已延伸到社会的各个领域,大量基于网络的广泛应用给生活带来了十分的便利。所以把网红酒店预定管理与现在网络相结合,利用计算机搭建网红酒店预定系统,实现网红酒店预定的信息化。则对于进一步…

供水管网漏损监测,24小时保障城市供水安全

供水管网作为城市生命线重要组成部分,其安全运行是城市建设和人民生活的基本保障。随着我国社会经济的快速发展和城市化进程的加快,城市供水管网的建设规模日益增长。然而,由于管网老化、外力破坏和不当维护等因素导致的供水管网漏损&#xf…

电脑连接KONICA MINOLTA(柯尼卡美能达) 打印机及驱动安装

电脑系统:Windows 7 安装的打印机型号:Konica minolta bizhub 363 驱动下载:https://www.konicaminolta.com.cn/support/drivers/index.html 打印机配置好网络 1.打开控制面板,或点击桌面开始(WIN)&#x…

Webpack5新手入门简单配置

1.初始化项目 yarn init -y 2.安装依赖 yarn add -D webpack5.75.0 webpack-cli5.0.0 3.新建index.js 说明:写入下面的一句话 console.log("hello webpack"); 4.执行命令 说明:如果没有安装webpack脚手架就不能执行yarn webpack&#xff08…

47.Linux学习day01 基础命令详解1(很全面)

目录 一、Linux和Windows的区别 二、Linux系统目录结构 常见目录说明 三、Linux常见的基础命令 1.pwd 2.cd 3.ls 4.man 5. touch 6.mkdir 7. rmdir 今天正式学习了linux的一些基础操作和基础知识,以及linux和windows的区别。 一、Linux和Windows的区…

Your local changes to the following files would be overwritten by checkout

Git 之 Your local changes to the following files would be overwritten by checkout 今天在切换分支时遇到了这样一个问题: 首先翻译下: Your local changes to the following files would be overwritten by checkout 大致意思就是: 当…

HCIP---OSPF的MGRE实验

一、实验要求: 1、R6为ISP只能配置ip地址,R1-5的环回为私有网段 2、R1/4/5为全连的MGRE结构,R1/2/3为星型的拓扑结构,R1为中心站点 3、所有私有网段可以互相通讯,私有网段使用OSPF协议完成 二、实验步骤 &#xf…

Leetcode-每日一题【剑指 Offer 09. 用两个栈实现队列】

题目 用两个栈实现一个队列。队列的声明如下,请实现它的两个函数 appendTail 和 deleteHead ,分别完成在队列尾部插入整数和在队列头部删除整数的功能。(若队列中没有元素,deleteHead 操作返回 -1 ) 示例 1: 输入: [&…

数据库的约束 详解

一、约束的概述 1.概念:约束是作用于表中字段上的规则,用于限制存储在表中的数据。 2.目的:保证数据库中数据的正确、有效性和完整性。 3.分类: 约束描述关键字非空约束限制该字段的数据不能为nullNOT NULL唯一约束保证该字段的所有数据都是唯一、不…

c51单片机16个按键密码锁源代码(富proteus电路图)

注意了:这个代码你是没法直接运行的,但是如果你看得懂,随便改一改不超过1分钟就可以用 #include "reg51.h" #include "myheader.h" void displayNumber(unsigned char num) {if(num1){P10XFF;P10P11P14P15P160;}else if…