《PyTorch深度学习实践》第八讲加载数据集

一、

1、DataSet 是抽象类,不能实例化对象,主要是用于构造我们的数据集

2、DataLoader 需要获取DataSet提供的索引[i]和len;用来帮助我们加载数据,比如说做shuffle(提高数据集的随机性),batch_size,能拿出Mini-Batch进行训练。它帮我们自动完成这些工作。DataLoader可实例化对象。DataLoader is a class to help us loading data in Pytorch.

3、__getitem__目的是为支持下标(索引)操作
 

二、

import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader# prepare datasetclass DiabetesDataset(Dataset):def __init__(self, filepath):xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)self.len = xy.shape[0] # shape(多少行,多少列)self.x_data = torch.from_numpy(xy[:, :-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.lendataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=0) #num_workers 多线程# design model using classclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.linear1 = torch.nn.Linear(8, 6)self.linear2 = torch.nn.Linear(6, 4)self.linear3 = torch.nn.Linear(4, 1)self.sigmoid = torch.nn.Sigmoid()def forward(self, x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# training cycle forward, backward, update
if __name__ == '__main__':for epoch in range(100):for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batchinputs, labels = datay_pred = model(inputs)loss = criterion(y_pred, labels)print(epoch, i, loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

1、需要mini_batch 就需要import DataSet和DataLoader

2、继承DataSet的类需要重写init,getitem,len魔法函数。分别是为了加载数据集,获取数据索引,获取数据总量。

3、DataLoader对数据集先打乱(shuffle),然后划分成mini_batch。

4、len函数的返回值 除以 batch_size 的结果就是每一轮epoch中需要迭代的次数。

5、inputs, labels = data中的inputs的shape是[32,8],labels 的shape是[32,1]。也就是说mini_batch在这个地方体现的

6、diabetes.csv数据集老师给了下载地址,该数据集需和源代码放在同一个文件夹内。

问题:loss没有收敛

网友解决:

做了两个实验:(1)输出每批次的loss,不收敛,loss在0.6上下浮动(2)每个epoch都不分批,把所有样本都输入,收敛,最后结果在0.6附近。所以猜测:小样本之间的loss差距相对于0.6而言有点大,所以看着像是没收敛,实际上从总loss来看已经收敛了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/502217.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

simple-pytest 框架使用指南

simple-pytest 框架使用指南 一、框架介绍简介框架理念:框架地址 二、实现功能三、目录结构四、依赖库五、启动方式六、使用教程1、快速开始1.1、创建用例:1.2、生成py文件1.3、运行脚本1.3.1 单个脚本运行1.3.2 全部运行 1.4 报告查看 2、功能介绍2.1、…

dll文件丢失如何修复,电脑dll文件作用及解决方法分享

dll 是一个属于 Microsoft Visual C Redistributable Package 的动态链接库 (DLL) 文件。这个文件是 Microsoft Visual Studio 版本中的 C 运行时库的一部分,具体对应的是 Visual Studio 2013 或更高版本。它的主要功能在于提供并行计算支持,特别是对于使…

GoFrame:如何简单地搭建一个简单地微服务

一切资料来源于GoFrame官网, 感兴趣的, 可以直接去官网查阅相关资料。 首先下载框架工具, 下载地址:https://github.com/gogf/gf/releases 然后进入你想要放置的项目文件夹, 执行命令行 gf init {project_name} #project_name为你的项目名 执行完后项目结构如图所示 然…

MySQL5.7.44版本压缩包在Win11系统快速安装

一.背景 主要还是为了公司的带徒弟任务。我自己也喜欢MySQL的绿色版本。 1.软件版本说明 MySQL版本:5.7.44 压缩包版本,相当于绿色版。当然,你也可以使用window系统的Installer版本去安装。 操作系统:Win11家庭版 二.MySQL软…

【Linux进程】进程状态(运行阻塞挂起)

目录 前言 1. 进程状态 2. 运行状态 3. 阻塞状态 4. 挂起状态 5. Linux中具体的状态 总结 前言 在Linux操作系统中,进程状态非常重要,它可以帮助我们了解进程在系统中的运行情况,从而更好地管理和优化系统资源,在Linux系统中&am…

使用maven项目引入jQuery

最近在自学 springBoot ,期间准备搞一个前后端不分离的东西,于是需要在 maven 中引入jQuery 依赖,网上百度了很多,这里来做一个总结。 1、pom.xml 导入依赖 打开我们项目的 pom.xml 文件,输入以下坐标。这里我使用的是…

ABAP - OOALV 用户交互事件

当用户要根据ALV进行某些功能操作比如打印表单时,OOALV标准按钮无法满足用户需求的时候,就要用到自定义按钮来实现了。思路:在OOALV增加一个自定义按钮,类CL_GUI_ALV_GRID提供了内置事件toolbar来完成,通过自定义按钮的…

SpringBoot+Vue实战:打造企业级项目管理神器

✍✍计算机编程指导师 ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java实战 |…

c# 获取源码路径与当前程序所在路径

获取源码路径 private static string GetFilePath([CallerFilePath] string path null) {return path;}//当程序所在路径string str67 System.Environment.CurrentDirectory;//源码路径 var path GetFilePath();var directory Path.GetDirectoryName(path);参考

作业1-224——P1331 海战

思路 深搜的方式&#xff0c;让它只遍历矩形块&#xff0c;然后在下面的遍历中判断是否出现矩形块交叉&#xff0c;但是很难实现&#xff0c;然后发现可以通过在遍历过程中判断是否合法。 参考代码 #include<iostream> #include<cstdio> using namespace std; …

vue3 构建项目

一.使用vite构建&#xff1a; npm init vitelatest 项目名称 构建的项目模板 进入项目 cd 项目名称 安装项目依赖包 npm install 启动项目 npm run dev 二.使用vue脚手架构建&#xff1a; npm init vuelatest 后续基本差不多

你真的了解C语言中的【柔性数组】吗~

柔性数组 1. 什么是柔性数组2. 柔性数组的特点3. 柔性数组的使用4. 柔性数组的优势 1. 什么是柔性数组 也许你从来没有听说过柔性数组这个概念&#xff0c;但是它确实是存在的。 C99中&#xff0c;结构体中的最后⼀个元素允许是未知大小的数组&#xff0c;这就叫做柔性数组成员…