6-2训练模型的三种方法

news/2025/3/18 1:29:50/文章来源:https://www.cnblogs.com/lotuslaw/p/18341856

Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。

有三类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类型是训练循环。

下面以minis数据集的多分类模型的训练为例,演示这3种训练模型的风格。

import torch
import torchkerasprint('torch.__version__=' + torch.__version__)
print('torchkeras.__version__=' + torchkeras.__version__)"""
torch.__version__=2.3.1+cu121
torchkeras.__version__=3.9.6
"""

1.准备数据

import torch
from torch import nn
import torchvision
from torchvision import transformstransform = transforms.Compose([transforms.ToTensor()])ds_train = torchvision.datasets.MNIST(root='./dataset/mnist/', train=True, download=True, transform=transform)
ds_val = torchvision.datasets.MNIST(root='./dataset/mnist/', train=False, download=True, transform=transform)dl_train = torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=4)print(len(ds_train))
print(len(ds_val))"""
60000
10000
"""
%matplotlib inline
%config InlineBackend.figure_format = 'svg'# 查看部分样本
import matplotlib.pyplot as pltplt.figure(figsize=(8, 8))
for i in range(9):img, label = ds_train[i]img = torch.squeeze(img)ax = plt.subplot(3, 3, i+1)ax.imshow(img.numpy())ax.set_title("label = %d" % label)ax.set_xticks([])ax.set_yticks([])
plt.show()

2.脚本风格

脚本风格的训练训练非常常见。

net = nn.Sequential()
net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3))
net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("conv2",nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5))
net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
net.add_module("dropout",nn.Dropout2d(p = 0.1))
net.add_module("adaptive_pool", nn.AdaptiveMaxPool2d((1,1)))
net.add_module("flatten",nn.Flatten())
net.add_module("linear1",nn.Linear(64,32))
net.add_module("relu",nn.ReLU())
net.add_module("linear2",nn.Linear(32,10))print(net)"""
Sequential((conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(dropout): Dropout2d(p=0.1, inplace=False)(adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))(flatten): Flatten(start_dim=1, end_dim=-1)(linear1): Linear(in_features=64, out_features=32, bias=True)(relu): ReLU()(linear2): Linear(in_features=32, out_features=10, bias=True)
)
"""
import os, sys, time
import numpy as np
import pandas as pd
import datetime
from tqdm import tqdm
import torch
from torch import nn
from copy import deepcopy
from torchmetrics import Accuracy

如果手动应用了 Softmax:使用 nn.NLLLoss,且其输入是 log 概率(log(softmax(x)))。

如果没有手动应用 Softmax:直接使用 nn.CrossEntropyLoss,输入为未经过处理的 logits。

通常情况下,为了避免不必要的复杂性和可能的数值问题,建议不要手动应用 Softmax,而是直接使用 nn.CrossEntropyLoss。

def printlog(info):nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print("\n" + "========"*8 + "%s" % nowtime)print(str(info) + "\n")loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
metrics_dict = {"acc": Accuracy(task="multiclass", num_classes=10)}epochs = 20
ckpt_path = 'checkpoint.pt'# early_stopping相关设置
monitor = "val_acc"
patience = 5
mode = "max"history = {}for epoch in range(1, epochs+1):print("Epoch {0} / {1}".format(epoch, epochs))# 1 trainnet.train()total_loss, step = 0, 0loop = tqdm(enumerate(dl_train), total=len(dl_train))train_metrics_dict = deepcopy(metrics_dict)for i, batch in loop:features, labels = batch# forwardpreds = net(features)loss = loss_fn(preds, labels)# backwardloss.backward()optimizer.step()optimizer.zero_grad()# metricsstep_metrics = {"train_" + name: metric_fn(preds, labels).item() for name, metric_fn in train_metrics_dict.items()}step_log = dict({"train_loss": loss.item()}, **step_metrics)total_loss += loss.item()step += 1if i != len(dl_train) - 1:loop.set_postfix(**step_log)else:epoch_loss = total_loss / stepepoch_metrics = {"train_"+name: metric_fn.compute().item() for name, metric_fn in train_metrics_dict.items()}epoch_log = dict({"train_loss": epoch_loss}, **epoch_metrics)loop.set_postfix(**epoch_log)for name, metric_fn in train_metrics_dict.items():metric_fn.reset()for name, metric in epoch_log.items():history[name] = history.get(name, []) + [metric]# 2 validatenet.eval()total_loss, step = 0, 0loop = tqdm(enumerate(dl_val), total=len(dl_val))val_metrics_dict = deepcopy(metrics_dict)with torch.no_grad():for i, batch in loop:features, labels = batch# forwardpreds = net(features)loss = loss_fn(preds, labels)# metricsstep_metrics = {"val_"+name: metric_fn(preds, labels).item() for name, metric_fn in val_metrics_dict.items()}step_log = dict({"val_loss": loss.item()}, **step_metrics)total_loss += loss.item()step += 1if i != len(dl_val) - 1:loop.set_postfix(**step_log)else:epoch_loss = total_loss / stepepoch_metrics = {"val_"+name: metric_fn.compute().item() for name, metric_fn in val_metrics_dict.items()}epoch_log = dict({"val_loss": epoch_loss}, **epoch_metrics)loop.set_postfix(**epoch_log)for name, metric_fn in val_metrics_dict.items():metric_fn.reset()epoch_log["epoch"] = epochfor name, metric_fn in epoch_log.items():history[name] = history.get(name, []) + [metric]# 3 early stoppingarr_scores = history[monitor]best_score_idx = np.argmax(arr_scores) if mode == "max" else np.argmin(arr_scores)if best_score_idx == len(arr_scores) - 1:torch.save(net.state_dict(), ckpt_path)print(">>>>>>>>> reach best {0} : {1} >>>>>>>>>".format(monitor, arr_scores[best_score_idx]), file=sys.stderr)if len(arr_scores) - best_score_idx > patience:print(">>>>>>>>> {} without improvement in {} epoch, early stopping >>>>>>>>>".format(monitor, patience), file=sys.stderr)breaknet.load_state_dict(torch.load(ckpt_path))
df_history = pd.DataFrame(history)

3.函数风格

该风格在脚本形式上做了进一步的函数封装

class Net(nn.Module):def __init__(self):super().__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,10)])def forward(self, x):for layer in self.layers:x = layer(x)return xnet = Net()
print(net)"""
Net((layers): ModuleList((0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Dropout2d(p=0.1, inplace=False)(5): AdaptiveMaxPool2d(output_size=(1, 1))(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=64, out_features=32, bias=True)(8): ReLU()(9): Linear(in_features=32, out_features=10, bias=True))
)
"""
import os, sys, time
import numpy as np
import pandas as pd
import datetime
from tqdm import tqdm
import torch
from torch import nn
from copy import deepcopydef printlog(info):nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')print("\n"+"=========="*8 + "%s"%nowtime)print(str(info)+"\n")class StepRunner:def __init__(self, net, loss_fn, stage="train", metrics_dict=None, optimizer=None):self.net, self.loss_fn, self.metrics_dict, self.stage = net, loss_fn, metrics_dict, stageself.optimizer = optimizerdef step(self, features, labels):# losspreds = self.net(features)loss = self.loss_fn(preds, labels)# backwardif self.optimizer is not None and self.stage == "train":loss.backward()self.optimizer.step()self.optimizer.zero_grad()# metricsstep_metrics = {self.stage+"_"+name: metric_fn(preds, labels).item() for name, metric_fn in self.metrics_dict.items()}return loss.item(), step_metricsdef train_step(self, features, labels):self.net.train()  # 训练模式dropout层发生作用return self.step(features, labels)@torch.no_grad()def eval_step(self, features, labels):self.net.eval()  # 预测模式 dropout层不发生作用return self.step(features, labels)def __call__(self, features, labels):if self.stage == "train":return self.train_step(features, labels)else:return self.eval_step(features, labels)class EpochRunner:def __init__(self, steprunner):self.steprunner = steprunnerself.stage = steprunner.stagedef __call__(self, dataloader):total_loss, step = 0, 0loop = tqdm(enumerate(dataloader), total=len(dataloader))for i, batch in loop:loss, step_metrics = self.steprunner(*batch)step_log = dict({self.stage+"_loss": loss}, **step_metrics)total_loss += lossstep += 1if i != len(dataloader) - 1:loop.set_postfix(**step_log)else:epoch_loss = total_loss / stepepoch_metrics = {self.stage+"_"+name: metric_fn.compute().item() for name, metric_fn in self.steprunner.metrics_dict.items()}epoch_log = dict({self.stage+"_loss": epoch_loss}, **epoch_metrics)loop.set_postfix(**epoch_log)for name, metric_fn in self.steprunner.metrics_dict.items():metric_fn.reset()return epoch_logdef train_model(net, optimizer, loss_fn, metrics_dict, train_data, val_data=None, epochs=10, ckpt_path='checkpoint.pt', patience=5, monitor='val_loss', mode='min'):history = {}for epoch in range(1, epochs+1):printlog("Epoch {0} / {1}".format(epoch, epochs))# 1 traintrain_step_runner = StepRunner(net=net, stage="train", loss_fn=loss_fn, metrics_dict=deepcopy(metrics_dict), optimizer=optimizer)train_epoch_runner = EpochRunner(train_step_runner)train_metrics = train_epoch_runner(train_data)for name, metric in train_metrics.items():history[name] = history.get(name, []) + [metric]# 2 validateif val_data:val_step_runner = StepRunner(net=net, stage="val", loss_fn=loss_fn, metrics_dict=deepcopy(metrics_dict))val_epoch_runner = EpochRunner(val_step_runner)with torch.no_grad():val_metrics = val_epoch_runner(val_data)val_metrics["epoch"] = epochfor name, metric in val_metrics.items():history[name] = history.get(name, []) + [metric]# 3 early stoppingarr_scores = history[monitor]best_score_idx = np.argmax(arr_scores) if mode == "max" else np.argmin(arr_scores)if best_score_idx==len(arr_scores)-1:torch.save(net.state_dict(),ckpt_path)print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor,arr_scores[best_score_idx]),file=sys.stderr)if len(arr_scores)-best_score_idx>patience:print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format(monitor,patience),file=sys.stderr)break net.load_state_dict(torch.load(ckpt_path))return pd.DataFrame(history)
from torchmetrics import Accuracyloss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
metrics_dict = {"acc": Accuracy(task="multiclass", num_classes=10)}df_history = train_model(net, optimizer, loss_fn, metrics_dict, train_data=dl_train, val_data=dl_val, epochs=10, patience=3, monitor='val_acc', mode='max')

4.类风格

此处使用torchkeras.KerasModel高层次API接口中的fit方法训练模型。

使用该形式训练模型非常简洁明了。

from torchkeras import KerasModel class Net(nn.Module):def __init__(self):super().__init__()self.layers = nn.ModuleList([nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),nn.MaxPool2d(kernel_size = 2,stride = 2),nn.Dropout2d(p = 0.1),nn.AdaptiveMaxPool2d((1,1)),nn.Flatten(),nn.Linear(64,32),nn.ReLU(),nn.Linear(32,10)])def forward(self,x):for layer in self.layers:x = layer(x)return xnet = Net() print(net)"""
Net((layers): ModuleList((0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(4): Dropout2d(p=0.1, inplace=False)(5): AdaptiveMaxPool2d(output_size=(1, 1))(6): Flatten(start_dim=1, end_dim=-1)(7): Linear(in_features=64, out_features=32, bias=True)(8): ReLU()(9): Linear(in_features=32, out_features=10, bias=True))
)
"""
from torchmetrics import Accuracymodel = KerasModel(net, loss_fn=nn.CrossEntropyLoss(), metrics_dict={"acc": Accuracy(task="multiclass", num_classes=10)}, optimizer=torch.optim.Adam(net.parameters(), lr=0.01))model.fit(train_data=dl_train, val_data=dl_val, epochs=10, patience=3, monitor="val_acc", mode="max", plot=True, cpu=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/777750.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

kibana报错:Your basic license does not support watcher. Please upgrade your license.

kibana突然遇到以下错误,特地记录下:错误内容如下:   解决方法:删除,然后启动:

Python逆向

参考链接: https://www.bilibili.com/video/BV1JL4y1p7Tt/?spm_id_from=333.999.0.0 https://bbs.kanxue.com/thread-282542.htm https://blog.csdn.net/weixin_35967330/article/details/114390031?spm=1001.2014.3001.5501 https://0xd13a.github.io/ctfs/0ctf2017/py/前言…

浏览器插件监听元素变动-用于直播自动回复

直播获取评论区的原理 MutationObserver 是一个强大的浏览器API,它可以监听DOM的变化,包括元素的添加、删除、属性的更改等。 开发需求可联系vx:llike620 步骤:创建一个MutationObserver实例,并提供一个回调函数。 使用observe方法指定要监控的DOM节点和具体的变动类型。 …

redis+xxl-job初步设计点赞功能

一般情况下点赞业务涉及以下下几个方面: 1.我们肯定要知道一个题目被多少人点过赞,还要知道,每个人他点赞了哪些题目。 2.点赞的业务特性,频繁。用户一多,时时刻刻都在进行点赞,收藏等等处理,如果说我们采取传统的数据库的模式啊,这个交互量是非常大的,很难去抗住这个…

易优CMS友情链接列表标签

{eyou:flink type=all row=30 titlelen=15} <!--type:image图片类型--> <a href={$field.url} {$field.target}><img src={$field.logo} /></a> <!--type:text文本类型--> <a href={$field.url} {$field.target}>{$field.title}</a>…

如何选择PHP和MySQL的版本?

建议优先使用 php5.6 + MySQL5.7 的组合。扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网站改版、BUG修复、问题处理、二次开发、PSD转HTML、网站被黑、网站漏洞修复等。…

易优cms获取栏目id的方法,调用栏目id实现不同栏目banner图替换

直接{$eyou.field.typeid}.jpg{$eyou.field.typeid}调用当前栏目ID , 可以在前面加上路径如 /skin/images/{$eyou.field.typeid}.jpg。 效果: 图片上传 images文件夹, 根据不同ID, 图片命名为1.jpg/2.jpg扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟…

帝国cms忘记后台登陆密码怎么办?

使用MySQL数据库管理软件phpmyadmin,修改数据库中的phome_enewsuser数据表进行密码重置操作: 修改phome_enewsuser表里的记录:将password字段的内容改为:“322d3fef02fc39251436cb4522d29a71”;将salt字段的内容改为:“abc”. 密码就重置为:123456扫码添加技术【解决问题…

易优CMS模板标签uiarclist文档列表

【基础用法】标签:uiarclist描述:文档列表编辑,比uitext、uihtml、uiupload标签多了一个typeid属性,使用时结合html一起才能完成可视化布局,只针对具有可视化功能的模板。用法: <div class="eyou-edit" e-id="文件模板里唯一的数字ID" e-page=文件…

怎么在Ubuntu系统云服务器搭建自己的幻兽帕鲁服务器?幻兽帕鲁搭建教程

《幻兽帕鲁》是一款备受瞩目的开放世界生存建造游戏,近期在游戏界非常火爆。玩家可以在游戏世界中收集神奇的生物“帕鲁”,并利用它们进行战斗、建造、农耕、工业生产等各种活动。与其他开放世界游戏不同,要想实现多人联机游戏,玩家需要自行搭建服务器。《幻兽帕鲁》是一款…

Flask 快速搭建模板1

快速搭建基础框架 成品预览pip安装 需要导入的基础包 pip install flask pip install flask-sqlalchemy pip install flask-wtf pip install bootstrap-flask pip install flask-login pip install flask-moment创建目录结构 type nul > main.py type nul > config.py ty…

DubboNacos

Dubbo的前世今生 2011年10月27日,阿里巴巴开源了自己的SOA服务化治理方案的核心框架Dubbo,服务治理和SOA的设计理念开始逐渐在国内软件行业中落地,并被广泛应用。早期版本的dubbo遵循SOA的思想,是面向服务架构的重要组件。如今版本的Dubbo作为Spring Cloud的二进制通信方案…