【PyTorch】使用回调和日志记录来监控模型训练

news/2025/1/20 22:39:28/文章来源:https://www.cnblogs.com/o-O-oO/p/18682628

就像船长依赖仪器来保持航向一样,数据科学家需要回调和日志记录系统来监控和指导他们在PyTorch中的模型训练。
在本教程中,我们将指导您实现回调和日志记录功能,以成功训练模型。

一、理解回调和日志记录

回调和日志记录是PyTorch中有效管理和监控机器学习模型训练过程的基本工具。

1.1 回调

在编程中,回调是一个作为参数传递给另一个函数的函数。这允许回调函数在调用函数的特定点执行。在PyTorch中,回调用于在训练循环的指定阶段执行操作,例如一个时期的结束或处理一个批次之后。
这些阶段可以是:

时期结束:当整个训练时期(对整个数据集的迭代)完成时。

批次结束:在一个时期内处理单个数据批次之后。

其他阶段:根据特定回调的实现,它也可能在其他点触发。

回调执行的常见操作包括:

监控:打印训练指标,如损失和准确率。

早停:如果模型性能停滞或恶化,则停止训练。

保存检查点:定期保存模型的状态,以便可能的恢复或回滚。

触发自定义逻辑:根据训练进度执行任何用户定义的代码。

【回调的好处】

模块化设计:回调通过将特定功能与核心训练循环分开封装,促进模块化。这提高了代码组织和可重用性。

灵活性:您可以轻松创建自定义回调以满足特殊需求,而无需修改核心训练逻辑。

定制化:回调允许您根据特定要求和监控偏好定制训练过程。

1.2 日志记录

日志记录是指记录软件执行过程中发生的事件。PyTorch日志记录对于监控各种指标至关重要,以理解模型随时间的性能。
存储训练指标,如:损失值准确率分数学习率其他相关的训练参数

为什么日志记录很重要?
日志记录提供了模型训练历程的历史记录。它允许您:

可视化进度:您可以绘制随时间记录的指标,以分析损失、准确率或其他参数的趋势。

比较实验:通过比较不同训练运行的日志,您可以评估超参数调整或模型变化的影响。

调试训练问题:日志记录有助于识别训练期间的潜在问题,如突然的性能下降或意外的指标值。

二、在PyTorch中实现回调和日志记录

让我们逐步了解如何在PyTorch中实现一个简单的回调和日志记录系统。
步骤1:定义一个回调类
首先,我们定义一个回调类,它将在每个时期的结束时打印一条消息。

class PrintCallback:def on_epoch_end(self, epoch, logs):print(f"Epoch {epoch}: loss = {logs['loss']:.4f}, accuracy = {logs['accuracy']:.4f}")

步骤2:修改训练循环
接下来,我们修改训练循环以接受我们的回调,并在每个时期的结束时调用它。

def train_model(model, dataloader, criterion, optimizer, epochs, callbacks):for epoch in range(epochs):for batch in dataloader:# Training process happens herepasslogs = {'loss': 0.001, 'accuracy': 0.999}  # Example metrics after an epochfor callback in callbacks:callback.on_epoch_end(epoch, logs)

步骤3:实现日志记录
对于日志记录,我们将使用Python内置的日志模块来记录训练进度。

import logging
logging.basicConfig(level=logging.INFO)def log_metrics(epoch, logs):logging.info(f"Epoch {epoch}: loss = {logs['loss']:.4f}, accuracy = {logs['accuracy']:.4f}")

步骤4:将所有内容整合在一起
最后,我们创建我们的回调实例,设置记录器,并开始训练过程。

print_callback = PrintCallback()
train_model(model, dataloader, criterion, optimizer, epochs=10, callbacks=[print_callback])

三、在PyTorch中实现回调和日志记录:示例

示例 1:合成数据集:让我们创建一个代表我们机器人绘画的随机数字的简单数据集。我们将使用PyTorch创建随机数据点。

import torch# Generate random data points
data = torch.rand(100, 3)  # 100 paintings, 3 colors each
labels = torch.randint(0, 2, (100,))  # Randomly label them as good (1) or bad (0)

步骤1:定义一个简单模型
现在,我们将定义一个简单的模型,尝试学习对绘画进行分类。

from torch import nn# A simple neural network with one layer
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.layer = nn.Linear(3, 2)def forward(self, x):return self.layer(x)
model = SimpleModel()

步骤2:设置训练
我们将准备训练模型所需的一切。

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)# DataLoader to handle our dataset
from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=10)

步骤3:实现一个回调
我们将创建一个回调,它在每个时期后打印损失。

class PrintLossCallback:def on_epoch_end(self, epoch, loss):print(f"Epoch {epoch}: loss = {loss:.4f}")

步骤4:使用回调训练
现在,我们将训练模型并使用我们的回调。

def train(model, dataloader, criterion, optimizer, epochs, callback):for epoch in range(epochs):total_loss = 0for inputs, targets in dataloader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()total_loss += loss.item()callback.on_epoch_end(epoch, total_loss / len(dataloader))# Create an instance of our callback
print_loss_callback = PrintLossCallback()
# Start training
train(model, dataloader, criterion, optimizer, epochs=5, callback=print_loss_callback)

输出:

Epoch 0: loss = 0.6927
Epoch 1: loss = 0.6909
Epoch 2: loss = 0.6899
Epoch 3: loss = 0.6891
Epoch 4: loss = 0.6885

步骤5:可视化训练
我们可以绘制随时间变化的损失,以可视化我们机器人的进步。

import matplotlib.pyplot as pltlosses = []  # Store the losses here
class PlotLossCallback:def on_epoch_end(self, epoch, loss):losses.append(loss)plt.plot(losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.show()
# Update our training function to use the plotting callback
plot_loss_callback = PlotLossCallback()
train(model, dataloader, criterion, optimizer, epochs=5, callback=plot_loss_callback)

输出:

示例 2:公共数据集

对于第二个示例,我们将使用在线可用的真实数据集。我们将直接使用URL加载著名的鸢尾花数据集。
步骤1:加载数据集
我们将使用pandas从URL加载数据集。

import pandas as pd# Load the Iris dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
iris_data = pd.read_csv(url, header=None)

步骤2:预处理数据
我们需要将数据转换为PyTorch可以理解的格式。

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split# Encode the labels
encoder = LabelEncoder()
iris_labels = encoder.fit_transform(iris_data[4])
# Split the data
train_data, test_data, train_labels, test_labels = train_test_split(iris_data.iloc[:, :4].values, iris_labels, test_size=0.2, random_state=42
)
# Convert to PyTorch tensors
train_data = torch.tensor(train_data, dtype=torch.float32)
test_data = torch.tensor(test_data, dtype=torch.float32)
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)
# Create DataLoaders
train_dataset = TensorDataset(train_data, train_labels)
test_dataset = TensorDataset(test_data, test_labels)
train_loader = DataLoader(train_dataset, batch_size=10)
test_loader = DataLoader(test_dataset, batch_size=10)

步骤3:为鸢尾花数据集定义一个模型
我们将为鸢尾花数据集创建一个合适的模型。

class IrisModel(nn.Module):def __init__(self):super(IrisModel, self).__init__()self.layer1 = nn.Linear(4, 10)self.layer2 = nn.Linear(10, 3)def forward(self, x):x = torch.relu(self.layer1(x))return self.layer2(x)
iris_model = IrisModel()

步骤4:训练模型
我们将按照之前的步骤训练这个模型。

# Assume the same training function and callbacks as before
train(iris_model, train_loader, criterion, optimizer, epochs=5, callback=plot_loss_callback)

输出:

步骤5:评估模型
最后,我们将检查我们的模型在测试数据上的表现如何。

def evaluate(model, test_loader):model.eval()  # Set the model to evaluation modecorrect = 0with torch.no_grad():  # No need to track gradientsfor inputs, targets in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs, 1)correct += (predicted == targets).sum().item()accuracy = correct / len(test_loader.dataset)print(f"Accuracy: {accuracy:.4f}")evaluate(iris_model, test_loader)

输出:

Accuracy: 0.3333

结论

您可以通过设置回调和日志记录来进行必要的调整,获得对模型训练过程的洞察,并确保其高效学习。请记住,如果您的模型提供明确反馈,您通往训练有素的机器学习模型的道路将更加顺利。本文提供了适合初学者的代码示例和解释,让您基本掌握PyTorch中的回调和日志记录。不要犹豫尝试提供的代码;记住,实践是掌握这些主题的关键。

关注“小白玩转Python”公众号

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/872362.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TBtools的使用

转录本:由一条基因通过转录,由于可变剪切形成的一种或多种可供编码蛋白质的成熟mRNA。另外非编码RNA也可能有转录本的说法。 最长转录本:可变剪切一个基因得到多个序列长度不同的转录本,应选序列最长的进行数据分析。 CDS:蛋白质编码区,是与蛋白序列一一对应的DNA序列,不…

finalshell远程连接Centos虚拟机配置固定ip地址

为虚拟机Centos的远程连接软件Finalshell或者xshell等软件配置固定ip地址提示:然后全点确定,就好了,这里就不演示了输入指令vim /etc/sysconfig/network-scripts/ifcfg-ens33BOOTPROTO=static #将ip设置为静态IPADDR="192.168.142.130" #静态ip地址,这个130是在0…

【scikit-opt库】智能算法开源库

这个开源库包含以下7个优化算法:网址:https://scikit-opt.github.io/ 使用文档:文档链接:https://scikit-opt.github.io/scikit-opt/#/en/README

【模拟电子技术】07-BJT特性曲线共射

【模拟电子技术】07-BJT特性曲线共射 我们要用三极管,那么就必须考虑输入和输出,即考虑输入特性曲线和输出特性曲线UCE固定,考虑UBE和IB的关系,就相当于一个PN结了!考虑多个变量的关系时,我们往往固定其他变量,然后看其中两个变量的关系,然后两两拿出来观察。三极管有三…

ssm日记04

发现springboot实在是太方便了,虽然还没有具体写过案例,不过直接免去一大堆配置真的太舒服了,明天会接着写案例,掌握是ssm整合springboot 这是学习一个小时的视频

ssm日记01

大前天 即2025.1.15-17学习了spring的前后台联合案例 就是springmvc的ssm整合包括异常处理的方法跟着敲了一边代码,不知道之后会不会使用或者说自己敲一遍,或许到自己跟着视频写案例才会使用到,或者说自己写一个小项目的时候才会用到 写这个异常处理受益匪浅,知道了从系统异…

ssm日记

前天即1.18号学习了springmvc的拦截器知识 但是好像用的不多 拦截器是相对于过滤器 是在请求进入web容器之后拦截和审核的其中主要是实现一个接口的三个方法分别是preHandle postHandle afterCompletion有相关执行顺序 这是代码和注释

学习ssm日记

补发一下前几次的学习日记 这是前两天在学maven高级的代码 都是自己敲得跟着视频 学习了分模块开发和父工程的创建以及关于依赖和pom文件相关知识

OpenWRT配置旁路由/中继模式,同时配置作为NAS必备的IPv6公网IP

1. 环境和要达成的目标 1.1 目标 主路由已配置好拨号,DHCP,IPv6 已刷OpenWRT路由B70作为中继路由,提高覆盖,解决一些老旧只能设备接入问题。 OpenWRT路由同时插入移动硬盘,配置WebDAV和smba作为NAS使用,所以此路自身要能获取到IPv6地址。 我的OpenWRT路由是极路由4,刷的…

【网关系统】通用设计

本文准备围绕七个点来讲网关,分别是网关的基本概念、网关设计思路、网关设计重点、流量网关、业务网关、常见网关对比,对基础概念熟悉的朋友可以根据目录查看自己感兴趣的部分。 一、什么是网关 网关,很多地方将网关比如成门, 没什么问题, 但是需要区分网关与网桥的区别,…

一文告诉你Linux下如何用C语言实现ini配置文件的解析和保存

嵌入式项目开发中,会有很多功能模块需要频繁修改参数,Linux下我们可以通过ini格式的文件保存配置信息。 本文通过开源库iniparser,详细讲解如何用C语言实现ini文件的参数解析和配置保存。 本文代码实例获取方式见文末。 一、ini文件 1 什么是 ini文件INI(Initialization F…

Mysql的学习

Mysql建立 索引优化: sql优化: 为了解决下面的索引失效问题序列索引优化: