pytorch 利用Tensorboar记录训练过程loss变化

文章目录

    • 1. LossHistory日志类定义
    • 2. LossHistory类的使用
      • 2.1 实例化LossHistory
      • 2.2 记录每个epoch的loss
      • 2.3 训练结束close掉SummaryWriter
    • 3. 利用Tensorboard 可视化
      • 3.1 显示可视化效果
    • 参考

利用Tensorboard记录训练过程中每个epoch的训练loss以及验证loss,便于及时了解网络的训练进展。

代码参考自 B导github仓库: https://github.com/bubbliiiing/deeplabv3-plus-pytorch

1. LossHistory日志类定义

import os
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import scipy.signalimport torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
class LossHistory():def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:passdef append_loss(self, epoch, loss, val_loss):if not os.path.exists(self.log_dir):os.makedirs(self.log_dir)self.losses.append(loss)self.val_loss.append(val_loss)with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:f.write(str(val_loss))f.write("\n")self.writer.add_scalar('loss', loss, epoch)self.writer.add_scalar('val_loss', val_loss, epoch)self.loss_plot()def loss_plot(self):iters = range(len(self.losses))plt.figure()plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')try:if len(self.losses) < 25:num = 5else:num = 15plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')except:passplt.grid(True)plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend(loc="upper right")plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))plt.cla()plt.close("all")
  • (1) 首先利用LossHistory类的构造函数__init__, 实例化TensorboardSummaryWriter对象self.writer,并将网络结构图添加到self.writer中。其中__init__方法接收的参数包括,保存log的路径log_dir以及模型model和输入的shape
def __init__(self, log_dir, model, input_shape):self.log_dir    = log_dirself.losses     = []self.val_loss   = []os.makedirs(self.log_dir)self.writer     = SummaryWriter(self.log_dir)try:dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])self.writer.add_graph(model, dummy_input)except:pass
  • (2) 记录每个epoch的训练损失loss以及验证val_loss,并保存到tensorboar中显示
self.writer.add_scalar('loss', loss, epoch)
self.writer.add_scalar('val_loss', val_loss, epoch)

同时将训练的loss以及验证val_loss逐行保存到.txt文件中

 with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:f.write(str(loss))f.write("\n")

并且在每个epoch时,调用loss_plot绘制历史的loss曲线,并保存为epoch_loss.png, 由于每个epoch保存的图片都是重名的,因此在训练结束时,会保存最新的所有epoch绘制的loss曲线

2. LossHistory类的使用

2.1 实例化LossHistory

在训练开始前,实例化LossHistory类,调用__init__实例化时,会创建SummaryWriter对象,用于记录训练的过程中的数据,比如loss, graph以及图片信息等

local_rank  = int(os.environ["LOCAL_RANK"]) 
model   = DeepLab(num_classes=num_classes, backbone=backbone, downsample_factor=downsample_factor, pretrained=pretrained)
input_shape     = [512, 512]if local_rank == 0:time_str        = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')log_dir         = os.path.join(save_dir, "loss_" + str(time_str))loss_history    = LossHistory(log_dir, model, input_shape=input_shape)else:loss_history    = None
  • 对于多GPU训练时,只在主进程(local_rank == 0)记录训练的日志信息
  • log 保存的路径log_dir,利用loss_ + 当前时间的形式记录
log_dir         = os.path.join(save_dir, "loss_" + str(time_str))

2.2 记录每个epoch的loss

在每个epoch中,利用loss_history的append_loss方法,利用SummaryWriter对象保存loss:

for epoch in range(start_epoch, total_epoch):...loss_history.append_loss(epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)
  • 记录了每个epoch的训练loss以及验证val_loss
  • 同时将最新的loss曲线,保存到本地epoch_loss.png
  • 并将历史的训练loss和val_loss保存为txt文件,方便查看

2.3 训练结束close掉SummaryWriter

loss_history.writer.close()

3. 利用Tensorboard 可视化

  • Tensorboard最早是在Tensorflow中开发和应用的,pytorch 中也同样支持Tensorboard的使用,pytorch中的Tensorboard工具叫TensorboardX, 它需要依赖于tensorflow库中的一些组件支持。因此在安装Tensorboardx之前,需要先安装TensorFlow, 否则直接安装Tensorboardx运行会报错。
pip install tensorflow
pip install tensorboardX

3.1 显示可视化效果

训练结束后,cd到SummaryWriter中定义好日志保存目录log_dir下,执行如下指令

cd log_dir # log_dir为定义的日志保存目录
tensorboard  --logdir=./     --port 6006 

然后会显示出访问的链接地址,点击链接就可以查看Tensorboard可视化效果

  • Scalar模块展示训练过程中,每个epoch的train_loss、Accuracy、Learn_Rating的数值变化
    在这里插入图片描述
  • GRAPH模块展示的是模型的网络结构
    在这里插入图片描述
  • HISTOGRAMS模块展示添加到tensorboard中各层的权重分布情况
    在这里插入图片描述

参考

  • 1 https://github.com/bubbliiiing/deeplabv3-plus-pytorch
  • 2 pytorch中使用tensorboard实现训练过程可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/456844.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA原型模式详解

原型模式 1 原型模式介绍 定义: 原型模式(Prototype Design Pattern)用一个已经创建的实例作为原型&#xff0c;通过复制该原型对象来创建一个和原型对象相同的新对象。 西游记中的孙悟空 拔毛变小猴,孙悟空这种根据自己的形状复制出多个身外化身的技巧,在面向对象软件设计领…

Java实现音乐平台 JAVA+Vue+SpringBoot+MySQL

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统展示 四、核心代码4.1 查询单首音乐4.2 新增音乐4.3 新增音乐订单4.4 查询音乐订单4.5 新增音乐收藏 五、免责说明 一、摘要 1.1 项目介绍 基于微信小程序JAVAVueSpringBootMySQL的音乐平台&#xff0c;包含了音乐…

后端软件三层架构

一、三层架构简介 三层架构是软件开发中广泛采用的一种经典架构模式&#xff0c;其核心价值在于通过清晰的任务划分来提高代码的可维护性和重用性。具体来说&#xff0c;三层架构主要包括以下三个层次&#xff1a; 持久层&#xff08;DAO层&#xff09;&#xff1a;这一层主要…

【XR806开发板试用】xr806使用tcp socket与手机通信

本文为极术社区XR806开发板活动试用文章。 参考&#xff1a;基于星辰处理器的全志XR806开源鸿蒙开发板上手体验 搭建环境。并成功编译。 项目源码 &#xff1a; https://gitee.com/kingwho/smart-home 在同一个局域网中&#xff0c;手机与xr806连接后&#xff0c;手机 APP 每隔…

背包2讲(2.6)

问题1&#xff1a;装箱问题 题解&#xff1a;这题其实本质上也是01背包&#xff0c;只不过算是变式&#xff0c;要求剩余空间最小值&#xff0c;我们可以转换成最大可以装多少的问题&#xff0c;然后就可以很快的写出题的答案 #include<bits/stdc.h> using namespace st…

c++学习第十六讲---STL常用容器---stack容器,queue容器

一、stack容器&#xff1a; 1.stack基本概念&#xff1a; stack栈容器 stack是一种先进后出的数据结构&#xff0c;它只有一个出口。 栈中只有顶端的元素才能被使用&#xff0c;因此不存在遍历操作。 栈中进数据&#xff1a;入栈 --- push() 栈中出数据&#xff1a;出栈 -…

【UE Niagara】实现物体变形的两种方式

目录 效果 步骤 方式一、通过设置粒子位置 方式二、通过线性力 效果 步骤 方式一、通过设置粒子位置 新建一个Niagara系统&#xff0c;选择Empty模板 这里命名为“NS_Morph” 打开“NS_Morph”&#xff0c;先添加一个“Spawn Burst Instantaneous”模块&#xff0c;“Spa…

用友U8 Cloud ReportDetailDataQuery SQL注入漏洞复现(QVD-2023-47860)

0x01 产品简介 用友U8 Cloud 提供企业级云ERP整体解决方案,全面支持多组织业务协同,实现企业互联网资源连接。 U8 Cloud 亦是亚太地区成长型企业最广泛采用的云解决方案。 0x02 漏洞概述 用友U8 cloud ReportDetailDataQuery 接口处存在SQL注入漏洞,攻击者未经授权可以访…

手撕spring bean的加载过程

这里我们采用手撕源码的方式&#xff0c;开始探索spring boot源码中最有意思的部分-bean的生命周期&#xff0c;也可以通过其中的原理理解很多面试以及工作中偶发遇到的问题。 springboot基于约定大于配置的思想对spring进行优化&#xff0c;使得这个框架变得更加轻量化&#…

【win】vscode无法使用ctrl+shift+p快捷键的解决方案

本文首发于 ❄️慕雪的寒舍 今天使用vscode的时候遇到的这个问题&#xff0c;明明快捷键设置的是ctrlshiftp&#xff0c;但是在电脑上怎么敲都敲不出来&#xff0c;因为用这个快捷键打开命令面板都习惯了&#xff0c;也不想换&#xff0c;就在找原因。 同时百度的时候还遇到了…

vue懒加载请求思路

当页面中不存在分页时&#xff0c;首先考虑到的就是懒加载&#xff0c;所以今天提供一个懒加载的思路。 首先是是么时候应该触发懒加载&#xff0c;以上面页面为例当页面容器中的卡片不能充满屏幕时就会触发加载出新数据&#xff0c;触发前提是1.已获取数据并非全部的。2.上一次…

Vitest 单元测试详解

一、自动化测试&#xff08;TDD&#xff09;的一些概念&#xff1a; 自动化测试&#xff08;TDD&#xff09;概念&#xff1a; 自动化测试是指 使用独立于待测软件的其他软件或程序来自动执行测试&#xff0c;比较实际结果与预期 并生成测试报告这一过程。在测试流程已经确定后…