优先经验回放(prioritized experience replay)

prioritized experience replay 思路

优先经验回放出自ICLR 2016的论文《prioritized experience replay》。

prioritized experience replay的作者们认为,按照一定的优先级来对经验回放池中的样本采样,相比于随机均匀的从经验回放池中采样的效率更高,可以让模型更快的收敛。其基本思想是RL agent在一些转移样本上可以更有效的学习,也可以解释成“更多地训练会让你意外的数据”。

那优先级如何定义呢?作者们使用的是样本的TD error δ \delta δ 的幅值。对于新生成的样本,TD error未知时,将样本赋值为最大优先级,以保证样本至少将会被采样一次。每个采样样本的概率被定义为
P ( i ) = p i α ∑ k p k α P(i) = \frac {p_i^{\alpha}} {\sum_k p_k^{\alpha}} P(i)=kpkαpiα
上式中的 p i > 0 p_i >0 pi>0是回放池中的第i个样本的优先级, α \alpha α则强调有多重视该优先级,如果 α = 0 \alpha=0 α=0,采样就退化成和基础DQN一样的均匀采样了。

p i p_i pi如何取值,论文中提供了如下两种方法,两种方法都是关于TD error δ \delta δ 单调的:

  • 基于比例的优先级: p i = ∣ δ i ∣ + ϵ p_i = |\delta_i| + \epsilon pi=δi+ϵ ϵ \epsilon ϵ是一个很小的正数常量,防止当TD error为0时样本就不会被访问到的情形。(目前大部分实现都是使用的这个形式的优先级)
  • 基于排序的优先级: p i = 1 r a n k ( i ) p_i = \frac {1}{rank(i)} pi=rank(i)1, 式中的 r a n k ( i ) rank(i) rank(i)是样本根据 ∣ δ i ∣ |\delta_i| δi 在经验回放池中的排序号,此时P就变成了带有指数 α \alpha α的幂率分布了。

作者们定义的概率调整了样本的优先级,因此也就在数据分布中引入了偏差,为了弥补偏差,使用了重要性采样权重(importance-sampling (IS) weights):
w i = ( 1 N ⋅ 1 P ( i ) ) β w_i = \left( \frac{1}{N} \cdot \frac{1}{P(i)} \right)^{\beta} wi=(N1P(i)1)β
上式权重中,当 β = 1 \beta=1 β=1时就完全补偿了非均匀概率采样引入的偏差,作者们提到为了收敛性考虑,最后让 β \beta β从0到1中的某个值开始,并逐渐增加到1。在Q-learning更新时使用这些权重乘以TD error,也就是使用 w i δ i w_i \delta_i wiδi而不是原来的 δ i \delta_i δi。此外,为了使训练更稳定,总是对权重乘以 1 / m a x i w i 1/\mathcal{max}_i{w_i} 1/maxiwi进行归一化。

以Double DQN为例,使用优先经验回放的算法(论文算法1)如下图:

在这里插入图片描述

prioritized experience replay 实现

直接实现优先经验回放池如下代码(修改自代码 )

class PrioReplayBufferNaive:def __init__(self, buf_size, prob_alpha=0.6, epsilon=1e-5, beta=0.4, beta_increment_per_sampling=0.001):self.prob_alpha = prob_alphaself.capacity = buf_sizeself.pos = 0self.buffer = []self.priorities = np.zeros((buf_size, ), dtype=np.float32)self.beta = betaself.beta_increment_per_sampling = beta_increment_per_samplingself.epsilon = epsilondef __len__(self):return len(self.buffer)def size(self):  # 目前buffer中数据的数量return len(self.buffer)def add(self, sample):# 新加入的数据使用最大的优先级,保证数据尽可能的被采样到max_prio = self.priorities.max() if self.buffer else 1.0if len(self.buffer) < self.capacity:self.buffer.append(sample)else:self.buffer[self.pos] = sampleself.priorities[self.pos] = max_prioself.pos = (self.pos + 1) % self.capacitydef sample(self, batch_size):if len(self.buffer) == self.capacity:prios = self.prioritieselse:prios = self.priorities[:self.pos]probs = np.array(prios, dtype=np.float32) ** self.prob_alphaprobs /= probs.sum()indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)samples = [self.buffer[idx] for idx in indices]total = len(self.buffer)self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])weights = (total * probs[indices]) ** (-self.beta)weights /= weights.max()return samples, indices, np.array(weights, dtype=np.float32)def update_priorities(self, batch_indices, batch_priorities):'''更新样本的优先级'''for idx, prio in zip(batch_indices, batch_priorities):self.priorities[idx] = prio + self.epsilon

直接实现的优先经验回放,在样本数很大时的采样效率不够高,作者们通过定义了sumtree的数据结构来存储样本优先级,该数据结构的每一个节点的值为其子节点之和,而样本优先级被放在树的叶子节点上,树的根节点的值为所有优先级之和 p t o t a l p_{total} ptotal,更新和采样时的效率为 O ( l o g N ) O(logN) O(logN)。在采样时,设采样批次大小为k,将 [ 0 , p t o t a l ] [0, p_{total}] [0,ptotal]均分为k等份,然后在每一个区间均匀的采样一个值,再通过该值从树中提取到对应的样本。python 实现如下(代码来源)

class SumTree:"""父节点的值是其子节点值之和的二叉树数据结构"""write = 0def __init__(self, capacity):self.capacity = capacityself.tree = np.zeros(2 * capacity - 1)self.data = np.zeros(capacity, dtype=object)self.n_entries = 0# update to the root nodedef _propagate(self, idx, change):parent = (idx - 1) // 2self.tree[parent] += changeif parent != 0:self._propagate(parent, change)# find sample on leaf nodedef _retrieve(self, idx, s):left = 2 * idx + 1right = left + 1if left >= len(self.tree):return idxif s <= self.tree[left]:return self._retrieve(left, s)else:return self._retrieve(right, s - self.tree[left])def total(self):return self.tree[0]# store priority and sampledef add(self, p, data):idx = self.write + self.capacity - 1self.data[self.write] = dataself.update(idx, p)self.write += 1if self.write >= self.capacity:self.write = 0if self.n_entries < self.capacity:self.n_entries += 1# update prioritydef update(self, idx, p):change = p - self.tree[idx]self.tree[idx] = pself._propagate(idx, change)# get priority and sampledef get(self, s):idx = self._retrieve(0, s)dataIdx = idx - self.capacity + 1return (idx, self.tree[idx], self.data[dataIdx])class PrioReplayBuffer:  # stored as ( s, a, r, s_ ) in SumTreeepsilon = 0.01alpha = 0.6beta = 0.4beta_increment_per_sampling = 0.001def __init__(self, capacity):self.tree = SumTree(capacity)self.capacity = capacitydef _get_priority(self, error):return (np.abs(error) + self.epsilon) ** self.alphadef add(self, error, sample):p = self._get_priority(error)self.tree.add(p, sample)def sample(self, n):batch = []idxs = []segment = self.tree.total() / npriorities = []self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])for i in range(n):a = segment * ib = segment * (i + 1)s = random.uniform(a, b)(idx, p, data) = self.tree.get(s)priorities.append(p)batch.append(data)idxs.append(idx)sampling_probabilities = priorities / self.tree.total()is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)is_weight /= is_weight.max()return batch, idxs, is_weightdef update(self, idx, error):'''这里是一次更新一个样本,所以在调用时,写for循环依次更次样本的优先级'''p = self._get_priority(error)self.tree.update(idx, p)

参考资料

  1. Schaul, Tom, John Quan, Ioannis Antonoglou, and David Silver. 2015. “Prioritized Experience Replay.” arXiv: Learning,arXiv: Learning, November.

  2. sum_tree的实现代码

  3. 相关blog: 1 (对应的代码), 2, 3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/208724.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3组件化开发页面之渲染函数实现

文章目录 前言一、渲染机制虚拟 DOM渲染管线 二、渲染函数基本用法声明渲染函数Vnodes 必须唯一 三、页面使用渲染函数及组件配置总结如有启发&#xff0c;可点赞收藏哟~ 前言 组件化开发是目前开发的常态 本文记录页面拆分多个不同组件模块&#xff0c;然后再基于渲染函数实现…

const修饰

const 起保护作用&#xff0c;禁止修改。 此时a变为常量&#xff0c;常量不可修改。 const放在*p的左端限制*p&#xff0c;即不能通过修改指针变量&#xff08;*p&#xff09;的值来修改p指向空间的内容&#xff0c;但p不受限制。 const放在*的右端限制p&#xff0c;即不能修…

Unity 三维场景的搭建 软件构造实验报告

实验2&#xff1a;仿真系统功能实现 1.实验目的 &#xff08;1&#xff09;熟悉在Unity中设置仿真场景&#xff1b; &#xff08;2&#xff09;熟悉在Unity中C#语言的使用&#xff1b; &#xff08;3&#xff09;熟悉仿真功能的实现。 2.实验内容 新建一个仿真场景&#x…

中职组网络安全 Server-Hun-1.img Server-Hun-2.img

一串密码 smbuser用户和密码登录ssh还是失败提示需要密钥&#xff0c;尝试ftp登录成功 发现密钥存放在.ssh/下&#xff0c;在kali上生成一个密钥&#xff0c;通过上传到.ssh/下&#xff0c;将其替换掉 使用kali生成密钥 登录成功,但是无法拿到root目录下的flag 获取root用户权限…

《QT从基础到进阶·三十七》QWidget实现左侧导航栏效果

NavigationBarPlugin插件类实现了对左侧导航栏的管理&#xff0c;我们可以在导航栏插件中添加界面&#xff0c;并用鼠标点击导航栏能够切换对应的界面。 源码在文章末尾 实现效果如下&#xff1a; NavigationBarPlugin实现的接口如下&#xff1a; class NAVIGATIONBAR_EXP…

Nevron Vision for .NET 2023.1 Crack

Nevron Vision for .NET 适用于桌面和 Web 应用程序的高级数据可视化 Nevron Vision for .NET提供最全面的组件&#xff0c;用于构建面向 Web 和桌面的企业级数据可视化应用程序。 该套件中的组件具有连贯的 2D 和 3D 数据可视化效果&#xff0c;对观众产生巨大的视觉冲击力。我…

3.Gin 框架中的路由简要说明

3.Gin 框架中的路由简要说明 Gin 框架中的路由 路由概述 路由&#xff08;Routing&#xff09;是由一个 URI&#xff08;或者叫路径&#xff09;和一个特定的 HTTP 方法&#xff08;GET、POST 等&#xff09; 组成的&#xff0c;涉及到应用如何响应客户端对某个网站节点的访问。…

SpringBoot——启动类的原理

优质博文&#xff1a;IT-BLOG-CN SpringBoot启动类上使用SpringBootApplication注解&#xff0c;该注解是一个组合注解&#xff0c;包含多个其它注解。和类定义SpringApplication.run要揭开SpringBoot的神秘面纱&#xff0c;我们要从这两位开始就可以了。 SpringBootApplicati…

VUE语法-$refs和ref属性的使用

1、$refs和ref属性的使用 1、$refs:一个包含 DOM 元素和组件实例的对象&#xff0c;通过模板引用注册。 2、ref实际上获取元素的DOM节点 3、如果需要在Vue中操作DOM我们可以通过ref和$refs这两个来实现 总结:$refs可以获取被ref属性修饰的元素的相关信息。 1.1、$refs和re…

解决ElementUI时间选择器回显出现Wed..2013..中国标准时间.

使用饿了么组件 时间日期选择框回显到页面为啥是这样的&#xff1f; 为什么再时间框中选择日期&#xff0c;回显页面出现了这种英文格式呢&#xff1f;&#xff1f;&#xff1f;&#xff1f; 其实这个问题直接使用elementui的内置属性就能解决 DateTimePicker 日期时间选择…

搭个网页应用,让ChatGPT帮我写SQL

大家好&#xff0c;我是凌览。 开门见山&#xff0c;我搭了一个网页应用名字叫sql-translate。访问链接挂在我的个人博客(https://linglan01.cn/about)导航栏&#xff0c;也可以访问https://www.linglan01.cn/c/sql-translate/直达sql-translate。 它的主要功能有&#xff1a;…

Linux上通过SSL/TLS和start tls连接到LDAP服务器

一&#xff0c;大致流程。 1.首先在Linux上搭建一个LDAP服务器 2.在LDAP服务器上安装CA证书&#xff0c;服务器证书&#xff0c;因为SSL/TLS&#xff0c;start tls都属于机密通信&#xff0c;需要客户端和服务器都存在一个相同的证书认证双方的身份。3.安装phpldapadmin工具&am…