创建用于预测序列的人工智能模型,用Keras Tuner探索模型的超参数。

news/2024/12/28 8:36:28/文章来源:https://www.cnblogs.com/jellyai/p/18637050

上一篇:《创建用于预测序列的人工智能模型(五),调整模型的超参数》

序言:在完成初步的模型研发后,接下来的重点是探索和优化超参数。通过合理调整超参数(如学习率、动量参数、神经元数量等),可以进一步提高模型的性能和准确性。这一过程需要结合工具(如 Keras Tuner)进行自动化测试和优化,从而高效找到最优配置方案。

探索使用 Keras Tuner 调整超参数

在上一节中,你学会了如何粗略地优化随机梯度下降(SGD)损失函数的学习率。这确实是一个非常粗略的尝试:每隔几个 epoch 改变一次学习率并测量损失值变化。然而,这种方式受到损失函数本身在每个 epoch 间波动的影响,因此你可能并没有真正找到最佳值,而只是得到了一个近似值。要真正找到最佳值,你需要在每个潜在值的情况下进行完整的轮次训练,然后比较结果。

而且,这还仅仅是针对一个超参数——学习率。如果你还想优化动量参数(momentum),或者调整其他内容,比如每层的神经元数量、层数等,那么可能需要测试成千上万种选项,而手动实现所有这些训练代码几乎是不可能的。

幸运的是,Keras Tuner 工具让这些变得相对简单。你可以通过以下命令安装 Keras Tuner:

!pip install keras-tuner

安装完成后,你就可以使用它来参数化超参数,并指定需要测试的值范围。Keras Tuner 会为每组参数训练模型,评估其性能,并根据你的目标(例如最小化损失)报告最佳的模型结果。我不会在这里详细介绍所有功能,但会展示如何在这个特定模型中使用它。

假设我们想实验两个方面,首先是调整模型架构中输入层的神经元数量。目前的模型架构是输入层 10 个神经元、隐藏层 10 个神经元,然后是输出层。但如果通过增加输入层的神经元数量,网络的表现可以变得更好呢?比如,你可以尝试将输入层的神经元数量从 10 增加到 30。

回忆一下,输入层的定义如下:

tf.keras.layers.Dense(10, input_shape=[window_size], activation="relu")

如果你想测试比硬编码的 10 更大的值,可以这样写:

tf.keras.layers.Dense(units=hp.Int('units', min_value=10, max_value=30, step=2),

activation='relu', input_shape=[window_size])

这里定义了输入层会用多种值进行测试,从 10 开始,每次增加 2,一直到 30。现在,Keras Tuner 将不再只训练一次模型,而是会训练 11 次!

同时,回忆一下优化器中的动量参数是硬编码为 0.9 的:

optimizer = tf.keras.optimizers.SGD(lr=1e-5, momentum=0.9)

你可以使用 hp.Choice 方法测试多个动量值,例如:

optimizer = tf.keras.optimizers.SGD(hp.Choice('momentum', values=[.9, .7, .5, .3]),

lr=1e-5)

这为动量参数提供了 4 种可能的选择。因此,当与前述输入层神经元数量组合时,总共会有 44 种组合需要测试。Keras Tuner 会自动完成这些训练任务,并报告表现最佳的模型。

为了完成设置,首先需要定义一个函数,用于根据超参数构建模型:

def build_model(hp):

model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Dense(

units=hp.Int('units', min_value=10, max_value=30, step=2),

activation='relu', input_shape=[window_size]))

model.add(tf.keras.layers.Dense(10, activation='relu'))

model.add(tf.keras.layers.Dense(1))

model.compile(loss="mse",

optimizer=tf.keras.optimizers.SGD(hp.Choice('momentum',

values=[.9, .7, .5, .3]),

lr=1e-5))

return model

接着,使用安装好的 Keras Tuner 创建一个 RandomSearch 对象来管理所有的迭代:

tuner = RandomSearch(

build_model,

objective='loss', max_trials=150,

executions_per_trial=3, directory='my_dir',

project_name='hello')

注意,你需要通过传递前面定义的函数来指定模型。hp 参数用于控制需要调整的值范围。目标(objective)被设置为 loss,表示我们想要最小化损失值。max_trials 参数限制总实验次数,executions_per_trial 参数可以指定每次实验的训练和评估次数,从而减少随机波动的影响。

开始搜索时,只需调用 tuner.search,就像调用 model.fit 一样:

tuner.search(dataset, epochs=100, verbose=0)

运行本章中所使用的合成序列数据后,Keras Tuner 会根据你定义的选项训练模型并完成所有可能的超参数组合测试。

完成后,你可以调用 tuner.results_summary 查看基于目标的前 10 次实验结果:

tuner.results_summary()

你会看到类似以下的输出:

Results summary

|-Results in my_dir/hello

|-Showing 10 best trials

|-Objective(name='loss', direction='min')

Trial summary

|-Trial ID: dcfd832e62daf4d34b729c546120fb14

|-Score: 33.18723194615371

|-Best step: 0

Hyperparameters:

|-momentum: 0.5

|-units: 28

Trial summary

|-Trial ID: 02ca5958ac043f6be8b2e2b5479d1f09

|-Score: 33.83273440510237

|-Best step: 0

Hyperparameters:

|-momentum: 0.7

|-units: 28

从结果中可以看到,最低损失值是在动量为 0.5 和输入神经元数量为 28 时达到的。你可以通过调用 get_best_models 来检索这些模型以及其他最佳模型。例如,如果你想获取前 4 个最佳模型,可以这样调用:

tuner.get_best_models(num_models=4)

你可以测试这些模型,或者使用找到的最佳超参数从头创建一个新模型,例如:

dataset = windowed_dataset(x_train, window_size, batch_size, shuffle_buffer_size)

model = tf.keras.models.Sequential([

tf.keras.layers.Dense(28, input_shape=[window_size], activation="relu"),

tf.keras.layers.Dense(10, activation="relu"),

tf.keras.layers.Dense(1)

])

optimizer = tf.keras.optimizers.SGD(lr=1e-5, momentum=0.5)

model.compile(loss="mse", optimizer=optimizer)

history = model.fit(dataset, epochs=100, verbose=1)

当我使用这些超参数进行训练,并像之前一样对整个验证集进行预测时,我得到了一个类似于图 10-6 的图表:

                                                      图 10-6:优化超参数后的预测图表

对这个模型的 MAE(平均绝对误差)计算结果是 4.47,相比之前的 4.51 有了轻微的改进,相较于上一章统计方法的 5.13 结果更是显著提升。这是在学习率调整为 1e−51e^{-5}1e−5 的情况下完成的,而这个学习率可能还不是最优值。通过 Keras Tuner,你可以进一步调整像这样的超参数,还可以尝试调整中间层的神经元数量,甚至实验不同的损失函数和优化器。尝试一下,看看能否进一步改进这个模型吧!

总结

在本篇中,我们从前几篇的时间序列统计分析出发,尝试将机器学习应用于改进预测。人工智能(机器学习)的核心在于模式匹配,正如预期的那样,我们通过使用深度神经网络(DNN)来发现数据中的模式,成功将平均绝对误差(MAE)降低了近 10%。接着,我们再利用 Keras Tuner 进一步优化超参数,改进了模型的损失值并提升了预测精度。

在接下来的文章中,我们将超越简单的人工智能模型( DNN),探索使用循环神经网络(RNN)来预测序列数据的可能性,并分析其对序列预测的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/860243.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源GTKSystem.Windows.Forms框架让C# Winform支持跨平台运行

前言 在咱们的印象中C# WinForm一直只支持Windows系统运行,无法支持跨平台运行。今天大姚给大家分享一个开源框架:GTKSystem.Windows.Forms,它能够让C# Winform支持跨平台运行。 项目介绍 GTKSystem.Windows.Forms是一个C#桌面应用程序跨平台(Windows、Linux、macOS)开发框…

SPIR-V的开源编译器生态系统API分层

API分层 SPIR-V的开源编译器生态系统越来越强大。 1.行分层 无需额外的内核级驱动程序即可实现内容,从而使平台受益。 OpenCL接口分层,如图1-30所示。图1-30 OpenCL接口分层 2.列分层 即使没有本机驱动程序,也可以跨多个平台提供API,以便提供应用程序部署灵活性并消除碎片,…

分层OpenCL实现

分层OpenCL实现 OpenCL接口分层实现,如图1-31所示。图1-31 OpenCL接口分层实现人工智能芯片与自动驾驶

《智能汽车传感器:原理设计应用》《AI芯片开发核心技术详解》新书推荐

两本书推荐《AI芯片开发核心技术详解》、《智能汽车传感器:原理设计应用》由清华大学出版社资深编辑赵佳霓老师策划编辑的新书《AI芯片开发核心技术详解》已经出版,京东、淘宝天猫、当当等网上,相应陆陆续续可以购买。该书强力解析AI芯片的核心技术开发,内容翔实、知识点新…

Excel+Python 飞速搞定数据分析与处理(图灵出品)PDF免费下载

零基础Python编程数据分析,Excel办公自动化处理,告别烦琐公式,办公人士也能轻松学习Python数据处理自动化,让你的Excel快得飞起来!适读人群 :本书既适合Excel用户,也适合Python用户阅读。电子版仅供预览,下载后24小时内务必删除,支持正版,喜欢的请购买正版书籍:http…

移动端滑动,better-scroll使用

背景 为博客园做移动端适配,有一个控件需要固定大小,但是里面的内容是动态的,很有可能放不下,因此需要滑动。 设置了滑动后,我发现划不动,原来原生的滑动是不管你什么移动端的,于是找移动端适配的滑动。 Better-Scroll 名声很大,坑不少。 划不动 官方文档写的快速开始实…

JAVA-第三次大作业blog

一.前言 在深入探索Java编程的征途中,我迎来了第七、八次PTA题目集的挑战。这两次作业不仅是对我学习成果的检验,更是深化我对Java核心概念——继承理解的宝贵契机。通过亲身实践,我不仅巩固了继承在Java中的应用技巧,还跨越性地深化了对子类与父类关系的洞察。每一次编码,…

Unity音频管理方案

AudioManager类的创建可以序列化,就可以在外面看到然后在Awake里面初始化一下AudioManager类的完善 写个单例:这样就可以直接在外面AudioManager.去调用比较方便 使用AudioMixer对音频进行分组使用unity自带的AudioMixer使用unity自带的AudioMixer进行音量统一处理在UI框架里…

没有xml configuration file

点击new菜单发现没有看到XML Configuration File选项。1、正确导入spring jar包,spring5.6 maven坐标<dependency> <groupId>org.springframework</groupId> <artifactId>spring-context</artifactId> <version>6.2.1</versi…

PostgreSQL 数据库的启动与停止管理

title: PostgreSQL 数据库的启动与停止管理 date: 2024/12/28 updated: 2024/12/28 author: cmdragon excerpt: 作为一个强大的开源关系数据库管理系统,PostgreSQL在众多应用场景中发挥着关键作用。在实际使用过程中,对于数据库的启动和停止操作至关重要。这不仅关系到数据…

人工智能Agent提示工程的六个关键要素

一个构造良好的提示封装了所有必要的信息,确保AI Agent生成准确的响应并有效地执行任务。 通过系统地组合特定组件,提示符为LLM提供了一个全面的框架,以实现最佳功能。 六个关键要素如下: 1.用户请求:这是用户提供的原始任务描述,概述了目标和期望的结果。它作为代理行为…

一个Java实现的OCR系统

一个Java实现的OCR系统 利用java17实现的一套OCR推理系统,兼容paddleocr。如下图,目前功能如下,https://github.com/jiangnanboy/JiaJiaOCR: