【需求实现】Tensorflow2的曲线拟合(二):进度条简化

文章目录

  • 导读
  • 普通的输出方式
  • 上下求索
    • TensorBoard是个不错的切入点
    • 与Callback参数对应的Callback方法
    • 官方的内置Callback
    • 官方进度条
    • 简单的猜测与简单的验证
    • 拼图凑齐了!

导读

在训练模型的过程中往往会有日志一堆一堆的困扰。我并不想知道,因为最后我会在变量里面查询,反正训练过程中也没心思看。于是,就想把进度条简化一下。下面给出解决方案。

普通的输出方式

对于一般的训练过程,我们可能在Tensorflow中的fit方法中,将verbose置为 1 1 1,或者不设置verbose而让Tensorflow默认verbose 1 1 1。这样的话就会有如下图一样长篇大论的输出。

在这里插入图片描述
虽然不至于很烦躁,但是实在不愿意去管这些事情。又不是做生物实验,完全不需要人在这边守着嘛。趁着这时间泡杯咖啡多好。

于是呢,就想着源码里面是如何将输出显示出来的。

上下求索

TensorBoard是个不错的切入点

Tensorflow海量的源码中寻找一个输出无疑是大海捞针,对于Windows用户来说找起来超级麻烦,除非Linux用户直接用grep命令作弊。

但是呢,突然就注意到,Tensorflow还有一个Tensorboard,是在fit方法里面的callback参数中出现。既然日志能够从callback参数中获得,那么这里是有什么玄机吗?

与Callback参数对应的Callback方法

于是找到了文件中tensorflow/tensorflow/python/keras/engine/training.py,也就是官方GitHub的这一页(点击直达callbacks注释那一行),他是这么说明的:

'''
callbacks: List of `keras.callbacks.Callback` instances.List of callbacks to apply during training.See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`and `tf.keras.callbacks.History` callbacks are created automaticallyand need not be passed into `model.fit`.`tf.keras.callbacks.ProgbarLogger` is created or not based on`verbose` argument to `model.fit`.Callbacks with batch-level calls are currently unsupported with`tf.distribute.experimental.ParameterServerStrategy`, and users areadvised to implement epoch-level calls instead with an appropriate`steps_per_epoch` value.
'''

这也就是说,官方已经把进度条内置到ProgbarLogger这个类里面了,并通过callbacks调用。

其中,对于callbacks是这么调用的:

callbacks.on_train_begin()

而这个on_train_begin又是属于CallbackList类中:

# in CallbackList class
def on_train_begin(self, logs=None):"""Calls the `on_train_begin` methods of its callbacks.Args:logs: Dict. Currently no data is passed to this argument for this methodbut that may change in the future."""logs = self._process_logs(logs)for callback in self.callbacks:callback.on_train_begin(logs)

也就是说,是遍历callbacks中的所有callback然后一一执行。

执行过程的源码在Callback类中(类名跟上一个不一样哦):

# in Callback class
@doc_controls.for_subclass_implementers
def on_train_begin(self, logs=None):"""Called at the beginning of training.Subclasses should override for any actions to run.Args:logs: Dict. Currently no data is passed to this argument for this methodbut that may change in the future."""

但很明显,这个源码就是想让我们自定义。

刚刚好发现进度条也在这里。那接下来的事情就更明确了,去找这个类就行了。一方面是查看进度条的原理,另一方面则是按照官方的进度条仿写一个简单的进度条。

官方的内置Callback

于是就找到了tensorflow/tensorflow/python/keras/callbacks.py文件中,也就是官方GitHub的这一页(点击直达ProgbarLogger类定义的那一行)。但比较可惜的是,这个类定义的时候注释太少了,并不能很确定每个类中的各个方法都在做什么。怎么办呢?

别忘了官方给的提示:【内置了ProgbarLogger类与History类】。我们既然需要了解如何在Callback里面调用,那么就需要了解这两个东西是如何插入进去的。但是这个类注释实在是太少了,很多东西看得不明不白的,该怎么办呢?

当然是【贪心搜索】了呀,按照【命名】与【个人经验】去判断哪一个最可能是我们想要的方法。虽然很不靠谱,但是万一运气好撞上了呢?于是呢,就找到了_add_default_callbacks方法,也就是默认插入的一些东西。他是这么写的:

  def _add_default_callbacks(self, add_history, add_progbar):"""Adds `Callback`s that are always present."""self._progbar = Noneself._history = Nonefor cb in self.callbacks:if isinstance(cb, ProgbarLogger):self._progbar = cbelif isinstance(cb, History):self._history = cbif self._progbar is None and add_progbar:self._progbar = ProgbarLogger(count_mode='steps')self.callbacks.insert(0, self._progbar)if self._history is None and add_history:self._history = History()self.callbacks.append(self._history)

总之就是一些判断,如果空就创建。

看来插入就是insert方法与append方法了。不难猜测,也不用猜测,callbacks将是一个数组。

官方进度条

既然知道了进度条是如何被调用的,那么接下来就是得了解官方的进度条是怎么添加的。

当然,还是在ProgbarLogger类中,为了方便点击这里就能传送。这里面明显的给出了输出Epoch,正好就是我们需要找到的输出。源码是这样的:

def on_epoch_begin(self, epoch, logs=None):self._reset_progbar()self._maybe_init_progbar()if self.verbose and self.epochs > 1:print('Epoch %d/%d' % (epoch + 1, self.epochs))

而且条件是需要verbose 0 0 0。怪不得我们把verbose置为 0 0 0就什么都没有了。

简单的猜测与简单的验证

既然官方这么设计能够输出,那么我们也就简单的提出一个想法:

我们首先需要定义一个类A,然后像这个ProgbarLogger类一样继承自Callback,然后再自定义on_epoch_begin或者on_epoch_end方法。这个方法需要具有三个参数:

  • 首先是on_epoch_begin方法作为A类一个成员的self
  • 其次是与ProgbarLoggeron_epoch_end方法一样传入epoch变量,从而获取到当前学习过程进行到哪一个epoch中了
  • 最后就是一个暂时是Nonelogs变量

那么,如何去验证呢?如果我们有一定的Java基础的话,那么我们其实大概可以猜出来,所有的东西都是有一个接口可以实现,或者一个抽象类可以继承。那么Tensorflow这种超大体量的框架也大概需要借鉴这种设计思想,否则很多东西都会乱糟糟的,没有一个统一的规范。所以,我们寻找一下有没有这类东西。

当然,我最终也是找到了:其实就是Callback类,他的注释是:

Abstract base class used to build new callbacks.

其中的on_epoch_beginon_epoch_endon_train_beginon_train_end等方法都是可以让子类实现的。

当然,在这里官方也很贴心的给了一个例子:

'''
Example:>>> training_finished = False>>> class MyCallback(tf.keras.callbacks.Callback):...   def on_train_end(self, logs=None):...     global training_finished...     training_finished = True>>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])>>> model.compile(loss='mean_squared_error')>>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),...           callbacks=[MyCallback()])>>> assert training_finished == True
'''

看来我们的猜想是正确的。

这么一想我这找了这么久都是没用的吗

拼图凑齐了!

找了这么久,我们所需要了解的一切就都明白了。

那就自定义一个进度条:

import tensorflow as tf
class TensorflowProgressBar(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs = None):print('\r', f'Now Processing: {epoch}, Progress: {round(epoch / EPOCHS * 100, 2)}%',end = '', flush = True)

这样的话,每当一个epoch结束的时候,就会显示当前是第几个epoch,并计算出当前进度的百分比。

这样的话就能在等结果的过程中做点别的事情,顺便时不时抬头看一眼进度。当老板问到的时候,随便看一下百分比就能报告,非常方便。

最后,在fit方法里调用一下:

record = model.fit(X_train, y_train,batch_size = BATCH_SIZE, epochs = EPOCHS,callbacks=[TensorflowProgressBar()], verbose = 0)

其中,verbose置为 1 1 1的话,Tensorflow还会继续输出大量的进度,这是我们并不想看到的。所以为了让他只输出我们想要看到的进度,就必须将verbose置为 0 0 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/4508.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【域名详解】网络杂谈(13)之深入简出了解什么是域名

涉及知识点 什么是域名,域名的概念,域名的结构,域名地址的寻址过程,深入了解域名的寻址机制。 原创于:CSDN博主-《拄杖盲学轻声码》,更多内容可去其主页关注下哈,不胜感激 文章目录 涉及知识点…

基于Tars高并发IM系统的设计与实现-基础篇1

基于Tars高并发IM系统的设计与实现–基础篇1 作者简介 兰怀玉 毕业于中央民族大学计算机专业 先后供职国内外多家公司软件研发设计岗位,有丰富的软件研发经验。 从事IM领域设计研发十余年,先后领衔多个IM通讯系统设计与研发发,拥有丰富的IM…

linux下RabbitMQ的使用

文章目录 linux下RabbitMQ的使用首先docker启动网页打开网址:用户名和密码登录创建exchanges:创建Queues增加Queues的Bind linux下RabbitMQ的使用 首先docker启动 su rootsudo docker run -d --hostname rabbitsvr --name rabbit -p 5672:5672 -p 15672:15672 -p …

ChatGPT实战:项目管理

人工智能有可能彻底改变许多行业,包括项目管理,及时了解最新技术以及它如何影响你的工作至关重要,因为学习好项目管理,不管你能不能做项目经理,在生活、工作的方面方面都会享受到懂得项目管理后带来的收益。 下面我们借…

c# Invoke使用

在多线程编程中,我们经常要在工作线程中去更新界面显示,而在多线程中直接调用界面控件的方法是错误的做法,Invoke 和 BeginInvoke 就是为了解决这个问题而出现的,使你在多线程中安全的更新界面显示。 正确的做法是将工作线程中涉…

发送邮箱验证码【spring boot】

⭐前言⭐ ※※※大家好!我是同学〖森〗,一名计算机爱好者,今天让我们进入学习模式。若有错误,请多多指教。更多有趣的代码请移步Gitee 👍 点赞 ⭐ 收藏 📝留言 都是我创作的最大的动力! 1. 思维…

Maven manual

Download maven Download 设置 system env… E:\apache-maven-3.9.3\bin查看版本信息 mvn -v Apache Maven 3.9.3 (21122926829f1ead511c958d89bd2f672198ae9f) Maven home: E:\apache-maven-3.9.3与Eclipse integrate Referrence,通常Eclipse原本就已经集成&am…

Docker Desktop 安装使用教程

一、前言 作为开发人员,在日常开发中,我们需要在本地去启动一些服务,如:redis、MySQL等,就需要去下载这些在本地去启动,操作较为繁琐。此时,我们可以使用Docker Desktop,来搭建我们需…

78、基于STM32单片机步进电机速度调速控制系统设计(程序+原理图+PCB源文件+参考论文+开题报告+流程图+元器件清单等)

摘 要 伴随着时代的快速发展,单片机的应用也越来越广泛,促进了微电子和计算机的快速发展。我们日常生活中步进电机扮演着很重要的角色在我们身边随处可以见。因为步进电机本身的结构组成相对于比较简单、价格也比较便宜廉价。比如压榨机,打印…

vue3使用高德地图实现点击获取经纬度以及搜索功能

话不多说直接上干活 在此之前你需要有高德地图的 key&#xff0c;这个自己去申请即可 1&#xff0c;首先需要在终端安装 npm i amap/amap-jsapi-loader --save 2&#xff0c;准备一个容器 <template><div id"container"></div> </templat…

亚马逊云科技如何通过四大自研芯片助力企业创新,摆脱基础架构束缚

2023年6月27-28日&#xff0c;2023亚马逊云科技中国峰会在上海顺利举行。在此次峰会上我们可以清晰地看到为什么亚马逊云科技可以做到领先地位&#xff0c;为什么亚马逊云科技可以一直保持进步。这都与亚马逊云科技“基于客户需求&#xff0c;快速进行产品更新与技术迭代”的Da…

【前端|CSS系列第3篇】CSS盒模型、浮动及定位

在前端开发中&#xff0c;CSS是一项重要的技术&#xff0c;用于控制网页的样式和布局。在本系列的第三篇文章中&#xff0c;我们将学习CSS的盒模型、浮动以及定位&#xff0c;这些概念和技术在页面布局中起着至关重要的作用。通过本文的学习&#xff0c;希望能够帮助大家更好地…