Hands on RL 之 Deep Deterministic Policy Gradient(DDPG)

Hands on RL 之 Deep Deterministic Policy Gradient(DDPG)

文章目录

  • Hands on RL 之 Deep Deterministic Policy Gradient(DDPG)
    • 1. 理论部分
      • 1.1 回顾 Deterministic Policy Gradient(DPG)
      • 1.2 Neural Network Difference
      • 1.3 Why is off-policy?
      • 1.4 Soft target update
      • 1.5 Maintain Exploration
      • 1.6 Other Techniques
      • 1.7 Pesudocode
    • 2. 代码实践
    • Reference

1. 理论部分

1.1 回顾 Deterministic Policy Gradient(DPG)

在介绍DDPG之前,我们先回顾一下DPG中最重要的结论,

Deterministic Policy Gradient Theorem即确定性策略梯度定理

∇ θ J ( μ θ ) = ∫ S ρ μ ( s ) ∇ θ μ θ ( s ) ∇ a Q μ ( s , a ) ∣ a = μ θ ( s ) d s = E s ∼ ρ μ [ ∇ θ μ θ ( s ) ∇ a Q μ ( s , a ) ∣ a = μ θ ( s ) ] \begin{aligned} \nabla_\theta J(\mu_\theta) & = \int_{\mathcal{S}} \rho^\mu(s) \nabla_\theta \mu_\theta(s) \nabla_a Q^\mu(s,a)|_{a=\mu_\theta(s)} \mathrm{d}s \\ & = \mathbb{E}_{s\sim\rho^\mu} \Big[ \nabla_\theta \mu_\theta(s) \nabla_a Q^\mu(s,a)|_{a=\mu_\theta(s)} \Big] \end{aligned} θJ(μθ)=Sρμ(s)θμθ(s)aQμ(s,a)a=μθ(s)ds=Esρμ[θμθ(s)aQμ(s,a)a=μθ(s)]

其中, a = μ θ ( s ) a=\mu_\theta(s) a=μθ(s)表示确定性的策略是从状态空间到动作空间的映射 μ θ : S → A \mu_\theta: \mathcal{S}\to\mathcal{A} μθ:SA,网络的参数为 θ \theta θ s ∼ ρ μ s\sim\rho^\mu sρμ表示状态 s s s符合在策略 μ \mu μ下的状态访问分布。如何推导的,这里不详细阐述。(可以参考Deterministic policy gradient algorithms)

接下来逐点介绍DDPG相较于DPG的改进

1.2 Neural Network Difference

​ DDPG在相较于传统的AC算法在网络结构上也有很大不同,首先看看传统算法的网络结构

Image

然后再看看DDPG的网络结构

Image

为什么DDPG会是这样的网络结构呢,这是因为DDPG中的actor输出的是确定性动作,而不是动作的概率分布,因此确定性的动作是连续的可以看作动作空间的维度为无穷,如果采用AC中critic的结构,我们无法通过遍历所有动作来取出某个特定动作对应的Q-value。因此DDPG中将actor的输出作为critic的输入,再联合状态输入,就能直接获得所采取动作 a = μ ( s t ) a=\mu(st) a=μ(st)的Q-value。

1.3 Why is off-policy?

​ 首先为什么DDPG或者说DPG是off-policy的?我们回顾stochastic policy π θ ( a ∣ s ) \pi_\theta(a|s) πθ(as)定义下的Q-value
Q π ( s t , a t ) = E r t , s t + 1 ∼ E [ r ( s t , a t ) + γ E a t + 1 ∼ π [ Q π ( s t + 1 , a t + 1 ) ] ] Q^\pi(s_t,a_t) = \mathbb{E}_{r_t, s_{t+1}\sim E}[r(s_t,a_t) + \gamma \mathbb{E}_{a_{t+1}\sim\pi}[Q^\pi(s_{t+1}, a_{t+1})]] Qπ(st,at)=Ert,st+1E[r(st,at)+γEat+1π[Qπ(st+1,at+1)]]
其中, E E E表示的是环境,即状态 s ∼ E s\sim E sE状态符合环境本身的分布。当我们使用确定性策略的时候 a = μ θ ( s ) a=\mu_\theta(s) a=μθ(s),那么inner expectation就自动被抵消掉了
Q π ( s t , a t ) = E r t , s t + 1 ∼ E [ r ( s t , a t ) + γ Q π ( s t + 1 , a t + 1 = μ ( s t + 1 ) ) ] Q^\pi(s_t,a_t) = \mathbb{E}_{r_t, s_{t+1}\sim E}[r(s_t,a_t) + \gamma Q^\pi(s_{t+1}, a_{t+1}=\mu(s_{t+1}))] Qπ(st,at)=Ert,st+1E[r(st,at)+γQπ(st+1,at+1=μ(st+1))]
这就意味着Q-value不再依赖于动作的访问分布,即没有了 a t + 1 ∼ π a_{t+1}\sim\pi at+1π。那么我们就可以通过行为策略behavior policy β \beta β产生的结果来计算该值,这让off-policy成为可能。

​ 实际上Q-value不再依赖于动作的访问分布,那么确定性梯度定理可以写作
∇ θ J ( μ θ ) ≈ E s ∼ ρ β [ ∇ θ μ θ ( s ) ∇ a Q μ ( s , a ) ∣ a = μ θ ( s ) ] \textcolor{red}{\nabla_\theta J(\mu_\theta) \approx \mathbb{E}_{s\sim\rho^\beta} \Big[ \nabla_\theta \mu_\theta(s) \nabla_a Q^\mu(s,a)|_{a=\mu_\theta(s)} \Big]} θJ(μθ)Esρβ[θμθ(s)aQμ(s,a)a=μθ(s)]
可以写作依赖于behavior policy β \beta β产生的状态访问分布的期望,这就是一种off-policy的形式。

1.4 Soft target update

​ 在DDPG中维护了四个神经网络,分别是policy network, target policy network, action value network, target action value network。使用了DQN中的将目标网络和训练网络分离的思想,并且采用soft更新的方式,能够更有效维护训练中的稳定性。soft更新方式如下
θ − ← τ θ + ( 1 − τ ) θ − \theta^- \leftarrow \tau \theta + (1-\tau)\theta^- θτθ+(1τ)θ
其中, θ − \theta^- θ表示目标网络参数, θ \theta θ表示训练网络参数, τ ≪ 1 \tau \ll 1 τ1 τ \tau τ是软更新参数。

1.5 Maintain Exploration

​ 确定性的策略是不具有探索性的,为了保持策略的探索性,我们可以在策略网络的输出中增加高斯噪声,让输出的动作值有些许偏差来增加网络的探索性。用数学的方式来表示即是
μ ′ ( s t ) = μ θ ( s t ) + N \mu^\prime(s_t) = \mu_\theta(s_t) + \mathcal{N} μ(st)=μθ(st)+N
其中 μ ′ \mu^\prime μ表示探索性的策略, N \mathcal{N} N表示高斯噪声。

1.6 Other Techniques

​ DDPG还集成了一些别的算法的常用技巧,比如Replay Buffer来产生independent and identically distribution的样本,使用了Batch Normalization来预处理数据。

1.7 Pesudocode

伪代码如下

Image

2. 代码实践

我们采用gym中的Pendulum-v1作为本次实验的环境,Pendulum-v1是典型的确定性连续动作空间环境,整体的代码如下

import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import random
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import collections# Policy Network
class PolicyNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = nn.Linear(state_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, action_dim)self.action_bound = action_bounddef forward(self, observation):x = F.relu(self.fc1(observation))x = F.tanh(self.fc2(x))return x * self.action_bound# Q Value Network
class QValueNet(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, hidden_dim)self.fc_out = nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)    # 拼接状态和动作x = F.relu(self.fc1(cat))x = F.relu(self.fc2(x))return self.fc_out(x)# Deep Deterministic Policy Gradient
class DDPG():def __init__(self, state_dim, hidden_dim, action_dim, action_bound, actor_lr, critic_lr, sigma, tau, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim, action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)# initialize target actor network with same parametersself.target_actor.load_state_dict(self.actor.state_dict())# initialize target critic network with same parametersself.target_critic.load_state_dict(self.critic.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差,均值直接设置为0self.action_dim = action_dimself.device = deviceself.tau = taudef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action = self.actor(state).item()# add noise to increase exploratoryaction = action + self.sigma * np.random.randn(self.action_dim)return actiondef soft_update(self, net, target_net):# implement soft update rulefor param_target, param in zip(target_net.parameters(), net.parameters()):param_target.data.copy_(param_target.data * (1.0-self.tau) + param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1,1).to(self.device)actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1,1).to(self.device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1,1).to(self.device)next_q_values = self.target_critic(next_states, self.target_actor(next_states))td_targets = rewards + self.gamma * next_q_values * (1-dones)critic_loss = torch.mean(F.mse_loss(self.critic(states, actions), td_targets))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()actor_loss = torch.mean( - self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()# soft update actor net and critic netself.soft_update(self.actor, self.target_actor)self.soft_update(self.critic, self.target_critic)class ReplayBuffer():def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)def add(self, s, a, r, s_, d):self.buffer.append((s,a,r,s_,d))def sample(self, batch_size):transitions = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*transitions)return np.array(states), actions, np.array(rewards), np.array(next_states), donesdef size(self):return len(self.buffer)def train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size, render, seed_number):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d'%(i+1)) as pbar:for i_episode in range(int(num_episodes/10)):observation, _ = env.reset(seed=seed_number)done = Falseepisode_return = 0while not done:if render:env.render()action = agent.take_action(observation)observation_, reward, terminated, truncated, _ = env.step(action)done = terminated or truncatedreplay_buffer.add(observation, action, reward, observation_, done)# swap statesobservation = observation_episode_return += rewardif replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'rewards': b_r,'next_states': b_ns,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if(i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d'%(num_episodes/10 * i + i_episode + 1),'return': "%.3f"%(np.mean(return_list[-10:]))})pbar.update(1)env.close()return return_listdef moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def plot_curve(return_list, mv_return, algorithm_name, env_name):episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list, c='gray', alpha=0.6)plt.plot(episodes_list, mv_return)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('{} on {}'.format(algorithm_name, env_name))plt.show()if __name__ == "__main__":# reproducibleseed_number = 0random.seed(seed_number)np.random.seed(seed_number)torch.manual_seed(seed_number)num_episodes = 250     # episodes lengthhidden_dim = 128        # hidden layers dimensiongamma = 0.98            # discounted rateactor_lr = 1e-3         # lr of actorcritic_lr = 1e-3        # lr of critictau = 0.005             # soft update parametersigma = 0.01            # std variance of guassian noisebuffer_size = 10000minimal_size = 1000batch_size = 64device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')env_name = 'Pendulum-v1'render = Falseif render:env = gym.make(id=env_name, render_mode='human')else:env = gym.make(id=env_name)state_dim = env.observation_space.shape[0]action_dim = env.action_space.shape[0]  action_bound = env.action_space.high[0]replay_buffer = ReplayBuffer(buffer_size)        agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr, critic_lr, sigma, tau, gamma, device)return_list = train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size, render, seed_number)mv_return = moving_average(return_list, 9)plot_curve(return_list, mv_return, 'DDPG', env_name)

DDPG训练的回报曲线如图所示

Image

Reference

Tutorial: Hands on RL

Paper: Continuous control with deep reinforcement learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/68901.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elisp之获取PC电池状态(二十八)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

SAP Fiori 问题收集

事务代码篇 启动工作台:/N/UI2/FLP 错误日志: /n/IWFND/ERROR_LOG 服务清单: /n/IWFND/MAINT_SERVICE 创建语义对象:/N/UI2/SEMOBJ 创建目录:/N/UI2/FLPD_CONF(cross-client)或 /N/UI2…

OpenCV-Python中的图像处理-视频分析

OpenCV-Python中的图像处理-视频分析 视频分析Meanshift算法Camshift算法光流Lucas-Kanade Optical FlowDense Optical Flow 视频分析 学习使用 Meanshift 和 Camshift 算法在视频中找到并跟踪目标对象: Meanshift算法 Meanshift 算法的基本原理是和很简单的。假设我们有一堆…

Spring Security用户授权

用户认证在上一篇用户认证 用户授权 总体流程: 在SpringSecurity中,会使用默认的FilterSecurityInterceptor来进行权限校验。在FilterSecurityInterceptor中会从SecurityContextHolder获取其中的Authentication,然后获取其中的权限信息。…

idea中Maven报错Unable to import maven project: See logs for details问题的解决方法

idea中Maven报错Unable to import maven project: See logs for details问题的解决方法。 在查看maven的环境配置和idea的maven配置后,发现是idea 2020版本和maven 3.9.3版本的兼容性问题。在更改为Idea自带的maven 3.6.1版本后问题解决,能成功下载jar包…

word 应用 打不开 显示一直是正在启动中

word打开来显示一直正在启动中,其他调用word的应用也打不开,网上查了下以后进程关闭spoolsv.exe,就可以正常打开word了

AI绘画 | 一文学会Midjourney绘画,创作自己的AI作品(快速入门+参数介绍)

一、生成第一个AI图片 首先,生成将中文描述词翻译成英文 然后在输入端输入:/imagine prompt:Bravely running boy in Q version, cute head portrait 最后,稍等一会即可输出效果 说明: 下面的U1、U2、U3、U4代表的第一张、第二张…

AI Chat 设计模式:15. 桥接模式

本文是该系列的第十五篇,采用问答式的方式展开,问题由我提出,答案由 Chat AI 作出,灰色背景的文字则主要是我的一些思考和补充。 问题列表 Q.1 如果你是第一次接触桥接模式,那么你会有哪些疑问呢?A.1Q.2 什…

JavaScript(JavaEE初阶系列13)

目录 前言: 1.初识JavaScript 2.JavaScript的书写形式 2.1行内式 2.2内嵌式 2.3外部式 2.4注释 2.5输入输出 3.语法 3.1变量的使用 3.2基本数据类型 3.3运算符 3.4条件语句 3.5循环语句 3.6数组 3.7函数 3.8对象 3.8.1 对象的创建 4.案例演示 4…

广告ROI可洞察到订单转化率啦

toB广告营销人的一日三问&#xff1a; 如何实现线索增长&#xff1f;如何获取更多高质量线索&#xff1f;如何能用更少的钱拿到更多高质量的线索&#xff1f; < 广告营销的终极目标&#xff0c;就是提升ROI > 从ROI公式中&#xff0c;可以找到提升广告营销ROI的路径&…

爬楼梯(一次爬1或2层)

一&#xff0c;题目描述 二&#xff0c;解题思路 动态规划 动规五部曲&#xff1a; 1. 确认dp数组以及下标含义 2. 推导递推公式 3. 确认dp数组如何初始化 4. 确认遍历顺序 5. 打印dp数组 dp数组含义&#xff1a;到第i层的方法数目 下标含义&#xff1a;层数 递推公式&…

WebRTC | SDP详解

目录 一、SDP标准规范 1. SDP结构 2. SDP内容及type类型 二、WebRTC中的SDP结构 1. 媒体信息描述 &#xff08;1&#xff09;SDP中媒体信息格式 i. “artpmap”属性 ii. “afmtp”属性 &#xff08;2&#xff09;SSRC与CNAME &#xff08;3&#xff09;举个例子 &…