Pytorch笔记之回归

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、结果预测
  • 总结


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法
2、实例化模型,定义损失函数与优化器

# 继承模型
class LinearRegression(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):out = self.fc(x)return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()
2、反向传播计算梯度值:loss.backward()
3、执行参数更新:optimizer.step()
循环迭代,定期输出损失值

print('损失值')
for i in range(1001):out = model.forward(inputs)loss = mse_loss(out, target)optimizer.zero_grad()loss.backward()optimizer.step()if i % 200 == 0:print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()


总结

使用Pytorch进行训练主要的三步:
(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;
(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;
(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/129136.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阿里云服务器ECS详细介绍_云主机_服务器托管_弹性计算

阿里云服务器ECS英文全程Elastic Compute Service,云服务器ECS是一种安全可靠、弹性可伸缩的云计算服务,阿里云提供多种云服务器ECS实例规格,如经济型e实例、通用算力型u1、ECS计算型c7、通用型g7、GPU实例等,阿里云服务器网分享阿…

深入了解归并排序:原理、性能分析与 Java 实现

归并排序(Merge Sort)是一种高效且稳定的排序算法,其优雅的分治策略使它成为排序领域的一颗明珠。它的核心思想是将一个未排序的数组分割成两个子数组,然后递归地对子数组进行排序,最后将这些排好序的子数组合并起来。…

【Godot4.1】Godot实现闪烁效果(Godot使用定时器实现定时触发的效果)

文章目录 准备工作创建Sprite2D创建Timer节点 编写脚本完整代码运行效果 准备工作 如果你希望配置C#编写脚本,可以查看如下教程: Godot配置C#语言编写脚本 创建Sprite2D 首先弄一个用于显示的Sprite2D,右键单击任意节点,然后选…

Windows10打开应用总是会弹出提示窗口的解决方法

用户们在Windows10电脑中打开应用程序,遇到了总是会弹出提示窗口的烦人问题。这样的情况会干扰到用户的正常操作,给用户带来不好的操作体验,接下来小编给大家详细介绍关闭这个提示窗口的方法,让大家可以在Windows10电脑中舒心操作…

react create-react-app v5配置 px2rem (不暴露 eject方式)

环境信息: create-react-app v5 “react”: “^18.2.0” “postcss-plugin-px2rem”: “^0.8.1” 配置步骤: 不暴露 eject 配置自己的webpack: 1.下载react-app-rewired 和 customize-cra-5 npm install react-app-rewired customize-cra…

OpenCV4(C++)—— 仿射变换、透射变换和极坐标变换

文章目录 一、仿射变换1. getRotationMatrix2D()2. warpAffine() 二、透射变换三、极坐标变换 一、仿射变换 在OpenCV中没有专门用于图像旋转的函数,而是通过图像的仿射变换实现图像的旋转。实现图像的旋转首先需要确定旋转角度和旋转中心,之后确定旋转…

排序算法之【归并排序】

📙作者简介: 清水加冰,目前大二在读,正在学习C/C、Python、操作系统、数据库等。 📘相关专栏:C语言初阶、C语言进阶、C语言刷题训练营、数据结构刷题训练营、有感兴趣的可以看一看。 欢迎点赞 &#x1f44d…

TS类中属性的封装

我们在如下的代码中,我们在类中设置属性,创建的对象可以随意修改自身的属性,对象中的属性可以任意被修改导致对象中的数据非常不安全。 // 创建一个Person类 class Person {name: string;age: number;constructor(name: string, age: number…

安卓玩机----解锁system分区 可读写系统分区 magisk面具模块

玩机教程----安卓机型解锁system分区 任意修改删除系统文件 system分区可读写 参考上个博文可以了解到解锁system分区的有关常识。但目前很多机型都在安卓12 13 基础上。其实最简单的方法就在于刷写一个解锁system分区的第三方补丁包。在面具更新不能解锁系统分区的前提下。…

ElasticSearch搜索引擎:数据的写入流程

一、ElasticSearch 写数据的总体流程: (1)ES 客户端选择一个节点 node 发送请求过去,这个节点就是协调节点 coordinating node (2)协调节点对 document 进行路由,通过 hash 算法计算出数据应该…

3D孪生场景搭建:3D漫游

上一篇 文章介绍了如何使用 NSDT 编辑器 制作模拟仿真应用场景,今天这篇文章将介绍如何使用NSDT 编辑器 设置3D漫游。 1、什么是3D漫游 3D漫游是指基于3D技术,将用户带入一个虚拟的三维环境中,通过交互式的手段,让用户可以自由地…

【计算机视觉|人脸建模】学习从4D扫描中获取的面部形状和表情的模型

本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题:Learning a model of facial shape and expression from 4D scans 链接:Learning a model of facial shape and expression from 4D scans | ACM Transactions on Graphics Pe…