Pytorch线性回归教程

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt

生成测试数据

# 长期趋势
def trend(time, slope=0):return slope * time# 季节趋势
def seasonal_pattern(season_time):return np.where(season_time < 0.4,np.cos(season_time * 2 * np.pi),1 / np.exp(3 * season_time))
def seasonality(time, period, amplitude=1, phase=0):season_time = ((time + phase) % period) / periodreturn amplitude * seasonal_pattern(season_time)# 噪声
def noise(time, noise_level=1):return np.random.randn(len(time)) * noise_level
X = torch.arange(1, 1001)
# Y = 0.7 * X + 100 + torch.randn(X.size())
Y = trend(X, 0.3) + seasonality(X, period=365, amplitude=30) + noise(X, 15) + 200
X.shape, Y.shape
(torch.Size([1000]), torch.Size([1000]))
plt.plot(X.numpy(), Y.numpy());

对测试数据进行处理

# 模型的数据的类型需要是32位浮点型
X = X.type(torch.float32)
Y = Y.type(torch.float32)
X.dtype, Y.dtype
(torch.float32, torch.float32)
# 模型的数据需要进行归一化或者标准化,下面是归一化
X = (X - X.min()) / (X.max() - X.min())
Y = (Y - Y.min()) / (Y.max() - Y.min())
plt.plot(X.numpy(), Y.numpy());

定义模型和模型参数

# 线性模型只有两个参数斜率k,和偏置b
# 线性模型的方程为y = k * x + b
k = nn.Parameter(torch.rand(1, dtype=torch.float32))
b = nn.Parameter(torch.rand(1, dtype=torch.float32))
# 下面输出中的requires_grad=True 表示该参数需要计算梯度
# 梯度用于在反向传播中对参数进行优化,优化方法即梯度下降
k, b 
(Parameter containing:tensor([0.6231], requires_grad=True),Parameter containing:tensor([0.0044], requires_grad=True))
def linear_model(x):return k * x + b

梯度下降优化参数

# 可以通过改变学习率lr和epoch_num学习各自的用途
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()# 每个epoch表示把全部的数据过一遍
epoch_num = 2000
for epoch in range(epoch_num):# 获取模型预测结果y_pred = linear_model(X)# 计算损失值loss = loss_func(y_pred, Y)# 将梯度设为0optimizer.zero_grad()# 反向传播,计算梯度loss.backward()# 执行梯度下降,优化参数optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
# detach()函数用于将参数设置为不需要梯度
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

优化模型

class LinearModel(nn.Module):def __init__(self):super().__init__()self.k = nn.Parameter(torch.rand(1, dtype=torch.float32))self.b = nn.Parameter(torch.rand(1, dtype=torch.float32))def forward(self, x):return self.k * x + self.b
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()epoch_num = 2000
for epoch in range(epoch_num):y_pred = model(X)loss = loss_func(y_pred, Y)optimizer.zero_grad()loss.backward()optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

随机梯度下降

# 前面执行梯度下降时,我们是一次将全部的数据都传入模型
# 但在实际应用中,可能会由于数据太大,没法全部传入模型
# 因此,可以一次传入一部分数据,这便是随机梯度下降
# 随机梯度下降的核心是,梯度是期望。期望可使用小规模的样本近似估计。
model = LinearModel()
# 定义优化器,用于更新模型的参数,即传入的k和b
optimizer = torch.optim.SGD([k, b], lr=0.01)
# 损失函数,模型优化的目的
loss_func = nn.MSELoss()# 每个epoch表示把全部的数据过一遍
epoch_num = 2000 
# iter_step表示在一个epoch内抽取几个小规模样本
iter_step = 10
# batch_size表示小规模样本的大小
batch_size = 100
for epoch in range(epoch_num):for i in range(iter_step):random_samples = torch.randint(X.size()[0], (batch_size, ))X_i, Y_i = X[random_samples], Y[random_samples]y_pred = model(X_i)loss = loss_func(y_pred, Y_i)optimizer.zero_grad()loss.backward()optimizer.step()
k, b
(Parameter containing:tensor([0.8825], requires_grad=True),Parameter containing:tensor([0.0419], requires_grad=True))
k2 = k.detach().numpy()[0]
b2 = b.detach().numpy()[0]plt.plot(X, Y);
plt.plot(X, k2 * X + b2);

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/256644.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

微信小程序js数组对象根据某个字段排序

一、排序栗子 注: 属性字段需要进行转换,如String类型或者Number类型 //升序排序 首元素(element1)在前 降序则(element1)元素在后 data data.sort((element1, element2) >element1.属性 - element2.属性 ); 二、代码 Page({/*** 页面的初始数据*/data: {user:…

算能 MilkV Duo开发板实战——opencv-mobile (迷你版opencv库)的移植和应用

前言 OpenCV是一种开源的计算机视觉和机器学习软件库&#xff0c;旨在提供一组通用的计算机视觉工具。它用于图像处理、目标识别、人脸识别、机器学习等领域&#xff0c;广泛应用于计算机视觉任务。 OpenCV-Mobile是OpenCV库的轻量版本&#xff0c;专为移动平台&#xff08;A…

水果党flstudio用什么midi键盘?哪个版本的FL Studio更适合我

好消息&#xff01;好消息&#xff01;特大好消息&#xff01; 水果党们&#xff01;终于有属于自己的专用MIDI键盘啦&#xff01; 万众期待的Novation FLKEY系列 正式出炉&#xff01; 话有点多话&#xff0c;先分享一份干货&#xff0c;尽快下载 FL Studio 21 Win-安装包&…

Android Audio实战——音频链路分析(二十五)

在 Android 系统的开发过程当中,音频异常问题通常有如下几类:无声、调节不了声音、爆音、声音卡顿和声音效果异常(忽大忽小,低音缺失等)等。尤其声音效果这部分问题通常从日志上信息量较少,相对难定位根因。想要分析此类问题,便需要对声音传输链路有一定的了解,能够在链…

qt creator配置opencv库 (MSVC版本)

目录 1. MSVC版本 1.1 使用cmake编译opencv 1.2 再使用visual studio 2019生成opencv的lib,dll 1.3 配置opencv的系统环境变量 1.4 新建qt项目 1. MSVC版本 1.1 使用cmake编译opencv 1.2 再使用visual studio 2019生成opencv的lib,dll 1.3 配置opencv的系统环境变量 D:…

【推荐系统】了解推荐系统的生态(重点:推荐算法的主要分类)

【大家好&#xff0c;我是爱干饭的猿&#xff0c;本文重点介绍推荐系统的关键元素和思维模式、推荐算法的主要分类、推荐系统常见的问题、推荐系统效果评测。 后续会继续分享其他重要知识点总结&#xff0c;如果喜欢这篇文章&#xff0c;点个赞&#x1f44d;&#xff0c;关注一…

万户协同办公平台ezoffice wpsservlet接口任意文件上传漏洞

声明 本文仅用于技术交流&#xff0c;请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;文章作者不为此承担任何责任。 一、漏洞描述 万户ezOFFICE协同管理平台是一个综合信息基础应用平台&am…

【webpack】初始化

webpack 旧项目的问题下一代构建工具 Vite 主角 &#xff1a;webpack安装webpack1&#xff0c;mode的选项2&#xff0c;使用source map 精准定位错误行数3&#xff0c;使用watch mode(观察模式)&#xff0c;自动运行4&#xff0c;使用webpack-dev-server工具&#xff0c;自动刷…

【Flutter】vs2022上开发flutter

在vs上开发flutter&#xff0c;结果扩展仓库上没办法找到Dart&#xff0c;Flutter。 在 这 搜索Dart时也无法找到插件。 最后发现是安装工具出错了 安装了 开发需要的是

UDS诊断 10服务

文章目录 简介诊断会话切换请求和响应1、请求2、子功能3、肯定响应4、否定响应5、特殊的NRC 为什么划分不同会话报文示例UDS中常用 NRC参考 简介 10服务&#xff0c;即 Diagnostic Session Control&#xff08;诊断会话控制&#xff09;服务用于启用服务器中的不同诊断会话&am…

HTML5+CSS3小实例:3D翻转Tab选项卡切换特效

实例:3D翻转Tab选项卡切换特效 技术栈:HTML+CSS 效果: 源码: 【HTML】 <!DOCTYPE html> <html><head><meta http-equiv="content-type" content="text/html; charset=utf-8"><meta name="viewport" content=…

Jsoup爬取HTTPS页面数据资源,并导入数据库(Java)

一、实现思路 示例页面&#xff1a; 2020年12月中华人民共和国县以上行政区划代码 忽略https请求的SSL证书通过Jsoup获取页面标签遍历行标签&#xff0c;分别获取每个行标签的第二个和第三个列标签将获取到的行政代码和单位名称分别插入sql语句占位符执行sql语句&#xff0c…