《PyTorch深度学习实践》第五讲 用PyTorch实现线性回归

b站刘二大人《PyTorch深度学习实践》课程第五讲用PyTorch实现线性回归笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=5&vd_source=b17f113d28933824d753a0915d5e3a90


PyTorch官网教程:https://pytorch.org/tutorials/beginner/pytorch_with_examples.html


PyTorch Fashion

  1. 准备数据集
  2. 设计模型,写成类的形式(nn.Module)
    • 前向传播,计算 y ^ \hat{y} y^
  3. 构造损失函数loss和优化器(使用PyTorch的API)
    • 构造loss用于反向传播;优化器用于更新梯度
  4. 写训练周期(前馈 -> 反馈 -> 更新)

线性回归第一步:准备数据集

  • 在PyTorch中,计算图是采用的mini-batch形式计算
import torchx_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

线性回归第二步:设计模型

  • 线性单元

    • 要确定权重 w w w的维度,则需要知道输入 x x x和输出 y ^ \hat{y} y^的维度;
    image-20230630102424959
  • 将模型定义成一个类

"""
Our model class should be inherit from nn.Module, which is Base class for all neural network modules
模型类都是从nn.Module继承,nn.Module是所有神经网络模型的基类
成员方法至少包含__init__()和forward()
"""
class LinearModel(torch.nn.Module):def __int__(self):# 构造函数,用于初始化对象super(LinearModel, self).__int__()  # super是调用父类的构造,第一个参数LinearModel是类名称self.linear = torch.nn.Linear(1, 1) # 构造对象。nn.Linear包含两个张量成员:权重w和偏置bdef forward(self, x):# 前馈计算y_pred = self.linear(x)	# y_hat,在一个对象(linear)后面加括号,表明实现了一个可调用的对象return y_predmodel = LinearModel()  # 实例化,model是可调用的,如model(x),x会传入forward中
image-20230629213425516
  • in_features:输入样本的维度(特征)

  • out_features:输出样本的维度

    image-20230630104714599
  • *args:表示可变参数,会存放所有未命名的变量参数,在函数调用的时候自动组装为一个元组

  • **kwargs:表示关键字参数,在函数内部自动组装成一个字典

    # 例子:假设定义一个func函数,并定义了形参
    def func(a, b, c, x, y):pass# 在调用的时候,传入的实参必须要和形参对应
    func(1, 2, 3, x=3, y=5)# 问题是如果调用的时候参数更多该怎么办?
    func(1, 2, 4, 3, x=3, y=5) # 比上面多一个值,这样调用就会出错---
    # 对func进行修改,将a,b,c换成*args,那么在调用func的时候所有没有命名的实参都会传到args中
    def func(*args, x, y):pass---
    # 对于x和y这种命名的参数可以写成**kwargs,在调用func的时候命名的实参都会传到kwargs中
    def func(*args, **kwargs):pass
    
    image-20230630113754366
# 定义一个可调用的类
class Foobar:def __init__(self):# 先定义__init__,因为没起作用就写个passpass# 要想对象可调用,则需要定义一个__call__函数。pycharm中会自动提示如下形式#  *args:表示可变参数,会存放所有未命名的变量参数,在函数调用的时候自动组装为一个元组#  **kwargs:表示关键字参数,在函数内部自动组装成一个字典def __call__(self, *args, **kwargs):print("Hello" + str(args[0]))  # 假设就接受args的第一个参数foobar = Foobar()  # 定义一个Foobar类的变量foobar
# 由于类中定义了__call__()函数,所以可以进行如下操作,给foobar传入参数
foobar(1, 2, 3)
image-20230630114332162
  • PyTorch中的Module的call函数里面有一条语句是要调用forward(),因此在我们自己写的module类中必须要实现forward()来覆盖掉父类中的forward()

线性回归第三步:构造loss函数和优化器

criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  • 损失函数使用MSE

    • MSELoss继承自nn.Module,参与计算图的构建
    image-20230630120312661 image-20230630120353109
    • size_average:是否要求均值(可求可不求)
    • reduce:是否要降维(一般只考虑size_average)
  • 优化器使用SGD

    • torch.optim.SGD()是一个类,与nn.Module无关,不参与计算图的构建

      image-20230630120854394
    • model.parameters()是权重

      • model中并没有定义相应的权重,但里面的成员函数linear有权重
      • 方法parameters是继承自Module,它会检查model中的所有成员函数,如果成员中有相应的权重,那就将其都加到最终的训练结果上
    • lr:learning rate,一般都设定一个固定的学习率


线性回归第四步:训练过程

三个步骤:

  • 前馈
  • 反馈
    • 开始反馈前要先将梯度归零
  • 更新
for epoch in range(100):y_pred = model(x_data)              # 前馈:计算y_hatloss = criterion(y_pred, y_data)    # 前馈:计算损失print(epoch, loss.item())optimizer.zero_grad()   # 反馈:在反向传播开始将上一轮的梯度归零loss.backward()         # 反馈:反向传播(计算梯度)optimizer.step()        # 更新权重w和偏置b
image-20230630125150399

完整的代码(包含模型测试和loss曲线绘制)

import torch
import matplotlib.pyplot as plt# 数据集
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])# 用于绘图
epoch_list = []
loss_list =[]class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)return y_predmodel = LinearModel()# criterion = torch.nn.MSELoss(size_average=False) pytorch更新后被弃用了
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练过程
for epoch in range(100):y_pred = model(x_data)              # 前馈:计算y_hatloss = criterion(y_pred, y_data)    # 前馈:计算损失print(epoch, loss.item())epoch_list.append(epoch)loss_list.append(loss.item())optimizer.zero_grad()   # 反馈:在反向传播开始将上一轮的梯度归零loss.backward()         # 反馈:反向传播(计算梯度)optimizer.step()        # 更新权重w和偏置b# 输出权重和偏置
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# 测试模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)# 绘制loss曲线
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
  • 训练100轮:
image-20230630125329678
  • 训练1000轮:
image-20230630125424128

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/7389.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据中心动环监控系统分析与应用

摘要:介绍了数据中心动环监控系统,并结合原理图详细分析。本系统主要对数据中心的电源设备和环境参数 进行监控,如 UPS、蓄电池、配电柜、温湿度、漏水监测等,将实现多机房、微模块远程联网集中监管, 从而为提高数据中…

Linux服务器扩容VG时报错 Couldn‘t create temporary archive name

今天扩容磁盘遇到失败报错。 [rootmysql ~]# vgextend rhel /dev/sdc1 Couldnt create temporary archive name. 原因:磁盘使用100%,无法执行挂载,须预留部分空间出来。解决办法:删掉其中无用文件、log日志继续操作即可。释放空间…

React之hooks

Hooks函数 1.useState():状态钩子。纯函数组件没有状态,用于为函数组件引入state状态, 并进行状态数据的读写操作。 const [state, setState] useState(initialValue); // state:初始的状态属性,指向状态当前值,类似…

6.26学习 es6中的类

学习 es6中的类 1.了解构造函数的属性2.类的继承2.1继承父类实例上的属性2.2继承父类原型上的属性或则方法(公共属性或则方法)2.2.1 Object.create2.2.2 Object.setPrototypeOf 3.es6中的类3.1定义3.2 继承 1.了解构造函数的属性 先上一份代码思考一下它…

数据结构-链表

链表结构 链表结构五花八门,今天我重点给你介绍三种最常见的链表结构,它们分别是:单链表、双向链表和循环链表。我们首先来看最简单、最常用的单链表。 单链表 我们习惯性地把第一个结点叫作头结点,把最后一个结点叫作尾结点。其…

Spring Boot 中的事务超时时间

Spring Boot 中的事务超时时间 在 Spring Boot 中,事务管理是一个非常重要的话题。当我们在数据库中执行一些复杂的操作时,需要确保这些操作能够在一定的时间内完成,否则可能会导致数据一致性问题。为了解决这个问题,Spring Boot…

springboot 整合mybatis plus,使用druid 切换多数据源实现单数据库事务,附赠项目源码地址

项目源码地址 GitHub - liyanlei58/ssm: springboot druid mybatis plus 事务 最近想搭一套spring cloud开发环境,各种不顺利吧,先是spring cloud的组件某些功能不好用,是版本自身的bug。后来又碰到了事务无法回滚,这个搞了好几个…

windows的环回网卡(loopback adapter) 安装方法

0.说明:windows的环回网卡(loopback adapter)的作用: microsoft loopback adapter就是安装在本机上的一块虚拟网卡,它跟本机上的其它物理网卡、和物理网卡连接的网络是没有关系的,你可以理解成这块网卡上的网线接到了另外一个空白…

Python 字节数组方式写入kafka(含报错return ‘<SimpleProducer batch=%s>‘ % self.async)

一、背景 项目开发了一个类似kafka tools查询工具的kafka 查询,现在需要测试一下如果通过字节数组的形式写入,看看查询有没有问题 二、kafka查询代码 Python代码示例: from kafka import KafkaProducer import json# 创建Kafka生产者 pro…

美好未来“一束光”儿童安全教育项目在四川泸定正式启动

6月26日,由中华少年儿童慈善救助基金会和北京臻爱公益基金会共同发起的美好未来计划“一束光”儿童安全教育公益项目启动仪式,在四川省甘孜藏族自治州泸定县贡嘎山片区寄宿制学校举行。 出席本次启动仪式活动的嘉宾有:中华少年儿童慈善救助基…

【Spring Boot 事务】万字详解Spring Boot 事务,赶快跟随良辰一起去学习Spring Boot 事务吧! ! !

前言: 大家好,我是良辰丫,这篇文章我将带领大家一起去学习Spring Boot 事务文章,我们在学习数据库的时候已经接触过事务了,来跟随我的脚步一起来瞧一下Spring Boot 事务吧.💌💌💌 🧑个人主页:良辰针不戳 📖…

element框架select值更新页面不回显的问题,动态表单props绑定问题

1、页面中使用form表单&#xff0c;引入select组件 当data中默认没有定义form.region的值时&#xff0c;会出现选择select后input没有回显选择数据值&#xff1b;所以使用select时&#xff0c;必须定义默认值 <el-form ref"form" :model"form" label-…