第1章 线性回归

一、基本概念

1、线性模型

2、线性模型可以看成:单层的神经网络

输入维度:d

输出维度:1

每个箭头代表权重

一个输入层,一个输出层

单层神经网络:带权重的层为1(将权重和输入层放在一起)

3、LOSS

y:真实值

y^:估计值

平方损失:

4、训练数据

n个样本

5、损失学习

训练损失

最小化损失来学习参数

6、显示解

7、总结

二、优化方法

1、梯度下降

2、学习率

不能太大也不能太小

3、小批量 随机梯度下降

4、批量大小

不能太大也不能太小

5、总结

三、代码实现

1、从头开始实现

import matplotlib.pyplot as plt #plt.show()
import random
import torch
from d2l import torch as d2l# 随机生成数据集
# 权重w = 2, -3.4
# 偏差 b = -4.2
def synthetic_data(w, b, num_examples):  #@save""" y=Xw+b+噪声 """# 均值为0,方差为1的随机数;n个样本,列数=wX = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + b# 再加一个随机噪音y += torch.normal(0, 0.01, y.shape)# x和y做成一个列向量返回return X, y.reshape((-1, 1))# 生成训练样本
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)print('features:', features[0],'\nlabel:', labels[0])
d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1)
plt.show()def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))# 这些样本是随机读取的,没有特定的顺序random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]
batch_size = 10
for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break
# true_w = torch.tensor([2, -3.4])
# true_b = 4.2# 初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)# 定义模型
def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + b# 定义损失函数
def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + b# 定义优化函数
def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()
def squared_loss(a, b):y = (a - b) ** 2y /= 2return y# 训练
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # X和y的小批量损失# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,# 并以此计算关于[w,b]的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

2、简洁实现

# 1.生成数据集
import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)# 2.读取数据集
def load_array(data_arrays, batch_size, is_train=True):  #@save"""构造一个PyTorch数据迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)print(next(iter(data_iter)))# 3.定义模型
# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))# 4.初始化模型参数
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)# 5.定义损失函数
loss = nn.MSELoss()# 6. 定义优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)# 7. 训练
num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/327760.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【STM32】RTC实时时钟

1 unix时间戳 Unix 时间戳(Unix Timestamp)定义为从UTC/GMT的1970年1月1日0时0分0秒开始所经过的秒数,不考虑闰秒 时间戳存储在一个秒计数器中,秒计数器为32位/64位的整型变量 世界上所有时区的秒计数器相同,不同时区…

RocketMQ 投递消息方式以及消息体结构分析:Message、MessageQueueSelector

🔭 嗨,您好 👋 我是 vnjohn,在互联网企业担任 Java 开发,CSDN 优质创作者 📖 推荐专栏:Spring、MySQL、Nacos、Java,后续其他专栏会持续优化更新迭代 🌲文章所在专栏&…

Mac 升级ruby 升级brew update

Mac 自身版本是2.x 查看ruby版本号 打开终端 ruby -v 1.brew update 如果报错 这时候brew更新出问题了 fatal: the remote end hung up unexpectedly fatal: early EOF fatal: index-pack failed error: RPC failed; curl 18 HTTP/2 stream 3 was reset fatal: th…

SpringMVC-@RequestMapping注解

0. 多个方法对应同一个请求 RequestMapping("/")public String toIndex(){return "index";}RequestMapping("/")public String toIndex2(){return "index";}这种情况是不允许的,会报错。 1. 注解的功能 RequestMapping注…

SolidUI Gitee GVP

感谢Gitee,我是一个典型“吃软不吃硬”的人。奖励可以促使我进步,而批评往往不会得到我的重视。 我对开源有自己独特的视角,我只参与那些在我看来高于自身认知水平的项目。 这么多年来,我就像走台阶一样,一步一步参与…

Android AAudio

文章目录 基本概念启用流程基本流程HAL层对接数据流计时模型调试 基本概念 AAudio 是 Android 8.0 版本中引入的一种音频 API。 AAudio 提供了一个低延迟数据路径。在 EXCLUSIVE 模式下,使用该功能可将客户端应用代码直接写入与 ALSA 驱动程序共享的内存映射缓冲区…

metaSPAdes,megahit,IDBA-UB:宏基因组装软件安装与使用

metaSPAdes,megahit,IDBA-UB是目前比较主流的宏基因组组装软件 metaSPAdes安装 GitHub - ablab/spades: SPAdes Genome Assembler #3.15.5的预编译版貌似有问题,使用源码安装试试 wget http://cab.spbu.ru/files/release3.15.5/SPAdes-3.15.5.tar.gz tar -xzf SP…

时间序列预测 — VMD-LSTM实现单变量多步光伏预测(Tensorflow):单变量转为多变量预测多变量

目录 1 数据处理 1.1 导入库文件 1.2 导入数据集 ​1.3 缺失值分析 2 VMD经验模态分解 2.1 VMD分解实验 2.2 VMD-LSTM预测思路 3 构造训练数据 4 LSTM模型训练 5 LSTM模型预测 5.1 分量预测 5.2 可视化 时间序列预测专栏链接:https://blog.csdn.net/qq_…

libexif库介绍

libexif是一个用于解析、编辑和保存EXIF数据的库。它支持EXIF 2.1标准(以及2.2中的大多数)中描述的所有EXIF标签。它是用纯C语言编写的,不需要任何额外的库。源码地址:https://github.com/libexif/libexif ,最新发布版本为0.6.24,…

【v8漏洞利用模板】starCTF2019 -- OOB

文章目录 前言参考题目环境配置漏洞分析 前言 一道入门级别的 v8 题目,不涉及太多的 v8 知识,很适合入门,对于这个题目,网上已经有很多分析文章,笔者不再为大家制造垃圾,仅仅记录一个模板,方便…

TYPE-C接口取电芯片介绍和应用场景

随着科技的发展,USB PDTYPE-C已经成为越来越多设备的充电接口。而在这一领域中,LDR6328Q PD取电芯片作为设备端协议IC芯片,扮演着至关重要的角色。本文将详细介绍LDR6328Q PD取电芯片的工作原理、应用场景以及选型要点。 一、工作原理 LDR63…

2024年天津体育学院专升本专业考试体育教育专业素质测试说明

天津体育学院2024年高职升本科招生专业考试体育教育专业素质测试说明 一、测试内容 100米跑、立定跳远、原地推铅球 二、考试规则 1.100米跑可穿跑鞋(短钉),起跑采用蹲踞式。抢跑个人累计犯规两次,取消本项考试资格&#xff0c…

解决使用localhost或127.0.01模拟CORS失效

解决使用localhost或127.0.01模拟CORS失效 前言问题发现问题解决 前言 CORS (Cross-Origin Resource Sharing) 指的是一种机制,它允许不同源的网页请求访问另一个源服务器上的某些资源。通常情况下,如果 JavaScript 代码在一个源中发起了 AJAX 请求&…

css的一些属性

我们在写项目的时候,会遇到多种多样的样式,大部分都是由css来实现的,css可以让我们的页面更美观,css通常是配合HTML使用,代码较为简单! 下面我就给大家举几个较为常用的一些css属性。 1.CSS中怎样让元素圆角化&#…

spring boot 2升级为spring boot 3中数据库连接池druid的问题

目录 ConfigurationClassPostProcessor ConfigurationClassBeanDefinitionReader MybatisPlusAutoConfiguration ConditionEvaluator OnBeanCondition 总结 近期给了一个任务,要求是对现有的 spring boot 2.x 项目进行升级,由于 spring boot 2.x 版…

数据结构和算法-交换排序中的快速排序(演示过程 算法实现 算法效率 稳定性)

文章目录 总览快速排序(超级重要)啥是快速排序演示过程算法实现第一次quicksort函数第一次partion函数到第一次quicksort的第一个quicksort到第二次quicksort的第一个quicksort到第二次quicksort的第二个quicksort到第一次quicksort的第二个quicksort到第…

while猜数字实例——C++版

案例描述&#xff1a;系统随机生成一个1到100之间的数字&#xff0c;玩家进行猜测&#xff0c;如果猜错&#xff0c;提示玩家数字过大或过小&#xff0c;如果猜对恭喜玩家胜利并退出游戏。 逻辑框图&#xff1a; #include<bits/stdc.h> using namespace std; int main()…

The Planets: Mercury

靶场环境 整个靶场的环境&#xff0c;我出现了一点点问题&#xff0c;一直找不到主机的IP地址&#xff0c;后来参考了https://www.cnblogs.com/hyphon/p/16354436.html&#xff0c;进行了相关的配置&#xff0c;最后完成靶机环境的搭建&#xff01; 信息收集 # nmap -sn 192…

Linux Debian12系统gnome桌面环境默认提供截屏截图工具gnome-screenshot

一、简介&#xff1a; 在Debian12中系统gnome桌面环境默认提供一个截图捕获工具screenshot,可以自定义区域截图、屏幕截图、窗口截图和录制视频&#xff0c;截图默认保存在“~/图片/截图”路径下。 可以在应用程序中搜索screenshot,如下图&#xff1a; 也可以在桌面右上角找到…

mysql之视图执行计划

一.视图 1.1视图简介 1.2 创建视图 1.3视图的修改 1.4视图的删除 1.5查看视图 二.连接查询案例 三.思维导图 一.视图 1.1视图简介 虚拟表&#xff0c;和普通表一样使用 MySQL中的视图&#xff08;View&#xff09;是一个虚拟表&#xff0c;其内容由查询定义。与实际表不…