第三章 3.4 训练神经网络

news/2024/12/15 14:25:53/文章来源:https://www.cnblogs.com/excellentHellen/p/18604854

 

代码:

# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch
# https://github.com/PacktPublishing/Modern-Computer-Vision-with-PyTorch###################  Chapter Three ######################################## 第三章  读取数据集并显示
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
#########################################################################%matplotlib inline
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision import datasets
data_folder = '~/data/FMNIST' # This can be any directory you want to
# download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True)
tr_images = fmnist.data
tr_targets = fmnist.targets########################################################################
class FMNISTDataset(Dataset):def __init__(self, x, y):x = x.float()x = x.view(-1,28*28)self.x, self.y = x, ydef __getitem__(self, ix):x, y = self.x[ix], self.y[ix]return x.to(device), y.to(device)def __len__(self):return len(self.x)
########################################################################
def get_data():train = FMNISTDataset(tr_images, tr_targets)trn_dl = DataLoader(train, batch_size=32, shuffle=True)return trn_dl
########################################################################from torch.optim import SGD
def get_model():model = nn.Sequential(nn.Linear(28 * 28, 1000),nn.ReLU(),nn.Linear(1000, 10)).to(device)loss_fn = nn.CrossEntropyLoss()optimizer = SGD(model.parameters(), lr=1e-2)return model, loss_fn, optimizer
########################################################################
def train_batch(x, y, model, opt, loss_fn):model.train() # <- let's hold on to this until we reach dropout section# call your model like any python function on your batch of inputsprediction = model(x)# compute lossbatch_loss = loss_fn(prediction, y)# based on the forward pass in `model(x)` compute all the gradients of# 'model.parameters()'
    batch_loss.backward()# apply new-weights = f(old-weights, old-weight-gradients) where# "f" is the optimizer
opt.step()# Flush gradients memory for next batch of calculations
    opt.zero_grad()return batch_loss.item()########################################################################
def accuracy(x, y, model):model.eval() # <- let's wait till we get to dropout section# get the prediction matrix for a tensor of `x` imagesprediction = model(x)# compute if the location of maximum in each row coincides# with ground truthmax_values, argmaxes = prediction.max(-1)is_correct = argmaxes == yreturn is_correct.cpu().numpy().tolist()
########################################################################
trn_dl = get_data()
print(trn_dl)
model, loss_fn, optimizer = get_model()
# ########################################################################
losses, accuracies = [], []
for epoch in range(5):print(epoch)epoch_losses, epoch_accuracies = [], []for ix, (x, y) in enumerate(iter(trn_dl)):#x, y = batchprint("++++++++++++++++++++++++++++++++++++++++++++++")print(ix,x,y)batch_loss = train_batch(x, y, model, optimizer, loss_fn)epoch_losses.append(batch_loss)epoch_loss = np.array(epoch_losses).mean()for ix, batch in enumerate(iter(trn_dl)):x, y = batchis_correct = accuracy(x, y, model)epoch_accuracies.extend(is_correct)epoch_accuracy = np.mean(epoch_accuracies)losses.append(epoch_loss)accuracies.append(epoch_accuracy)
# ########################################################################
#
epochs = np.arange(5)+1
plt.figure(figsize=(20,5))
plt.subplot(121)
plt.title('Loss value over increasing epochs')
plt.plot(epochs, losses, label='Training Loss')
plt.legend()
plt.subplot(122)
plt.title('Accuracy value over increasing epochs')
plt.plot(epochs, accuracies, label='Training Accuracy')
#plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()])
plt.legend()
plt.show()

 

 文心一言修改后的代码

# 导入必要的库
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt# 设置设备为CUDA(如果可用)或CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 下载FashionMNIST数据集
import torchvision
from torchvision import datasetsdata_folder = '~/data/FMNIST'  # 数据集下载目录
fmnist = datasets.FashionMNIST(data_folder, download=True, train=True, transform= torchvision.transforms.ToTensor())
tr_images, tr_targets = fmnist.data, fmnist.targets# 自定义FashionMNIST数据集类
class FMNISTDataset(Dataset):def __init__(self, images, targets):# 将图像数据转换为浮点型,并展平self.images = images.float().view(-1, 28 * 28)self.targets = targetsdef __getitem__(self, index):# 根据索引获取图像和标签,并移动到指定设备image, target = self.images[index], self.targets[index]return image.to(device), target.to(device)def __len__(self):# 返回数据集的大小return len(self.images)# 获取数据加载器
def get_data():dataset = FMNISTDataset(tr_images, tr_targets)data_loader = DataLoader(dataset, batch_size=32, shuffle=True)return data_loader# 获取模型、损失函数和优化器
def get_model():model = nn.Sequential(nn.Linear(28 * 28, 1000),nn.ReLU(),nn.Linear(1000, 10)).to(device)loss_fn = nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 使用更常见的写法return model, loss_fn, optimizer# 训练一个批次的数据
def train_batch(images, targets, model, optimizer, loss_fn):model.train()  # 设置模型为训练模式outputs = model(images)  # 前向传播loss = loss_fn(outputs, targets)  # 计算损失optimizer.zero_grad()  # 清零梯度loss.backward()  # 反向传播optimizer.step()  # 更新权重return loss.item()# 计算准确率
def calculate_accuracy(images, targets, model):model.eval()  # 设置模型为评估模式with torch.no_grad():  # 禁用梯度计算outputs = model(images)_, predicted = torch.max(outputs, 1)is_correct = (predicted == targets).cpu().numpy()return is_correct.tolist()# 主训练循环
trn_dl = get_data()
model, loss_fn, optimizer = get_model()losses, accuracies = [], []
for epoch in range(5):epoch_losses, epoch_accuracies = [], []for images, targets in trn_dl:  # 直接迭代DataLoaderbatch_loss = train_batch(images, targets, model, optimizer, loss_fn)epoch_losses.append(batch_loss)epoch_loss = np.mean(epoch_losses)# 注意:这里不应该再次迭代DataLoader来计算准确率,因为这样会打乱数据顺序# 为了简化,我们只在每个epoch结束时使用一小部分数据来计算准确率(不推荐这样做,仅用于演示)# 正确的做法是使用一个验证集或者在每个epoch结束时不打乱数据地计算整个训练集的准确率# 但为了保持示例的简洁性,我们仍然这样做
    with torch.no_grad():model.eval()  # 设置模型为评估模式(虽然在这个简化的例子中不是必需的,因为我们已经禁用了梯度)# 假设我们只使用前100个批次的数据来计算准确率(这只是一个示例,不推荐这样做)for i in range(100):  # 注意:这里应该有一个更好的方法来处理,比如使用验证集images, targets = next(iter(trn_dl))  # 小心:这会改变DataLoader的状态!if i >= len(trn_dl):  # 如果已经迭代完,则重新迭代(但这在真实场景中是不推荐的)trn_dl = get_data()  # 重新获取数据加载器(不推荐)# 更好的做法是使用一个固定的验证集is_correct = calculate_accuracy(images, targets, model)epoch_accuracies.extend(is_correct)epoch_accuracy = np.mean(epoch_accuracies)losses.append(epoch_loss)accuracies.append(epoch_accuracy)# 绘制损失和准确率曲线
epochs = np.arange(1, 6)  # 正确的epoch范围
plt.figure(figsize=(20, 5))
plt.subplot(1, 2, 1)
plt.title('Loss over Epochs')
plt.plot(epochs, losses, label='Training Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.title('Accuracy over Epochs')
plt.plot(epochs, accuracies, label='Training Accuracy')
plt.legend()
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/853279.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024-2025-1 20241417 《计算机基础与程序设计》第十二周学习总结

2024-2025-1 20241417 《计算机基础与程序设计》第十二周学习总结 作业信息这个作业属于哪个课程 <班级的链接>(如2024-2025-1-计算机基础与程序设计)这个作业要求在哪里 <作业要求的链接>2024-2025-1计算机基础与程序设计第十二周作业这个作业的目标 <复习前…

PbootCMS中如何让后台输入的换行符在前台正确显示?

在PbootCMS中,如果你在后台输入的内容中包含换行符(如 <br>),但前台显示时这些换行符被当作普通文本输出(例如显示为 <br>),你可以通过使用格式化标签来解决这个问题。具体做法是在调用内容的标签中添加 decode=1 参数。例如,如果你原本的代码是 {sort:sub…

如何在PbootCMS中获取搜索页的关键词和搜索结果数量?

在PbootCMS中,你可以通过特定的标签来获取搜索页的关键词和搜索结果的数量。以下是如何使用这些标签的详细说明和一些扩展建议:获取搜索关键词:在搜索页模板search.html中,使用标签{$get.keyword}来获取用户输入的搜索关键词。 例如:html<h1>搜索结果:{$get.keywor…

PbootCMS后台登录验证码看不清怎么办?

在使用PbootCMS时,有时会遇到后台登录验证码看不清的问题。这通常是由于PHP版本不兼容导致的。以下是如何解决这一问题的详细步骤和注意事项。问题原因分析:PHP版本不支持:验证码看不清的问题通常是由于服务器上的PHP版本不支持PbootCMS的验证码生成功能。不同版本的PHP对某…

VS2022 配置openCV方法

第一步下载opencv库解压出来这里不做过多讲解第二步配置环境变量 %path%\build\x64\vc16\bin %path%这个替换成自己的路径 然后打开项目属性设置点击VC++目录 链接器、输入、附件依赖分别添加 前面的是我自己的目录 换成你们自己目录即可 第一步添加 库目录D:\Opencv\ope…

WPF TreeView实现固定表头

1、在WPF中TreeView默认不支持固定表头的我们可以修改样式实现固定表头新建一个TreeListView类 然后继承TreeView代码如下public class TreeListView : TreeView,IDisposable{public TreeListView(){//this.Loaded += TreeListView_Loaded;//this.SizeChanged += TreeListView_…

居家徒手健身

居家徒手健身 力竭组,组间歇2min,动作变形算力竭为一组 第一天:胸+三头 动作: 宽距俯卧撑6组(胸外延) 标准俯卧撑4组胸整体 钻石俯卧撑4组(胸中缝) 板凳臂屈伸4~8组(三头) 第二天:肩 +腿 动作: 折刀俯卧撑6~10组(肩中束) 腰间俯卧撑4~6组 (肩前束) 弹力绳深蹲6组…

个人网站建站日记-集成Markdown编辑器

一次偶然的机会,我体验的到了markdown的便捷,于是乎,我就着手给我的网站闲蛋博客社区集成了Markdown,现在可以自由的切换Markdown与富文本编辑的使用了。这里我特此分享记录下安装使用的过程。 一、安装Markdown编辑器 这里我采用的是md-editor-v3编辑器,目前看来还是很好…

arbitrum 资产桥合约

资产桥的作用 Rollup 的主要流程中,实际上不包含资产桥,也就是说即使没有资产桥,L2依然能正常运行但是此时L1与L2在数据上是完全独立的两条链,L1不理解L2上的数据(L1只保存L2压缩后的数据,不理解数据),L2上也不知道L1上发生了什么(只能拿到区块高度等一些基本信息)。完…

鸿蒙NEXT开发案例:经纬度距离计算

【引言】 在鸿蒙NEXT平台上,我们可以轻松地开发出一个经纬度距离计算器,帮助用户快速计算两点之间的距离。本文将详细介绍如何在鸿蒙NEXT中实现这一功能,通过简单的用户界面和高效的计算逻辑,为用户提供便捷的服务。 【环境准备】 • 操作系统:Windows 10 • 开发工具:De…

C语言中0为假,正数和负数均为真

001、[b20223040323@admin2 test]$ ls test.c [b20223040323@admin2 test]$ cat test.c #include <stdio.h>int main(void) {int i,j,k; ## 三个变量 负数、正数和0i = -5;j = 8;k = 0;if(i){puts("xxxx");}if(j){puts("yyyy");}if(k){puts(&qu…

2024-2025-1(20241321)《计算机基础与程序设计》第十二周学习总结

这个作业属于哪个课程 <班级的链接>(2024-2025-1-计算机基础与程序设计)这个作业要求在哪里 <作业要求的链接>(2024-2025-1计算机基础与程序设计第十二周作业)这个作业的目标 <深刻学习C语言,反思一周学习,温故知新>作业正文 ... 本博客链接https://www.…