【强化学习】10 —— DQN算法

文章目录

  • 深度强化学习
    • 价值和策略近似
    • RL与DL结合产生的问题
    • 深度强化学习的分类
  • Q-learning回顾
  • 深度Q网络(DQN)
    • 经验回放
      • 优先经验回放
    • 目标网络
    • 算法流程
  • 代码实践
    • CartPole环境
    • 代码
    • 结果
  • 参考

深度强化学习

价值和策略近似

在这里插入图片描述
我们可以利用深度神经网络建立这些近似函数

在这里插入图片描述
深度强化学习使强化学习算法能够以端到端的方式解决复杂问题

RL与DL结合产生的问题

• 价值函数和策略现在变成了深度神经网络
• 相当高维的参数空间
• 难以稳定地训练
• 容易过拟合
• 需要大量的数据
• 需要高性能计算
• CPU(用于收集经验数据)和GPU(用于训练神经网络)之间的平衡

深度强化学习的分类

  • 基于价值的方法
    • 深度Q网络及其扩展
  • 基于随机策略的方法
    • 使用神经网络的策略梯度,自然策略梯度,信任区域策略优化(TRPO),
    近端策略优化(PPO),A3C
  • 基于确定性策略的方法
    • 确定性策略梯度(DPG),DDPG

Q-learning回顾

文章链接——http://t.csdnimg.cn/Abz4v
Q-learning不直接更新策略,是一种基于值的方法。我们先来回顾一下 Q-learning 的更新规则 Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha\left[r+\gamma\max_{a^{\prime}\in\mathcal{A}}Q(s^{\prime},a^{\prime})-Q(s,a)\right] Q(s,a)Q(s,a)+α[r+γaAmaxQ(s,a)Q(s,a)]

上述公式用时序差分(temporal difference,TD)学习目标来增量式更新 r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a'\in\mathcal A}Q(s',a') r+γmaxaAQ(s,a),也就是说要使 Q ( s , a ) Q(s,a) Q(s,a)和 TD 目标 r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a'\in\mathcal A}Q(s',a') r+γmaxaAQ(s,a)靠近。于是,对于一组数据 { ( s i , a i , r i , s i ′ ) } \{(s_i,a_i,r_i,s_i')\} {(si,ai,ri,si)},我们可以很自然地将 Q 网络的损失函数构造为均方误差的形式: ω ∗ = arg ⁡ min ⁡ ω 1 2 N ∑ i = 1 N [ Q ω ( s i , a i ) − ( r i + γ max ⁡ a ′ Q ω ( s i ′ , a ′ ) ) ] 2 \omega^*=\arg\min_\omega\frac{1}{2N}\sum_{i=1}^N\left[Q_\omega\left(s_i,a_i\right)-\left(r_i+\gamma\max_{a'}Q_\omega\left(s_i',a'\right)\right)\right]^2 ω=argωmin2N1i=1N[Qω(si,ai)(ri+γamaxQω(si,a))]2

PS1: Q w ( s , a ) Q_w(s,a) Qw(s,a)表示Q-learning学习一个由 w w w 作为参数的函数 Q w ( s , a ) Q_w(s,a) Qw(s,a)
PS2: Q w ( s i ′ , a ′ ) Q_w(s_i',a') Qw(si,a)无梯度

比较直观的想法是使用神经网络来逼近上述 Q w ( s , a ) Q_w(s,a) Qw(s,a),但是深度神经网络存在以下问题:

  • 算法不稳定
    • 连续采样得到的 { ( s i , a i , r i , s i ′ ) } \{(s_i,a_i,r_i,s_i')\} {(si,ai,ri,si)}不满足独立分布。
    • 会导致 Q w ( s , a ) Q_w(s,a) Qw(s,a)的频繁更新(Q-policy-data_distribution都在变)。

解决办法

  • 经验回放
  • 使用双网络结构:评估网络(evaluation network)和目标网络(target network)

深度Q网络(DQN)

在这里插入图片描述

经验回放

在一般的有监督学习中,假设训练数据是独立同分布的,我们每次训练神经网络的时候从训练数据中随机采样一个或若干个数据来进行梯度下降,随着学习的不断进行,每一个训练数据会被使用多次。在原来的 Q-learning 算法中,每一个数据只会用来更新一次 Q Q Q值。为了更好地将 Q-learning 和深度神经网络结合,DQN 算法采用了经验回放(experience replay)方法,具体做法为维护一个回放缓冲区,将每次从环境中采样得到的四元组数据(状态、动作、奖励、下一状态)存储到回放缓冲区中,训练 Q 网络的时候再从回放缓冲区中随机采样若干数据来进行训练。这么做可以起到以下两个作用。

(1)使样本满足独立假设。在 MDP 中交互采样得到的数据本身不满足独立假设,因为这一时刻的状态和上一时刻的状态有关。非独立同分布的数据对训练神经网络有很大的影响,会使神经网络拟合到最近训练的数据上。采用经验回放可以打破样本之间的相关性,让其满足独立假设。

(2)提高样本效率。每一个样本可以被使用多次,十分适合深度神经网络的梯度学习。

优先经验回放

优先经验回放可以防止数据过拟合,可以更多地关注差距较大的那些值。
Schaul, Tom, et al. “Prioritized experience replay.” arXiv preprint arXiv:1511.05952 (2015).
衡量标准

  • 以 𝑄 函数的值与 Target 值的差异来衡量学习的价值,即 p t = ∣ r t + γ max ⁡ a ′ Q θ ( s t + 1 , a ′ ) − Q θ ( s t , a t ) ∣ p_t=|r_t+\gamma\underset{a^{\prime}}{\operatorname*{max}}Q_\theta(s_{t+1},a^{\prime})-Q_\theta(s_t,a_t)| pt=rt+γamaxQθ(st+1,a)Qθ(st,at)
  • 为了使各样本都有机会被采样,存储 e t = ( s t , a t , s t + 1 , r t , p t + ϵ ) e_{t}=(s_{t},a_{t},s_{t+1},r_{t},p_{t}+\epsilon) et=(st,at,st+1,rt,pt+ϵ)
  • 选中的概率,样本 e t e_t et 被选中的概率为 P ( t ) = p t α ∑ k p k α P(t)=\frac{p_t^\alpha}{\sum_kp_k^\alpha} P(t)=kpkαptα
  • 重要性采样(Importance Sampling),权重为 ω t = ( N × P ( t ) ) − β max ⁡ i ω i \omega_{t}=\frac{\left(N\times P(t)\right)^{-\beta}}{\max_{i}\omega_{i}} ωt=maxiωi(N×P(t))β

算法伪代码
在这里插入图片描述

目标网络

DQN 算法最终更新的目标是让 Q w ( s , a ) Q_w(s,a) Qw(s,a)逼近 r + γ max ⁡ a ′ ∈ A Q ( s ′ , a ′ ) r+\gamma\max_{a'\in\mathcal A}Q(s',a') r+γmaxaAQ(s,a)由于 TD 误差目标本身就包含神经网络的输出,因此在更新网络参数的同时目标也在不断地改变,这非常容易造成神经网络训练的不稳定性。为了解决这一问题,DQN 便使用了目标网络(target network)的思想:既然训练过程中 Q 网络的不断更新会导致目标不断发生改变,不如暂时先将 TD 目标中的 Q 网络固定住。为了实现这一思想,我们需要利用两套 Q 网络。

(1)原来的训练网络 Q w ( s , a ) Q_w(s,a) Qw(s,a),用于计算原来的损失函数 1 2 [ Q ω ( s , a ) − ( r + γ max ⁡ a ′ Q ω − ( s ′ , a ′ ) ) ] 2 \frac{1}{2}[Q_{\omega}\left(s,a\right)-\left(r+\gamma\max_{a^{\prime}}Q_{\omega^{-}}\left(s^{\prime},a^{\prime}\right)\right)]^{2} 21[Qω(s,a)(r+γmaxaQω(s,a))]2中的 Q w ( s , a ) Q_w(s,a) Qw(s,a)项,并且使用正常梯度下降方法来进行更新。

(2) 目标网络 Q w − ( s , a ) Q_{w^{-}}(s,a) Qw(s,a),用于计算原先损失函数 1 2 [ Q ω ( s , a ) − ( r + γ max ⁡ a ′ Q ω − ( s ′ , a ′ ) ) ] 2 \frac{1}{2}[Q_{\omega}\left(s,a\right)-\left(r+\gamma\max_{a^{\prime}}Q_{\omega^{-}}\left(s^{\prime},a^{\prime}\right)\right)]^{2} 21[Qω(s,a)(r+γmaxaQω(s,a))]2中的 ( r + γ max ⁡ a ′ Q ω − ( s ′ , a ′ ) ) \left(r+\gamma\max_{a^{\prime}}Q_{\omega^{-}}\left(s^{\prime},a^{\prime}\right)\right) (r+γmaxaQω(s,a))项,其中 w − w^{-} w表示目标网络中的参数。如果两套网络的参数随时保持一致,则仍为原先不够稳定的算法。为了让更新目标更稳定,目标网络并不会每一步都更新。具体而言,目标网络使用训练网络的一套较旧的参数,训练网络 Q w ( s , a ) Q_w(s,a) Qw(s,a)在训练中的每一步都会更新,而目标网络 Q w − ( s , a ) Q_{w^{-}}(s,a) Qw(s,a)的参数每隔 C C C步才会与训练网络同步一次,即 w − ← w w^{-}\leftarrow w ww。这样做使得目标网络相对于训练网络更加稳定。

算法流程

在这里插入图片描述
在这里插入图片描述

代码实践

CartPole环境

Cart Pole gymnasium文档
pytorch官方教程REINFORCEMENT LEARNING (DQN) TUTORIAL
(使用stable_baselines3)强化学习训练的模型怎么存储?比如OpenAI-gym训练好的模型? -
https://www.zhihu.com/question/67825049/answer/2794069082
在这里插入图片描述
在车杆环境中,有一辆小车,智能体的任务是通过左右移动保持车上的杆竖直,若杆的倾斜度数过大,或者车子离初始位置左右的偏离程度过大,或者坚持时间到达 500 帧,则游戏结束。智能体的状态是一个维数为 4 的向量,每一维都是连续的,其动作是离散的,动作空间大小为 2。在游戏中每坚持一帧,智能体能获得分数为 1 的奖励,坚持时间越长,则最后的分数越高,坚持 500 帧即可获得最高的分数。

状态空间Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

维度意义最小值最大值
0车的位置-2.42.4
1车的速度-InfInf
2杆的角度~ -41.8°~ 41.8°
3杆尖端的速度-InfInf

动作空间Discrete(2)

标号动作
0向左移动小车
1向右移动小车

代码

import random
import gymnasium as gym
import numpy as np
import collections
from tqdm import tqdm
import torch
import torch.nn.functional as F
import utilclass ReplayBuffer:''' 经验回放池 '''def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)  # 队列,先进先出# 将数据加入bufferdef add(self, state, action, reward, next_state, terminated, truncated):self.buffer.append((state, action, reward, next_state, terminated, truncated))# 从buffer中采样数据,数量为batch_sizedef sample(self, batch_size):transitions = random.sample(self.buffer, batch_size)state, action, reward, next_state, terminated, truncated = zip(*transitions)return np.array(state), action, reward, np.array(next_state), terminated, truncated# 目前buffer中数据的数量def size(self):return len(self.buffer)class Qnet1(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(Qnet1, self).__init__()self.fc = torch.nn.Sequential(torch.nn.Linear(state_dim, hidden_dim),torch.nn.ReLU(),torch.nn.Linear(hidden_dim, action_dim))def forward(self, x):return self.fc(x)class DQN:''' DQN算法 '''def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,epsilon, target_update_rate, device, numOfEpisodes, env,buffer_size, minimal_size, batch_size):self.action_dim = action_dim# Q网络self.q_net = Qnet1(state_dim, hidden_dim, self.action_dim).to(device)# 目标网络self.target_q_net = Qnet1(state_dim, hidden_dim, self.action_dim).to(device)# 使用Adam优化器self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = gammaself.epsilon = epsilon# 目标网络更新频率self.target_update_rate = target_update_rate# 计数器,记录更新次数self.count = 0self.device = deviceself.numOfEpisodes = numOfEpisodesself.env = envself.buffer_size = buffer_sizeself.minimal_size = minimal_sizeself.batch_size = batch_size# Choose A from S using policy derived from Q (e.g., epsilon-greedy)def ChooseAction(self, state):if np.random.random() < self.epsilon:action = np.random.randint(self.action_dim)else:state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action = self.q_net(state).argmax().item()return actiondef Update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)terminateds = torch.tensor(transition_dict['terminateds'],dtype=torch.float).view(-1, 1).to(self.device)truncateds = torch.tensor(transition_dict['truncateds'],dtype=torch.float).view(-1, 1).to(self.device)#Q值?q_values = self.q_net(states).gather(1, actions)# 下个状态的最大Q值max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)# TD误差目标q_targets = rewards + self.gamma * max_next_q_values * (1 - terminateds + truncateds)# 均方误差损失函数dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))# PyTorch中默认梯度会累积,这里需要显式将梯度置为0self.optimizer.zero_grad()# 反向传播更新参数dqn_loss.backward()self.optimizer.step()if self.count % self.target_update_rate == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络self.count += 1def DQNtrain(self):replay_buffer = ReplayBuffer(self.buffer_size)returnList = []for i in range(10):with tqdm(total=int(self.numOfEpisodes / 10), desc='Iteration %d' % i) as pbar:for episode in range(int(self.numOfEpisodes / 10)):# initialize statestate, info = self.env.reset()terminated = Falsetruncated = FalseepisodeReward = 0# Loop for each step of episode:while (not terminated) or (not truncated):action = self.ChooseAction(state)next_state, reward, terminated, truncated, info = self.env.step(action)replay_buffer.add(state, action, reward, next_state, terminated, truncated)if terminated or truncated:breakstate = next_stateepisodeReward += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > self.minimal_size:b_s, b_a, b_r, b_ns, b_te, b_tr = replay_buffer.sample(self.batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'terminateds': b_te,'truncateds': b_tr}self.Update(transition_dict)returnList.append(episodeReward)if (episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报pbar.set_postfix({'episode':'%d' % (self.numOfEpisodes / 10 * i + episode + 1),'return':'%.3f' % np.mean(returnList[-10:])})pbar.update(1)return returnListdef test01():device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env = gym.make("CartPole-v1", render_mode="human")# random.seed(0)# np.random.seed(0)# torch.manual_seed(0)returnLists1 = []ReturnList = []agent = DQN(state_dim=env.observation_space.shape[0],hidden_dim=128,action_dim=2,learning_rate=2e-3,gamma=0.98,epsilon=0.01,target_update_rate=10,device=device,numOfEpisodes=500,env=env,buffer_size=10000,minimal_size=500,batch_size=64)returnLists1.append(agent.DQNtrain())ReturnList.append(util.smooth(returnLists1, sm=100))labelList = ['DQN']util.PlotReward(500, ReturnList, labelList, 'CartPole-v1')np.save("D:\LearningRL\Hands-on-RL\DQN_CartPole\ReturnData\DQN_CartPole_v0_2.npy", returnLists1)env.close()if __name__ == "__main__":test01()

结果

一次不太理想的结果
在这里插入图片描述
pytorch教程中的结果
在这里插入图片描述

一次比较“好”的结果
在这里插入图片描述
在这里插入图片描述
左(保留训练效果不理想),右(剔除训练效果不理想)

部分结果动图
在这里插入图片描述

参考

[1] 伯禹AI
[2] https://www.davidsilver.uk/teaching/
[3] 动手学强化学习
[4] Reinforcement Learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/151972.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式(19)命令模式

一、介绍&#xff1a; 1、定义&#xff1a;命令模式&#xff08;Command Pattern&#xff09;是一种行为设计模式&#xff0c;它将请求封装为一个对象&#xff0c;从而使你可以使用不同的请求对客户端进行参数化。命令模式还支持请求的排队、记录日志、撤销操作等功能。 2、组…

Python深度学习实战-基于tensorflow原生代码搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能 前面两篇文章分别介绍了两种搭建神经网络模型的方法&#xff0c;一种是基于tensorflow的keras框架&#xff0c;另一种是继承父类自定义class类&#xff0c;本篇文章将编写原生代码搭建BP神经网络。 实现代码 import tensorflow as tf from sklearn.datasets import…

哈希算法:如何防止数据库中的用户信息被脱库?

文章来源于极客时间前google工程师−王争专栏。 2011年CSDN“脱库”事件&#xff0c;CSDN网站被黑客攻击&#xff0c;超过600万用户的注册邮箱和密码明文被泄露&#xff0c;很多网友对CSDN明文保存用户密码行为产生了不满。如果你是CSDN的一名工程师&#xff0c;你会如何存储用…

debian 10 安装apache2 zabbix

nginx 可以略过&#xff0c;改为apache2 apt updateapt-get install nginx -ynginx -v nginx version: nginx/1.14.2mysql 安装参考linux debian10 安装mysql5.7_debian apt install mysql5.7-CSDN博客 Install and configure Zabbix for your platform a. Install Zabbix re…

SpringCore完整学习教程5,入门级别

本章从第6章开始 6. JSON Spring Boot提供了三个JSON映射库的集成: Gson Jackson JSON-B Jackson是首选的和默认的库。 6.1. Jackson 为Jackson提供了自动配置&#xff0c;Jackson是spring-boot-starter-json的一部分。当Jackson在类路径上时&#xff0c;将自动配置Obj…

uniapp 中添加 vconsole

uniapp 中添加 vconsole 一、安装 vconsole npm i vconsole二、使用 vconsole 在项目的 main.js 文件中添加如下内容 // #ifdef H5 // 提交前需要注释 本地调试使用 import * as vconsole from "vconsole"; new vconsole() // 使用 vconsole // #endif三、成功

[17]JAVAEE-HTTP协议

目录 一、什么是HTTP协议 什么时候会用到HTTP协议&#xff1f; HTTP协议的工作流程 二、HTTP的报文格式 抓包 HTTP请求报文格式 1.首行 2.header 常见键值对&#xff1a; 3.空行 4.正文&#xff08;body&#xff09;&#xff08;有的时候可以没有&#xff09; HTTP…

数据分析和互联网医院小程序:提高医疗决策的准确性和效率

互联网医院小程序已经在医疗领域取得了显著的进展&#xff0c;为患者和医疗从业者提供了更便捷和高效的医疗服务。随着数据分析技术的快速发展&#xff0c;互联网医院小程序能够利用大数据来提高医疗决策的准确性和效率。本文将探讨数据分析在互联网医院小程序中的应用&#xf…

【Kotlin精简】第6章 反射

1 反射简介 反射机制是在运行状态中&#xff0c;对于任意一个类&#xff0c;都能够知道这个类的所有属性和方法&#xff0c;对于任意一个对象&#xff0c;都能够调用它的任意一个方法和属性。 1.1 Kotlin反射 我们对比Kotlin和Java的反射类图。 1.1.1 Kotlin反射常用的数据结…

Egg.js使用MySql数据库

最近在接手一个项目&#xff0c;vuenuxtegg&#xff0c;我也是刚开始学习egg.js&#xff0c;所以会将自己踩的坑都记录下来。 安装mysql 使用sequelize连接数据库&#xff0c;首先安装egg-sequelize和mysql2。 npm install --save egg-sequelize mysql2打开package.json文件…

(a /b)*c的值

系列文章目录 进阶的卡莎C++_睡觉觉觉得的博客-CSDN博客数1的个数_睡觉觉觉得的博客-CSDN博客双精度浮点数的输入输出_睡觉觉觉得的博客-CSDN博客足球联赛积分_睡觉觉觉得的博客-CSDN博客大减价(一级)_睡觉觉觉得的博客-CSDN博客小写字母的判断_睡觉觉觉得的博客-CSDN博客纸币(…

JavaWeb——关于servlet种mapping地址映射的一些问题

6、Servlet 6.4、Mapping问题 一个Servlet可以指定一个映射路径 <servlet-mapping><servlet-name>hello</servlet-name><url-pattern>/hello</url-pattern> </servlet-mapping>一个Servlet可以指定多个映射路径 <servlet-mapping>&…