DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in

一、数据处理
用于训练 RNN 的 mini-batch 数据不同于通常的数据。 这些数据通常应按时间序列排列。 对于 DI-engine, 这个处理是在 collector 阶段完成的。 用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。 这里将在下一节中具体解释。

什么是数据处理?
数据处理指的是为循环神经网络(RNN)训练准备时间序列数据的过程。这个过程包括将收集到的数据组织成适当格式的小批量(mini-batches),这些批量数据将用于网络的训练。这一步骤通常发生在DI-engine的collector阶段,也就是数据收集和预处理发生的地方。用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。例如,如果你设置 learn_unroll_len = 10 和 burnin_step = 5,那么 RNN 实际接收的输入序列长度将是 15:前 5 步为 burn-in(用于预热隐藏状态),接下来的 10 步作为学习的一部分。这样设置可以帮助 RNN 在计算梯度和进行权重更新时,有一个更加准确的隐藏状态作为起点。
部分名词解释

  • mini-batches:在机器学习中,特别是在训练神经网络时,数据一般被分成小的批次进行处理,这些批次被称为 “mini-batch”。一个 mini-batch 包含了一组样本,这组样本用于执行单次迭代的前向传播和反向传播,以更新网络的权重。使用 mini-batches 而不是单个样本或整个数据集(后者称为 “batch” 或 “full-batch”)可以平衡计算效率和内存限制,有助于提高学习的稳定性和收敛速度。
  • collector阶段:在 DI-engine中,collector 阶段是指环境与智能体交互并收集经验数据的过程。在这个阶段,智能体根据其当前的策略执行操作,环境则返回新的状态、奖励和其他可能的信息,如是否达到终止状态。收集到的数据(经常被称为经验或转换)随后被用于训练智能体的模型,例如对策略或价值函数进行更新。

为什么要进行数据处理:

  1. 保持时间依赖性:RNN的核心优势是处理具有时间序列依赖性的数据,比如语言、视频帧、股票价格等。正确的数据处理确保了这些时间依赖性在训练数据中得以保留,使得模型能够学习到数据中的序列特征。
  2. 提高学习效率:通过将数据划分为与模型期望的序列长度匹配的批次,可以提高模型学习的效率。这样做可以确保网络在每次更新时都接收到足够的上下文信息。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的输入数据。例如,标准的RNN只需要过去的信息,而一些变体如LSTM或GRU可能会处理更长的序列。特定的算法,如R2D2,还可能需要额外的步骤(如burn-in),以便更好地初始化网络状态。
  4. 处理不规则长度:在现实世界的数据集中,序列长度往往是不规则的。数据处理确保了每个mini-batch都有统一的序列长度,这通常通过截断过长的序列或填充过短的序列来实现。
  5. 优化内存和计算资源:通过将数据组织成具有固定时间步长的批次,可以更有效地利用GPU等计算资源,因为这些资源在处理固定大小的数据时通常更高效。
  6. 稳定学习过程:特别是在强化学习中,使用如n-step返回或经验回放的技术,可以帮助模型从环境反馈中学习,并减少方差,从而稳定学习过程。

如何进行数据处理

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)    return get_train_sample(data, self._sequence_len)

 代码段 def _get_train_sample(self, data: list) 是一个方法,它的作用是从收集到的数据中提取用于训练 RNN 的样本。这个方法会在两个步骤中处理数据:

  • N步返回计算(get_nstep_return_data): 这个函数接受原始的经验数据,然后计算所谓的 N 步返回值。N 步返回是一个在强化学习中用于临时差分(Temporal Difference, TD)学习的概念,它考虑了从当前状态开始的未来 N 步的累积奖励。计算这个值需要使用折现因子 gamma。这个步骤的目的是为了让智能体学习如何根据当前的行动预测未来的奖励,这是强化学习中价值函数估计的重要部分。
  • 训练样本获取(get_train_sample): 在得到 N 步返回值之后,这个函数进一步处理数据以生成训练样本。具体地,它会根据 self._sequence_len(即时间序列长度或者 RNN 的历史长度)来选择数据序列。这意味着每个训练样本将是一个具有 self._sequence_len 长度的数据序列,这对于训练 RNN 来说是必要的,因为 RNN 需要一定长度的历史来维护其内部状态(或记忆)。

有关这两个数据处理功能的工作流程见下图:

二、初始化隐藏状态 (Hidden State)
RNN用于处理具有时间依赖性的信息。RNN的隐藏状态(Hidden State)是其记忆的一部分,它能够捕捉到前一时间步长的信息。这些信息对于预测下一个动作或状态非常关键。在此上下文中,初始化RNN的隐藏状态是一个重要的步骤,它确保了RNN在开始新的数据批次处理时具有正确的起始状态。
策略的 _learn_model 需要初始化 RNN。这些隐藏状态来自 _collect_model 保存的 prev_state。 用户需要通过 _process_transition 函数将这些状态添加到 _learn_model 输入数据字典中。 

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:    transition = {        'obs': obs,        'action': model_output['action'],        'prev_state': model_output['prev_state'], # add ``prev_state`` key here        'reward': timestep.reward,        'done': timestep.done,    }    return transition

点击DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in - 古月居 可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/675861.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是web3D?应用场景有哪些?如何实现web3D展示?

Web3D是一种将3D技术与网络技术完美结合的全新领域,它可以实现将数字化的3D模型直接在网络浏览器上运行,从而实现在线交互式的浏览和操作。 Web3D通过将多媒体技术、3D技术、信息网络技术、计算机技术等多种技术融合在一起,实现了它在网络上…

2024深圳杯数学建模C题完整思路+配套解题代码+半成品参考论文持续更新

所有资料持续更新,最晚我们将于5.9号更新参考论文。 【无水印word】2024深圳杯A题成品论文23页mtlab(python)双版本代码https://www.jdmm.cc/file/27105652024深圳杯数学建模C题完整思路配套解题代码半成品参考论文持续更新https://www.jdmm.cc/file/2710545 深圳杯…

扭蛋机小程序在互联网浪潮中的崛起与发展

随着互联网的快速发展,各种线上娱乐方式层出不穷,其中扭蛋机小程序凭借其独特的魅力,在互联网浪潮中迅速崛起并发展壮大。扭蛋机小程序不仅打破了传统扭蛋机的地域限制和操作不便,还融入了丰富的互动元素和便捷性,满足…

C#高级编程笔记-委托、lambda表达式和事件

本章的主要内容如下: ● 委托 ● lambda表达式 ● 闭包 ● 事件 目录 1.1 引用方法 1.2 委托 1.2.1 在C#中声明委托 1.2.2 在C#中使用委托 1.2.3 Action和Func委托 1.2.4 多播委托 1.2.5 匿名方法 1.3 lambda表达式 1.3.…

LED显示屏的维护与使用指南

LED显示屏作为一种先进的显示技术,广泛应用于广告、信息显示、舞台背景等领域。然而,为了确保显示屏的长期稳定运行和良好的显示效果,对其进行正确的维护和使用是非常必要的。以下是一些专业的维护与使用建议: 维护建议&#xff1…

ws注入js逆向调用函数

这里需要选择一个文件夹 随便 紫色为修改保存 记得ctrls保存 注入代码如下 (function() {var ws new WebSocket("ws://127.0.0.1:8080")ws.onmessage function(evt) {console.log("收到消息:" evt.data);if (evt.data "exit") {…

Vue + Element-plus 快速入门

1. 构建项目 npm init vuelatest # 可选项一路回车,使用默认NO,按提示执行3条命令 cd 项目名 npm install npm run dev 2. 下载element-plus npm install element-plus --save 3.替换main.js import { createApp } from vue import ElementPlus from element-plu…

cmake进阶:目录属性之 INCLUDE_DIRECTORIES说明二

一. 简介 前面几篇文章学习了 cmake的一些目录属性,主要有两个重要的目录属性INCLUDE_DIRECTORIES 属性、LINK_DIRECTORIES 属性。文章如下: cmake进阶:目录属性之 INCLUDE_DIRECTORIES-CSDN博客 本文学习 父目录的 INCLUDE_DIRECTORIES …

MFC DLL注入失败一些错误总结

使用cheat Engine为MFC窗口程序注入DLL时一定要注意,被注入的exe程序和注入的DLL 的绝对路径中一定不要带有中文字符,否则会遇到各种各样的奇怪错误,如下所示: 以下是dll绝对路径中均含有中文字符,会报错误&#xff…

python从0开始学习(五)

目录 前言 1、顺序结构 2、选择结构 2.1双分支结构 2.2多分枝结构 2.3嵌套使用 2.4多个条件的链接 总结 前言 在上篇文章中,我们学习了python中的运算符,本篇文章继续往下讲解。本篇文章主要讲解程序的组织结构。 1、顺序结构 顺序结构是程序按照…

407627-60-5,AF647 NHS酯一种高亮度的红色荧光试剂

一、产品概述 中文名称:Alexa Fluor 647活化酯,AF647 NHS酯,AF 647 琥珀酰亚胺酯 英文名称:AF647 NHS,Alexa Fluor 647 NHS ester CAS号:1620475-28-6,407627-60-5,1453856-34-2 …

uniapp video 层级覆盖

层级覆盖 cover-view组件 我这里做了个判断 监听全屏时隐藏按钮 根据项目需求自行更改