第三章 3.6 批大小的影响

news/2024/12/16 10:30:28/文章来源:https://www.cnblogs.com/excellentHellen/p/18604859

第三章 3.4 训练神经网络

 

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################  Chapter Three ######################################## 第三章  读取数据集并显示
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
########################################################################
from torchvision import datasets
import torch
data_folder = '~/data/FMNIST' # This can be any directory you want to
# download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)
tr_images = fmnist.data
tr_targets = fmnist.targetsval_fmnist = datasets.FashionMNIST(data_folder, download=True, train=False)
val_images = val_fmnist.data
val_targets = val_fmnist.targets########################################################################
import matplotlib.pyplot as plt
#matplotlib inline
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
device = 'cuda' if torch.cuda.is_available() else 'cpu'########################################################################
class FMNISTDataset(Dataset):def __init__(self, x, y):x = x.float()/255 #归一化x = x.view(-1,28*28)self.x, self.y = x, ydef __getitem__(self, ix):x, y = self.x[ix], self.y[ix]return x.to(device), y.to(device)def __len__(self):return len(self.x)from torch.optim import SGD, Adam
def get_model():model = nn.Sequential(nn.Linear(28 * 28, 1000),nn.ReLU(),nn.Linear(1000, 10)).to(device)loss_fn = nn.CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=1e-2)return model, loss_fn, optimizerdef train_batch(x, y, model, optimizer, loss_fn):model.train()prediction = model(x)batch_loss = loss_fn(prediction, y)batch_loss.backward()optimizer.step()optimizer.zero_grad()return batch_loss.item()def accuracy(x, y, model):model.eval()# this is the same as @torch.no_grad# at the top of function, only difference# being, grad is not computed in the with scope
    with torch.no_grad():prediction = model(x)max_values, argmaxes = prediction.max(-1)is_correct = argmaxes == yreturn is_correct.cpu().numpy().tolist()########################################################################
def get_data():train = FMNISTDataset(tr_images, tr_targets)trn_dl = DataLoader(train, batch_size=10000, shuffle=True) # 比较批大小分别是10000,和32是,损失函数值 和 准确度值,注意:数据集只有60000个样本val = FMNISTDataset(val_images, val_targets)val_dl = DataLoader(val, batch_size=len(val_images), shuffle=False)return trn_dl, val_dl
########################################################################
#@torch.no_grad()
def val_loss(x, y, model):with torch.no_grad():prediction = model(x)val_loss = loss_fn(prediction, y)return val_loss.item()########################################################################
trn_dl, val_dl = get_data()
model, loss_fn, optimizer = get_model()########################################################################
train_losses, train_accuracies = [], []
val_losses, val_accuracies = [], []
for epoch in range(5):print(epoch)train_epoch_losses, train_epoch_accuracies = [], []for ix, batch in enumerate(iter(trn_dl)):x, y = batchbatch_loss = train_batch(x, y, model, optimizer, loss_fn)train_epoch_losses.append(batch_loss)train_epoch_loss = np.array(train_epoch_losses).mean()for ix, batch in enumerate(iter(trn_dl)):x, y = batchis_correct = accuracy(x, y, model)train_epoch_accuracies.extend(is_correct)train_epoch_accuracy = np.mean(train_epoch_accuracies)for ix, batch in enumerate(iter(val_dl)):x, y = batchval_is_correct = accuracy(x, y, model)validation_loss = val_loss(x, y, model)val_epoch_accuracy = np.mean(val_is_correct)train_losses.append(train_epoch_loss)train_accuracies.append(train_epoch_accuracy)val_losses.append(validation_loss)val_accuracies.append(val_epoch_accuracy)########################################################################
epochs = np.arange(5)+1
import matplotlib.ticker as mtick
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
#%matplotlib inline
plt.figure(figsize=(20,5))
plt.subplot(211)
plt.plot(epochs, train_losses, 'bo', label='Training loss')
plt.plot(epochs, val_losses, 'r', label='Validation loss')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation loss when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
#plt.show()
plt.subplot(212)
plt.plot(epochs, train_accuracies, 'bo', label='Training accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation accuracy')
plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))
plt.title('Training and validation accuracy when batch size is 32')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()])
plt.legend()
plt.grid('off')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/853658.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue2 脚手架安装及使用

1.安装npm install -g @vue/cli 2.查看版本vue -V 3.使用3.1 命令形式vue create my-project 3.2可视化操作

.NET8升级.NET9,CodeFirst模式迁移Add-Migration执行Update-DataBase报错

在做netcore开发时,如果net8一直是正常的,只升级了一下框架net9,在使用Entity Framework Core的Code First模式进行迁移时,执行Add-Migration后尝试使用Update-DataBase时出现了如下错误。Unhandled exception. System.InvalidOperationException: An error was generated …

响应式圆形js轮播图插件

jcircle.js是一款响应式圆形js轮播图插件。该轮播图插件能够将图片或文字以圆形轮播图的方式进行展示。并且该轮播图以响应式设计,可以自动进行圆形轮播。在线演示 下载使用方法 在页面中引入jCircle.css和jCircle.min.js文件<link href="jCircle.css" rel=&quo…

阿里云联合中国信通院等单位发布首个云计算智能化可观测性能力成熟度模型标准

随着云计算技术与现代企业技术架构的飞速发展,IT 运维场景愈发多元与复杂,需要观测的对象、观测数据类型、数据规模、数据结构复杂度相较于传统监控发生了翻天覆地的变化。这给企业可观测性的准确、实时、高效与智能化发展带来了巨大挑战。如何借助大模型等智能化技术成为应对…

交易系统:应用层、领域层分层架构设计

大家好,我是汤师爷~ 线上线下交易系统的应用架构包括终端、应用层、领域层和关联系统。应用层能力 应用层定义软件的应用功能,负责接收用户请求、协调领域层执行任务并返回结果。主要包括以下模块: 1)C端服务模块 为消费者提供完整的交易链路功能,包括加购、下单、支付、结…

卖点

什么是埋点? 埋点是一种用于跟踪用户在网站或应用中行为的数据采集技术,通过记录点击、浏览等操作,帮助团队进行用户行为分析、AB 实验、错误监听,指导优化方向和资源分配 监控类型 基于要监控的内容,可以分为:数据监控、性能监控、异常监控 上报方式 手动上报在用户点击…

vue3项目构建流程

1.项目包管理工具选择pnpm npm i -g pnpm 2.选择用vite管理项目 注意node的版本需要16+,项目才能正常使用,在cmd中输入pnpm create vite命令,按照指示创建初始项目 3.下载eslint项目代码校验 执行pnpm i eslint -D安装eslint依赖,然后执行命令npx eslint --init生成配置文件…

APS计划排产在金属制管行业的应用及效益提升分析

金属管材是众多工业产品、基础设施建设的基础材料,其所属的金属制管行业通过自身技术创新、质量管控,带动整个制造业的进步。金属制管行业是典型的离散行业特征,具有多品种、小批量、非标定制的生产特性,其工艺复杂和多样性给从事计划生产的人员带来了巨大的挑战。今天我们…

软件开发项目管理(从立项到上线的全流程解析)

软件开发项目管理(从立项到上线的全流程解析)图1 传统软件开发流程研发项目流程是组织研发活动的重要方式,可以帮助企业高效地开展研发工作,实现研发成果的快速转化。本文将介绍研发项目流程的八个阶段,包括规划阶段、需求分析阶段、设计阶段、编码阶段、测试阶段、部署阶…

JavaWeb-1 动态网页开发

1.新建动态网页项目 2.新建类作者:太一吾鱼水 宣言:在此记录自己学习过程中的心得体会,同时积累经验,不断提高自己! 声明:博客写的比较乱,主要是自己看的。如果能对别人有帮助当然更好,不喜勿喷! 文章未经说明均属原创,学习笔记可能有大段的引用,一…

如何在易优CMS中调试新模板?

在易优CMS中调试新模板是一项重要的任务,以确保新模板能够正常工作并且不会影响网站的其他功能。以下是一些详细的步骤和建议,帮助你在易优CMS中调试新模板:备份现有数据:在开始调试之前,务必备份现有的数据库和文件。这样即使在调试过程中出现问题,也可以快速恢复到原来…

如何在易优CMS中创建分页加载文件 formreply_block001.htm?

在易优CMS中,为了实现分页加载功能,需要创建一个分页加载文件 formreply_block001.htm。以下是详细的步骤和说明:创建文件:在模板目录 pc/system 或 mobile/system 下创建一个名为 formreply_block001.htm 的文件。具体路径如下:PC端:模板目录/pc/system/formreply_block…