【强化学习】13 —— Actor-Critic 算法

文章目录

  • REINFORCE 存在的问题
  • Actor-Critic
  • A2C: Advantageous Actor-Critic
  • 代码实践
    • 结果
  • 参考

REINFORCE 存在的问题

  • 基于片段式数据的任务
    • 通常情况下,任务需要有终止状态,REINFORCE才能直接计算累计折扣奖励
  • 低数据利用效率
    • 实际中,REINFORCE需要大量的训练数据
  • 高训练方差(最重要的缺陷
    • 从单个或多个片段中采样到的值函数具有很高的方差

Actor-Critic

在 REINFORCE 算法中,目标函数的梯度中有一项轨迹回报,用于指导策略的更新。REINFOCE 算法用蒙特卡洛方法来估计 Q ( s , a ) Q(s,a) Q(s,a),能不能考虑拟合一个值函数来指导策略进行学习呢?这正是 Actor-Critic 算法所做的。
在这里插入图片描述
评论家Critic Q Φ ( s , a ) Q_\Phi (s,a) QΦ(s,a):

  • 学会准确估计当前演员策略(actor policy)的动作价值。通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。 Q Φ ( s , a ) ≃ r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) , a ′ ∼ π θ ( a ′ ∣ s ′ ) [ Q Φ ( s ′ , a ′ ) ] Q_\Phi(s,a)\simeq r(s,a)+\gamma\mathbb{E}_{s^{\prime}\thicksim p(s^{\prime}|s,a),a^{\prime}\thicksim\pi_\theta(a^{\prime}|s^{\prime})}[Q_\Phi(s^{\prime},a^{\prime})] QΦ(s,a)r(s,a)+γEsp(ss,a),aπθ(as)[QΦ(s,a)]

演员Actor π θ ( s , a ) \pi_\theta(s,a) πθ(s,a):

  • 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。 J ( θ ) = E s ∼ p , π θ [ π θ ( a ∣ s ) Q Φ ( s , a ) ] ∂ f ( θ ) ∂ θ = E π θ [ ∂ log ⁡ π θ ( a ∣ s ) ∂ θ Q Φ ( s , a ) ] \begin{aligned}J(\theta)&=\mathbb{E}_{s\sim p,\pi_\theta}[\pi_\theta(a|s)Q_\Phi(s,a)]\\\\\frac{\partial f(\theta)}{\partial\theta}&=\mathbb{E}_{\pi_\theta}\left[\frac{\partial\log\pi_\theta(a|s)}{\partial\theta}Q_\Phi(s,a)\right]\end{aligned} J(θ)θf(θ)=Esp,πθ[πθ(as)QΦ(s,a)]=Eπθ[θlogπθ(as)QΦ(s,a)]

A2C: Advantageous Actor-Critic

思想:通过减去一个基线函数来标准化评论家的打分

  • 更多信息指导:降低较差动作概率,提高较优动作概率
  • 进一步降低方差

优势函数(Advantage Function) A π ( s , a ) = Q π ( s , a ) − V π ( s ) A^\pi(s,a)=Q^\pi(s,a)-V^\pi(s) Aπ(s,a)=Qπ(s,a)Vπ(s)
在这里插入图片描述
若只采用动作值的方式,虽然也会选择A2,但是方差相对会更大,同时所有的动作都是出于上升的状态,只是上升程度的问题。而采用优势函数的方式,部分动作的优势函数值是负的,可以直接降低相应动作的概率,同时方差更小。

状态-动作值和状态值函数 Q π ( s , a ) = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) , a ′ ∼ π θ ( a ′ ∣ s ′ ) [ Q Φ ( s ′ , a ′ ) ] = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) [ V π ( s ′ ) ] \begin{aligned} Q^{\pi}(s,a)& =r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a),a^{\prime}\sim\pi_\theta(a^{\prime}|s^{\prime})}\left[Q_\Phi(s^{\prime},a^{\prime})\right] \\ &=r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a)}[V^{\pi}(s^{\prime})] \end{aligned} Qπ(s,a)=r(s,a)+γEsp(ss,a),aπθ(as)[QΦ(s,a)]=r(s,a)+γEsp(ss,a)[Vπ(s)]

因此我们只需要拟合状态值函数来拟合优势函数 A π ( s , a ) = Q π ( s , a ) − V π ( s ) = r ( s , a ) + γ E s ′ ∼ p ( s ′ ∣ s , a ) [ V π ( s ′ ) − V π ( s ) ] ≃ r ( s , a ) + γ ( V π ( s ′ ) − V π ( s ) ) \begin{aligned} A^{\pi}(s,a)& =Q^\pi(s,a)-V^\pi(s) \\ &=r(s,a)+\gamma\mathbb{E}_{s^{\prime}\sim p(s^{\prime}|s,a)}[V^{\pi}(s^{\prime})-V^{\pi}(s)] \\ &\simeq r(s,a)+\gamma(V^{\pi}(s^{\prime})-V^{\pi}(s)) \end{aligned} Aπ(s,a)=Qπ(s,a)Vπ(s)=r(s,a)+γEsp(ss,a)[Vπ(s)Vπ(s)]r(s,a)+γ(Vπ(s)Vπ(s))


在策略梯度中,可以把梯度写成下面这个更加一般的形式: g = E [ ∑ t = 0 T ψ t ∇ θ log ⁡ π θ ( a t ∣ s t ) ] g=\mathbb{E}\left[\sum_{t=0}^T\psi_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right] g=E[t=0Tψtθlogπθ(atst)]其中, ψ t \psi_t ψt可以有很多种形式: 1. ∑ t ′ = 0 T γ t ′ r t ′ : 轨迹的总回报; 2. ∑ t ′ = t T γ t ′ − t r t ′ : 动作 a t 之后的回报; 3. ∑ t ′ = t T γ t ′ − t r t ′ − b ( s t ) : 基准线版本的改进 ; 4. Q π θ ( s t , a t ) : 动作价值函数; 5. A π θ ( s t , a t ) : 优势函数; 6. r t + γ V π θ ( s t + 1 ) − V π θ ( s t ) : 时序差分残差。 \begin{aligned} &1.\sum_{t^{\prime}=0}^T\gamma^{t^{\prime}}r_{t^{\prime}}:\textit{轨迹的总回报;} \\ &2.\sum_{t^{\prime}=t}^T\gamma^{t^{\prime}-t}r_{t^{\prime}}:\textit{动作}a_t\textit{之后的回报;} \\ &\begin{aligned}3.\sum_{t^{\prime}=t}^T\gamma^{t^{\prime}-t}r_{t^{\prime}}-b(s_t):\textit{基准线版本的改进};\end{aligned} \\ &4.Q^{\pi_\theta}(s_t,a_t):\textit{动作价值函数;} \\ &5.A^{\pi_\theta}(s_t,a_t):\textit{优势函数;} \\ &6.r_t+\gamma V^{\pi_\theta}(s_{t+1})-V^{\pi_\theta}(s_t):\textit{时序差分残差。} \end{aligned} 1.t=0Tγtrt:轨迹的总回报;2.t=tTγttrt:动作at之后的回报;3.t=tTγttrtb(st):基准线版本的改进;4.Qπθ(st,at):动作价值函数;5.Aπθ(st,at):优势函数;6.rt+γVπθ(st+1)Vπθ(st):时序差分残差。

REINFORCE 通过蒙特卡洛采样的方法对策略梯度的估计是无偏的,但是方差非常大。我们可以用形式(3)引入基线函数 b ( s t ) b(s_t) b(st)(baseline function)来减小方差。此外,我们也可以采用 Actor-Critic 算法估计一个动作价值函数 Q Q Q,代替蒙特卡洛采样得到的回报,这便是形式(4)。这个时候,我们可以把状态价值函数 V V V作为基线,从 Q Q Q函数减去这个 V V V函数则得到了 A A A函数,我们称之为优势函数(advantage function),这便是形式(5)。更进一步,我们可以利用等式 Q = r + γ V Q=r+\gamma V Q=r+γV得到形式(6)。

Actor 的更新采用策略梯度的原则,那 Critic 如何更新呢?我们将 Critic 价值网络表示为 V ω V_\omega Vω,参数为 ω \omega ω。于是,我们可以采取时序差分残差的学习方式,对于单个数据定义如下价值函数的损失函数: L ( ω ) = 1 2 ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) 2 \mathcal{L}(\omega)=\frac12(r+\gamma V_\omega(s_{t+1})-V_\omega(s_t))^2 L(ω)=21(r+γVω(st+1)Vω(st))2

与 DQN 中一样,我们采取类似于目标网络的方法,将上式中 r + γ V ω ( s t + 1 ) r+\gamma V_\omega(s_{t+1}) r+γVω(st+1)作为时序差分目标,不会产生梯度来更新价值函数。因此,价值函数的梯度为:
∇ ω L ( ω ) = − ( r + γ V ω ( s t + 1 ) − V ω ( s t ) ) ∇ ω V ω ( s t ) \nabla_{\omega}\mathcal{L}(\omega)=-(r+\gamma V_{\omega}(s_{t+1})-V_{\omega}(s_{t}))\nabla_{\omega}V_{\omega}(s_{t}) ωL(ω)=(r+γVω(st+1)Vω(st))ωVω(st)

然后使用梯度下降方法来更新 Critic 价值网络参数即可。

算法伪代码:

在这里插入图片描述

代码实践

import gymnasium as gym
import numpy as np
from tqdm import tqdm
import torch
import torch.nn.functional as F
import utilclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)# 输入是某个状态,输出则是状态的价值。
class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class ActorCritic:def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma,device, numOfEpisodes, env):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.gamma = gammaself.device = deviceself.env = envself.numOfEpisodes = numOfEpisodes# 根据动作概率分布随机采样def takeAction(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action_probs = self.actor(state)action_dist = torch.distributions.Categorical(action_probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']), dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']), dtype=torch.float).to(self.device)terminateds = torch.tensor(transition_dict['terminateds'], dtype=torch.float).view(-1, 1).to(self.device)truncateds = torch.tensor(transition_dict['truncateds'], dtype=torch.float).view(-1, 1).to(self.device)# 时序差分目标td_target = rewards + self.gamma * self.critic(next_states) * (1 - terminateds + truncateds)# 时序差分误差td_delta = td_target - self.critic(states)log_probs = torch.log(self.actor(states).gather(1, actions))# 均方误差损失函数actor_loss = torch.mean(-log_probs * td_delta.detach())critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()def ACTrain(self):returnList = []for i in range(10):with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:for episode in range(int(self.numOfEpisodes / 10)):# initialize statestate, info = self.env.reset()terminated = Falsetruncated = FalseepisodeReward = 0transition_dict = {'states': [],'actions': [],'next_states': [],'rewards': [],'terminateds': [],'truncateds': []}# Loop for each step of episode:while 1:action = self.takeAction(state)next_state, reward, terminated, truncated, info = self.env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['terminateds'].append(terminated)transition_dict['truncateds'].append(truncated)state = next_stateepisodeReward += rewardif terminated or truncated:breakself.update(transition_dict)returnList.append(episodeReward)if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报pbar.set_postfix({'episode':'%d' % (self.numOfEpisodes / 10 * i + episode + 1),'return':'%.3f' % np.mean(returnList[-10:])})pbar.update(1)return returnList

结果

在这里插入图片描述
在这里插入图片描述
可以发现,Actor-Critic 算法很快便能收敛到最优策略,并且训练过程非常稳定,抖动情况相比 REINFORCE 算法有了明显的改进,这说明价值函数的引入减小了方差。

参考

[1] 伯禹AI
[2] https://www.davidsilver.uk/teaching/
[3] 动手学强化学习
[4] Reinforcement Learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/155601.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI:47-基于深度学习的人像背景替换研究

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌本专栏包含以下学习方向: 机器学习、深度学…

Py之transformers_stream_generator:transformers_stream_generator的简介、安装、使用方法之详细攻略

Py之transformers_stream_generator:transformers_stream_generator的简介、安装、使用方法之详细攻略 目录 transformers_stream_generator的简介 1、Web Demo T1、original T2、stream transformers_stream_generator的安装 transformers_stream_generator的…

恒驰服务 | 华为云数据使能专家服务offering之数仓建设

恒驰大数据服务主要针对客户在进行智能数据迁移的过程中,存在业务停机、数据丢失、迁移周期紧张、运维成本高等问题,通过为客户提供迁移调研、方案设计、迁移实施、迁移验收等服务内容,支撑客户实现快速稳定上云,有效降低时间成本…

[论文笔记]GTE

引言 今天带来今年的一篇文本嵌入论文GTE, 中文题目是 多阶段对比学习的通用文本嵌入。 作者提出了GTE,一个使用对阶段对比学习的通用文本嵌入。使用对比学习在多个来源的混合数据集上训练了一个统一的文本嵌入模型,通过在无监督预训练阶段和有监督微调阶段显著增加训练数…

重磅消息!优维发布全新产品“应急管理”

近日,蚂蚁集团旗下的在线文档编辑与协同工具语雀平台发生了一次严重的宕机事件,导致用户无法正常使用其各项功能。从故障发生到完全恢复正常,语雀整个宕机时间将近 8 小时,如此长时间的宕机已经达到了 P0 级事故,并在网…

1-1 暴力破解-枚举

枚举:枚举所有的情况,逐个判断是否是问题的解 判断题目是否可以使用枚举:估算算法复杂度 1秒计算机大约能处理107的数据量,即O(n)107,则O(n2)能处理的数据量就是根号107≈3162 复杂度对应的数据量是该算法能处理的最…

Git 入门指南:从新手到高手的完全指南

Git是一种强大的分布式版本控制系统,广泛应用于软件开发中。它的使用不仅可以帮助开发团队更好地管理代码,还可以提高团队协作效率和代码质量。随着软件开发的不断发展,版本控制成为了程序员必备的一项技能。 Git的基本概念 Git的基本概念对…

华为云资源搭建过程

网络搭建 EIP: 弹性EIP,支持IPv4和IPv6。 弹性公网IP(Elastic IP)提供独立的公网IP资源,包括公网IP地址与公网出口带宽服务。可以与弹性云服务器、裸金属服务器、虚拟IP、弹性负载均衡、NAT网关等资源灵活地绑定及解绑…

【嵌入式】【GIT】如何迁移老的GIF到新的仓库时使用LFS功能并保持LOG不变

一、正常迁移流程 假设有仓库 ssh://old/buildroot-201902 需要迁移到新的仓库 ssh://old/buildroot-201902时,我们可以使用以下命令来完成: # 下载老的仓库 git clone ssh://old/buildroot-201902 # 向新的仓库上传所有的tags git push ssh://new/buildroot-201902 --tag…

成为一个优秀的测试工程师需要具备哪些知识和经验?

看到这个题目,头脑中马上就拆分出了3个小问题: 1、什么是优秀的测试工程师? 2、优秀测试工程师需要哪些知识? 3、优秀测试工程师需要哪些经验? 一个个讲解。 一、什么才是一名优秀的测试工程师呢? 什么才是…

window11最新版终于可以取消任务栏合并了

windows11一个软件开了多个窗口之后,会自动合并任务栏,很不方便选择其中一个窗口,且没有选项能关闭这一配置 今日发现,最新版完善了这一功能,现在可以关闭自动合并任务栏了 右击任务栏,选择任务栏设置选择…

python实现Excel自动化办公

准备工作 安装相关模块 pip install openpyxl lxml pillow 基本定义 工作簿:一个电子表文件为一个工作簿 活动表:用户当前查看的表活关闭Excel最后查看的表 sheet表 单元格 Excel数据读取操作 打开工作簿并创建一个对象: wb openpyxl.loa…