7 pytorch DataLoader, TensorDataset批数据训练方法

前言

本文主要介绍pytorch里面批数据的处理方法,以及这个算法的效果是什么样的。具体就是要弄明白这个批数据选取的算法是在干什么,不会涉及到网络的训练。

from torch.utils.data import DataLoader, TensorDataset

主要实现就是上面的数据集和数据载入两个类来实现该算法功能,这里只要求会调用接口就够了。

一、生成数据集

import torch
from torch.utils.data import DataLoader, TensorDataset
# 准备数据集与定义batch_size
batch_size = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
print(x)
print(y)

输出:
在这里插入图片描述

二、将训练数据进行batch处理

# 将训练数据放入torch的数据集
train_dataset = TensorDataset(x, y)
# 载入batch批次选取数据规则
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,   # True表示每一个epoch都打乱抽取num_workers=2   # 定义工作线程个数)

三、epoch训练

# 训练模型
epochs = 3
for epoch in range(epochs):# 每一个epoch表示将整个数据集所有数据都训练一遍for step,(batch_x, batch_y) in enumerate(train_loader):# training......# 这里用enumerate是为了让你更加情况观察,batch的逻辑是怎么样的# 实际中只要  for batch_x,batch_y in train_loader就可以了print('Epoch:',epoch,'| Step:',step,'| batch x:',batch_x.data.numpy(),'| batch y:',batch_y.data.numpy())
# 测试模型(略)

输出:
在这里插入图片描述
【注】:可以看到每一个epoch将所有样本点都涉及到了一次,并且还是打乱顺序了的。
下面看看将shuffle=False不打乱顺序会发生什么:
在这里插入图片描述
【注】:可以看到每一个epoch,都是相同的结果,可想而知这样训练效果肯定没有打乱的好。
注意到,上半batch=5,恰好将样本总数10均分为2分,那么要是不能均分会发生什么,下面将batch=8,看看会发生什么。
在这里插入图片描述
可以看到直接将不够的组就直接剩下的了。

总结

后面我们会经常用到这种batch和epoch的训练方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/624773.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

maven 基础用法 (终端界面和IDEA界面)

目录 maven定义 Maven环境配置 仓库 本地仓库 关于pom.xml 运行方式 关于maven在IDEA创建 maven定义 Maven 是一个项目管理和整合工具。通过对 目录结构和构建生命周期 的标准化, 使开发团队用极少的时间就能够自动完成工程的基础构建配置。 ​ Maven 简化了…

计算机网络的七层模型

序 OSl(Open System Interconnect),即开放式系统互联。一般都叫OSI参考模型。在网络编程中最重要的模型就是OSI七层网络模型和TCP/IP四层网络模型 一、OSI七层参考模型以及功能概述 二、各层的具体职能以及实际应用 1.应用层: OSI参考模型中最接近用…

Docker 入门介绍及简单使用

Docker 的简单介绍 中文官网:Docker中文网 官网 英文官网:Docker: Accelerated Container Application Development Docker 是一个开源的应用容器引擎,它允许开发者打包应用及其依赖包到一个可移植的容器中,然后发布到任何流行的 …

C语言单向链表的经典算法

1.分割链表 2.移除链表元素 3.反转链表 4.合并两个有序链表 5.链表的中间结点 6.环形链表的约瑟夫问题 1.分割链表: 1.思路:创建新链表,小链表和大链表。如图 代码如下 /*** Definition for singly-linked list.* struct ListNode {* int val…

【读论文】【泛读】三篇生成式自动驾驶场景生成: Bevstreet, DisCoScene, BerfScene

文章目录 1. Street-View Image Generation from a Bird’s-Eye View Layout1.1 Problem introduction1.2 Why1.3 How1.4 My takeaway 2. DisCoScene: Spatially Disentangled Generative Radiance Fields for Controllable 3D-aware Scene Synthesis2.1 What2.2 Why2.3 How2.4…

如何在PPT中获得网页般的互动效果

如何在PPT中获得网页般的互动效果 效果可以看视频 PPT中插入网页有互动效果 当然了,获得网页般的互动效果,最简单的方法就是在 PPT 中插入网页呀。 那么如何插入呢? 接下来为你讲解如何获得(此方法在 PowerPoint中行得通&#…

Unity 点击次数统计功能

介绍 💡.调用方便,发生点击事件后直接通过"xxx".CacheClick缓存 💡. 在允许的时间间隔内再次点击会累计点击次数,直到超出后触发事件 传送门👈

记录一下hive跑spark的insert,update语句报类找不到的问题

我hive能正常启动,建表没问题,我建了一个student表,没问题,但执行了下面一条insert语句后报如下错误: hive (default)> insert into table student values(1,abc); Query ID atguigu_20240417184003_f9d459d7-199…

【GD32】_时钟架构及系统时钟频率配置

文章目录 一、有关时钟源二、系统时钟架构三、时钟树分析四、修改参数步骤1、设置外部晶振2、选择外部时钟源。3、 设置系统主频率大小4、修改PLL分频倍频系数 学习系统时钟架构和时钟树,验证及学习笔记如下,如有错误,欢迎指正。主要记录了总…

基于springboot的扶贫助农系统

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式 🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 &…

用海外云手机高效率运营TikTok!

很多做国外社媒运营的公司,想要快速引流,往往一个账号是不够的,多数都是矩阵养号的方式,运营多个TikToK、Facebook、Instagram等账号,慢慢沉淀流量变现,而他们都在用海外云手机这款工具! 海外云…

二级综合医院云HIS系统源码,B/S架构,采用JAVA编程,集成相关医保接口

二级医院云HIS系统源码 云HIS系统是一款满足基层医院各类业务需要的健康云产品。该产品能帮助基层医院完成日常各类业务,提供病患预约挂号支持、病患问诊、电子病历、开药发药、会员管理、统计查询、医生工作站和护士工作站等一系列常规功能,还能与公卫…