第二章 2.3.1 定义数据集训练神经网络

news/2024/12/12 14:09:46/文章来源:https://www.cnblogs.com/excellentHellen/p/18602309

定义数据集训练神经网络# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################  Chapter Two #######################################
#数据集的库
from torch.utils.data import Dataset, DataLoader 
import torch
import torch.nn as nn
########################################################################
x = [[1,2],[3,4],[5,6],[7,8]]
y = [[3],[7],[11],[15]]X = torch.tensor(x).float()
Y = torch.tensor(y).float()
########################################################################
# 如果有GPU,使用它
device = 'cuda' if torch.cuda.is_available() else 'cpu' X = X.to(device) Y = Y.to(device) print(device) ########################################################################
# 定义我的数据集
class MyDataset(Dataset):def __init__(self,x,y):self.x = x.clone().detach().requires_grad_(True)self.y = y.clone().detach().requires_grad_(False)def __len__(self):return len(self.x)def __getitem__(self, ix):return self.x[ix], self.y[ix] ds = MyDataset(X, Y) # 取出一个样本 print(ds.__getitem__(1)) ######################################################################## dl = DataLoader(ds, batch_size=2, shuffle=True) ######################################################################## # 定义我的神经元网络类 class MyNeuralNet(nn.Module):def __init__(self):super().__init__()self.input_to_hidden_layer = nn.Linear(2,8)self.hidden_layer_activation = nn.ReLU()self.hidden_to_output_layer = nn.Linear(8,1)def forward(self, x):x = self.input_to_hidden_layer(x)x = self.hidden_layer_activation(x)x = self.hidden_to_output_layer(x)return x ########################################################################
#定义我的神经网络对象、损失函数、优化算法
mynet = MyNeuralNet().to(device) loss_func = nn.MSELoss() from torch.optim import SGD opt = SGD(mynet.parameters(), lr = 0.001)######################################################################## import time loss_history = [] start = time.time() # 训练我的神经网络 for _ in range(50):for data in dl:x, y = dataprint(data)opt.zero_grad()loss_value = loss_func(mynet(x),y)loss_value.backward()opt.step()loss_history.append(loss_value) end = time.time() print(end - start) ######################################################################## val_x = [[7,8]] # 使用我训练好的神经网络,计算一个输入、输出 val_x = torch.tensor(val_x).float().to(device)print(mynet(val_x))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/851378.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

纷享销客荣获2024产业互联网百强“千峰奖”

12月6日,亿邦产业互联网峰会“2024产业互联网”千峰奖正式揭晓。本届千峰奖以“供应链变革与AI落地”为主题,聚焦数字供应链、产业数字科技、数智品牌三大细分赛道。 纷享销客作为国内领先的数字化营销技术提供商,凭借扎实的基本面、资本的高度认可、良好的商业价值获得评审…

转载:【AI系统】内存分配算法

本文将介绍 AI 编译器前端优化部分的内存分配相关内容。在 AI 编译器的前端优化中,内存分配是指基于计算图进行分析和内存的管理,而实际上内存分配的实际执行是在 AI 编译器的后端部分完成的。本文将包括三部分内容,分别介绍模型和硬件的内存演进,内存的划分与复用好处,节…

蓝桥杯嵌入式模板创建(STM32 CubeMx简单使用教程)

本人在备赛22年第十二届蓝桥杯嵌入式时所记录的笔记,可能有错漏,欢迎指出问题。当时使用的开发板为蓝桥杯新板STM32G431RBT6,实际上使用STM32F103芯片也可以通过STM32CubeMX快速上手HAL库编程蓝桥杯嵌入式新板模板创建&简单经验分享补充在最前: 以下原文是22年还未毕业…

安川机器人U轴减速机 HW9381465-C维修具体细节

安川机器人U轴减速机 HW9381465-C的维修是一个相对复杂的过程,涉及到多个部件的检查、维修和更换。以下是一些具体细节: 1、故障诊断:对安川机器人U轴减速机 HW9381465-C进行彻底的检查,以确定故障的具体位置和原因。可能出现故障的部件包括齿轮、轴承、油封和润滑系统等…

RNN模型的训练和推理以及成员推断攻击的实现

循环神经网络(RNN) 2024年7月20日更新 在此教程中,我们将对循环神经网络RNN模型及其原理进行一个简单的介绍,并实现RNN模型的训练和推理,目前支持MNIST、FashionMNIST和CIFAR-10等数据集,并给用户提供一个详细的帮助文档。同时,本项目还将实现循环神经网络的模型成员推理攻…

《Django 5 By Example》阅读笔记:p493-p520

《Django 5 By Example》学习第 17 天,p493-p520 总结,总计 28 页。 一、技术总结 1.internationalization(国际化) vs localization(本地化) (1)18n,L10n,g11n 以前总觉得这两个缩写好难记,今天仔细看了下维基百科,"i18n" 中的 i 代表 “internationalization…

图模型的训练和推理以及成员推理攻击的实现

graph_model 2024年10月14日更新 在此教程中,我们将对深度学习中的图模型及其原理进行一个简单的介绍,并实现一种图模型的训练和推理,至少支持三种数据集,目前支持数据集有:Cora、CiteSeer、PubMed等,并给用户提供一个详细的帮助文档。 目录 基本介绍目前存在的问题 现有…

一套以用户体验出发的.NET8 Web开源框架

前言 今天大姚给大家分享一套以用户体验出发的.NET8 Web开源框架:YiFramework。 项目介绍 YiFramework是一个基于.NET8 + Abp.vNext + SqlSugar 的DDD领域驱动设计后端开源框架,前端使用Vue3,项目架构模式三层架构\DDD领域驱动设计,内置RBAC权限管理、BBS论坛社区系统 以用…

链表的一步步实现(需有一部分c语言基础)【缓慢更新中

链表的一步步实现(需有一部分c语言基础) (由于本人上课实在没学懂链表的具体实现步骤,于是写下这篇博客记录学习过程,有兴趣的新手也可以跟着学习 1.认识链表的结构&创建简单静态链表并输出数据 Q:什么是链表? A:链表是由一系列节点组成,每个节点包含两个域,一个…

VGGNet模型的训练和推理

VGGNet 2024年5月10日更新 在此教程中,我们将对VGGNet模型及其原理进行一个简单的介绍,并实VGGNet模型的训练和推理,目前支持数据集有:MNIST、fashionMNIST、CIFAR10等,并给用户提供一个详细的帮助文档。 目录 基本介绍VGGNett描述 创新点 网络结构 VGGNet的特点VGGNet实现…

ResNet模型的训练和推理

ResNet 2024年5月7日更新 在此教程中,我们将对ResNet模型及其原理进行一个简单的介绍,并实现ResNet模型的训练和推理,目前支持数据集有:MNIST、fashionMNIST、CIFAR10等,并给用户提供一个详细的帮助文档。 目录 基本介绍ResNet描述 为什么要引入ResNet? 网络结构分析ResN…

转载:【AI系统】AI编译器前瞻

本文首先会基于 The Deep Learning Compiler: A Comprehensive Survey 中的调研做一个热门 AI 编译器的横向对比,并简要介绍几个当前常用的 AI 编译器。随后会分析当前 AI 编译器面临的诸多挑战,并展望 AI 编译器的未来。 业界主流 AI 编译器对比 在 The Deep Learning Compi…