pytorch 实现线性回归(深度学习)

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:# print('param:', param, 'param.grad:', param.grad)param -= lr * param.grad / batch_sizeparam.grad.zero_()lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print('w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2ldef synthetic_data(w, b, num_examples):x = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(x, w) + bprint('x:', x)print('y:', y)y += torch.normal(0, 0.01, y.shape)  # 噪声return x, y.reshape((-1 , 1))true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')features, labels = synthetic_data(true_w, true_b, 10)def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10
for x, y in data_iter(batch_size, features, labels):print(f'x: {x}, \ny: {y}')w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\nl.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2def linreg(x, w, b):return torch.matmul(x, w) + bdef sgd(params, lr, batch_size):with torch.no_grad():for param in params:print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):for x, y in data_iter(batch_size, features, labels):l = squared_loss(linreg(x, w, b), y)   # 计算总损失print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n# l.sum().backward()  # 计算更新梯度sgd([w, b], lr, batch_size)#     break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/473415.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Midjourney绘图欣赏系列(四)

Midjourney介绍 Midjourney 是生成式人工智能的一个很好的例子,它根据文本提示创建图像。它与 Dall-E 和 Stable Diffusion 一起成为最流行的 AI 艺术创作工具之一。与竞争对手不同,Midjourney 是自筹资金且闭源的,因此确切了解其幕后内容尚不…

数据结构~二叉树(基础知识)

上一篇博客我们对树有了初步了解与学习,这篇我将初步学习二叉树!!(新年快乐!) 目录 二叉树 1、定义: 2、特点: 3、基本形态: 4、二叉树的种类: &…

【探索Linux】—— 强大的命令行工具 P.22(POSIX信号量)

阅读导航 引言一、POSIX信号量的基本概念二、信号量的相关操作1 . 初始化信号量sem_init ( )(1)原型(2)参数(3)返回值(4)示例代码 2 . 等待信号量(1)sem_wait…

数据结构与算法:二叉树(寻找最近公共祖先、寻找后继节点、序列化和反序列化、折纸问题的板子和相关力扣题目)

最近公共祖先 第一版(前提:p和q默认存在于这棵树中) 可以层序遍历每个节点时用个HashMap存储该结点和其直接父节点的信息。然后从p开始溯源,将所有的父节点都添加到一个HashSet集合里。然后从q开始溯源,每溯源一步看…

数据库数据加密的 4 种常见思路的对比

应用层加解密方案数据库前置处理方案磁盘存取环节:透明数据加密DB 后置处理 最近由于工作需要,我对欧洲的通用数据保护条例做了调研和学习,其中有非常重要的一点,也是常识性的一条,就是需要对用户的个人隐私数据做好加…

webpack实际实践优化项目

参考: 如何通过性能优化,将包的体积压缩了62.7% 雅虎35条 20210526-webpack深入学习,搭建和优化react项目 本文只专注于性能优化的这个部分。 总体来说分为两个方面:第一是开发环境中主要优化打包速度,第二是线上环境…

BIG DATA —— 大数据时代

大数据时代 [英] 维克托 迈尔 — 舍恩伯格 肯尼斯 库克耶 ◎ 著 盛杨燕 周涛◎译 《大数据时代》是国外大数据研究的先河之作,本书作者维克托迈尔舍恩伯格被誉为“大数据商业应用第一人”,他在书中前瞻性地指出,大数据带来的信息…

python-自动化篇-运维-网络-IP

文章目录 IP自我介绍IPy安装模块windowsLinux IPy介绍支持大多数 IP 地址格式IPv4 地址IPv6 地址网络掩码和前缀 派生网络地址将地址转换为字符串使用多个网络多网络计算方法 IP自我介绍 IP地址规划是网络设计中非常重要的一个环节,规划的好坏会直接影响路由协议算…

2024年【天津市安全员B证】考试技巧及天津市安全员B证复审模拟考试

题库来源:安全生产模拟考试一点通公众号小程序 2024年天津市安全员B证考试技巧为正在备考天津市安全员B证操作证的学员准备的理论考试专题,每个月更新的天津市安全员B证复审模拟考试祝您顺利通过天津市安全员B证考试。 1、【多选题】《建设行政处罚决定…

力扣刷题之旅:进阶篇(六)—— 图论与最短路径问题

力扣(LeetCode)是一个在线编程平台,主要用于帮助程序员提升算法和数据结构方面的能力。以下是一些力扣上的入门题目,以及它们的解题代码。 --点击进入刷题地址 引言 在算法的广阔天地中,图论是一个非常重要的领域。…

《Go 简易速速上手小册》第8章:网络编程(2024 最新版)

文章目录 8.1 HTTP 客户端与服务端编程 - Go 语言的网络灯塔与探航船8.1.1 基础知识讲解服务端编程客户端编程 8.1.2 重点案例:简易博客服务服务端实现客户端实现运行示例 8.1.3 拓展案例 1:增加文章评论功能功能描述服务端实现客户端实现 8.1.4 拓展案例…

MATLAB导出图程序

本文将以代码的形式快速介绍MATLAB导出图到Paper 1 从simulation导出数 2 与simulation同源文件夹下创建导图m文件 代码如下: % 实验后的数据处理用 M-文件 % clear all % 清空工作空间 % close all      % 关闭所有图形窗口 % load adp.mat …