torch神经网络--线性回归

news/2024/10/5 13:10:55/文章来源:https://www.cnblogs.com/jackchen28/p/18408065

简单线性回归

y = 2*x + 1

import numpy as np
import torch
import torch.nn as nnclass LinearRegressionModel(nn.Module):def __init__(self, input_dim, output_dim):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_dim, output_dim)def forward(self, x):out = self.linear(x)return outx_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1, 1)
x_train.shapey_values = [2*i+1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1, 1)
y_train.shape
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim, output_dim)# 如果使用GPU训练,增加以下两行代码
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# model.to(device)# 指定好参数和损失函数
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()# 训练模型
for epoch in range(epochs):epoch += 1# 使用cpu时,注意转行成tensorinputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)# 如果使用GPU训练,将以上两行代码修改为# inputs = torch.from_numpy(x_train).to(device)# labels = torch.from_numpy(y_train).to(device)# 梯度要清零每一次迭代optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新权重参数optimizer.step()# 打印if epoch % 50 == 0:print('epoch {}, loss {}'.format(epoch, loss.item()))# CPU测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()# 模型的保存
torch.save(model.state_dict(), 'model.pkl')
# 模型读取
model.load_state_dict(torch.load('model.pkl'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/808446.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

找到织梦CMS的数据库配置文件,以便了解数据库的具体连接信息

首先,找到织梦CMS的数据库配置文件,以便了解数据库的具体连接信息。 数据库配置文件路径织梦CMS安装目录假设织梦CMS安装在 /var/www/html 目录下。 数据库配置文件位于 include/config.inc.php。打开配置文件使用FTP工具或服务器上的文件管理器,打开织梦CMS安装目录下的 in…

织梦的数据库在哪,告诉我路径

织梦CMS(DedeCMS)的数据库并不是直接存储在文件系统中的某个特定路径下,而是存储在MySQL数据库服务器中。不过,织梦CMS的数据库配置文件和一些相关文件还是有固定的路径。以下是一些关键路径及其说明: 织梦CMS安装目录 假设你的织梦CMS安装在 /var/www/html 目录下,那么以…

vs code如何配置C/C++环境,实现完美运行.c/.cpp文件,以及终端乱码问题

环境配置 在 Visual Studio Code (VS Code) 中安装了 C/C++ Extension Pack 后,你可以通过以下步骤来运行 C++ 文件:安装编译器配置编译任务:在 VS Code 中,你可以创建一个编译任务来编译你的 C++ 文件。这通常通过创建一个 tasks.json 文件来完成。你可以通过以下步骤创建…

blender拖动视角到一定程度很慢

配置 win11 - blender3.6点击 编辑 - 偏好设置视图切换 - 旋转&平移 - 自动 - 深度(勾选)后期可根据需要进行勾选和取消勾选

查看织梦CMS源码中的数据库相关文件

如果你想查看织梦CMS源码中的数据库相关文件,可以参考以下路径:织梦CMS安装目录/var/www/html 这里包含织梦CMS的所有文件。核心文件/var/www/html/inc 包含一些核心配置文件。 /var/www/html/include 包含数据库配置文件 config.inc.php 和其他核心文件。数据库表前缀默认表…

uv --- replacement of conda + pip (python version + package version install) python版本和包管理集大成者

uv https://docs.astral.sh/uv/An extremely fast Python package and project manager, written in Rust. Installing Trios dependencies with a warm cache. Highlights🚀 A single tool to replace pip, pip-tools, pipx, poetry, pyenv, virtualenv, and more. ⚡️ 10…

织梦怎么进数据库,织梦网站源码在哪里看数据库

假设你的织梦CMS安装在 /var/www/html 目录下,且数据库配置如下:织梦CMS安装目录:/var/www/html数据库配置文件:/var/www/html/include/config.inc.php数据库配置:$cfg_dbhost = localhost; $cfg_dbname = mydatabase; $cfg_dbuser = myusername; $cfg_dbpw = mypassword;…

blender贴图丢失,贴图显示紫色

闲言 一般在模型复制粘贴或转移过程中, 发生贴图加载失败, 导致模型贴图位置显示紫色. 如果是上述相关情况, 那么本文章应能为你提供相关帮助. 本人配置: win11 - blender3.6(本案例演示版本) - blender4.2 打开丢失材质模型(.blend).fbx导入也是一样的, 这里不赘述.打开材质预…

R3CTF2024 WP

一、PWN1.Nullullullllu在直接给 libc_base 的情况下,一次任意地址写 \x00 。直接修改 IO_2_1_stdin 的 _IO_buf_base 末尾为 \x00 ,那么 _IO_buf_base 就会指向 IO_2_1_stdin 的 _IO_write_base,接下来就是利用 getchar 函数触发写操作修改 IO_buf_base 为 IO_2_1_stdout ,…

WMCTF 2024 wp

WEBPasswdStealer前言本来题目叫PasswdStealer的:)考点就是CVE-2024-21733在SpringBoot场景下的利用。漏洞基本原理参考 https://mp.weixin.qq.com/s?__biz=Mzg2MDY2ODc5MA==&mid=2247484002&idx=1&sn=7936818b93f2d9a656d8ed48843272c0不再赘述。SpringBoot场景…

Z-library数字图书馆镜像地址及客户端/app(持续更新)

Z-library数字图书馆介绍 Z-library,被誉为全球范围内最为庞大的数字图书馆之一,其藏书量之丰富令人叹为观止,总计囊括了超过9,826,996册电子书及84,837,646篇学术期刊文章。这座庞大的知识宝库覆盖了从经典文学巨著到前沿理工学科,从人文艺术瑰宝到专业学术论文的广泛领域…

T3 玄泡面求调

觉得模拟赛题解还是单独放出来比较好。 A.挤压 好像不难?二进制表示下的平方展开没推出来,不然就成简单题了。 首先我们需要知道对于一个数 \(x\),把它拆成 29 位的二进制形式后,用 \(s_i\) 表示二进制下第 \(i\) 位上的数,那么其实这个数就是 \((\overline{s_{29} s_{28}…