Pytorch深度强化学习2-1:基于价值的强化学习——DQN算法

目录

  • 0 专栏介绍
  • 1 基于价值的强化学习
  • 2 深度Q网络与Q-learning
  • 3 DQN原理分析
  • 4 DQN训练实例

0 专栏介绍

本专栏重点介绍强化学习技术的数学原理,并且采用Pytorch框架对常见的强化学习算法、案例进行实现,帮助读者理解并快速上手开发。同时,辅以各种机器学习、数据处理技术,扩充人工智能的底层知识。

🚀详情:《Pytorch深度强化学习》


1 基于价值的强化学习

根据不动点定理,最优策略和最优价值函数是唯一的(对该经典理论不熟悉的请看Pytorch深度强化学习1-4:策略改进定理与贝尔曼最优方程详细推导),通过优化价值函数间接计算最优策略的方法称为基于价值的强化学习(value-based)框架。设状态空间为 n n n维欧式空间 S = R n S=\mathbb{R} ^n S=Rn,每个维度代表状态的一个特征。此时状态-动作值函数记为

Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)

其中 s \boldsymbol{s} s是状态向量, a \boldsymbol{a} a是动作空间中的动作向量, θ \boldsymbol{\theta } θ是神经网络的参数向量。深度学习完成了从输入状态到输出状态-动作价值的映射

s → Q ( s , a ; θ ) [ Q ( s , a 1 ) Q ( s , a 2 ) ⋯ Q ( s , a m ) ] T ( a 1 , a 2 , ⋯ , a m ∈ A ) \boldsymbol{s}\xrightarrow{Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}\left[ \begin{matrix} Q\left( \boldsymbol{s},a_1 \right)& Q\left( \boldsymbol{s},a_2 \right)& \cdots& Q\left( \boldsymbol{s},a_m \right)\\\end{matrix} \right] ^T\,\, \left( a_1,a_2,\cdots ,a_m\in A \right) sQ(s,a;θ) [Q(s,a1)Q(s,a2)Q(s,am)]T(a1,a2,,amA)

相当于对无穷维Q-Table的一次隐式查表,对经典Q-learing算法不熟悉的请看Pytorch深度强化学习1-6:详解时序差分强化学习(SARSA、Q-Learning算法)、Pytorch深度强化学习案例:基于Q-Learning的机器人走迷宫。设目标价值函数为 Q ∗ Q^* Q,若采用最小二乘误差,可得损失函数为

J ( θ ) = E [ 1 2 ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) 2 ] J\left( \boldsymbol{\theta } \right) =\mathbb{E} \left[ \frac{1}{2}\left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) ^2 \right] J(θ)=E[21(Q(s,a)Q(s,a;θ))2]

采用梯度下降得到参数更新公式为

θ ← θ + α ( Q ∗ ( s , a ) − Q ( s , a ; θ ) ) ∂ Q ( s , a ; θ ) ∂ θ \boldsymbol{\theta }\gets \boldsymbol{\theta }+\alpha \left( Q^*\left( \boldsymbol{s},\boldsymbol{a} \right) -Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) \right) \frac{\partial Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right)}{\partial \boldsymbol{\theta }} θθ+α(Q(s,a)Q(s,a;θ))θQ(s,a;θ)

随着迭代进行, Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)将不断逼近 Q ∗ Q^* Q,由 Q ( s , a ; θ ) Q\left( \boldsymbol{s},\boldsymbol{a};\boldsymbol{\theta } \right) Q(s,a;θ)进行的策略评估和策略改进也将迭代至最优。

2 深度Q网络与Q-learning

Q-learning和深度Q学习(Deep Q-learning, DQN)是强化学习领域中两种重要的算法,它们在解决智能体与环境之间的决策问题方面具有相似之处,但也存在一些显著的异同。这里进行简要阐述以加深对二者的理解。

  • Q-learning是一种基于值函数的强化学习算法。它通过使用Q-Table来表示每个状态和动作对的预期回报。Q值函数用于指导智能体在每个时间步选择最优动作。通过不断更新Q值函数来使其逼近最优的Q值函数
  • DQN是对Q-learning的深度网络版本,它将神经网络引入Q-learning中,以处理具有高维状态空间的问题。通过使用深度神经网络作为函数逼近器,DQN可以学习从原始输入数据(如像素值)直接预测每个动作的Q值

在这里插入图片描述

3 DQN原理分析

深度Q网络(Deep Q-Network, DQN)的核心原理是通过

  • 经验回放池(Experience Replay):考虑到强化学习采样的是连续非静态样本,样本间的相关性导致网络参数并非独立同分布,使训练过程难以收敛,因此设置经验池存储样本,再通过随机采样去除相关性;
  • 目标网络(Target Network):考虑到若目标价值 与当前价值 是同一个网络时会导致优化目标不断变化,产生模型振荡与发散,因此构建与 结构相同但慢于 更新的独立目标网络来评估目标价值,使模型更稳定。

拟合了高维状态空间,是Q-Learning算法的深度学习版本,算法流程如表所示

在这里插入图片描述

4 DQN训练实例

最简单的例子是使用全连接网络来构造DQN

class DQN(nn.Module):def __init__(self, input_dim, output_dim):super(DQN, self).__init__()self.input_dim = input_dimself.output_dim = output_dimself.fc = nn.Sequential(nn.Linear(self.input_dim[0], 128),nn.ReLU(),nn.Linear(128, 256),nn.ReLU(),nn.Linear(256, self.output_dim))def __str__(self) -> str:return "Fully Connected Deep Q-Value Network, DQN"def forward(self, state):qvals = self.fc(state)return qvals

基于贝尔曼最优原理的损失计算如下

def computeLoss(self, batch):states, actions, rewards, next_states, dones = batchstates = torch.FloatTensor(states).to(self.device)actions = torch.LongTensor(actions).to(self.device)rewards = torch.FloatTensor(rewards).to(self.device)next_states = torch.FloatTensor(next_states).to(self.device)dones = (1 - torch.FloatTensor(dones)).to(self.device)# 根据实际动作提取Q(s,a)值curr_Q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)next_Q = self.target_model(next_states)max_next_Q = torch.max(next_Q, 1)[0]expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q * donesloss = self.criterion(curr_Q, expected_Q.detach())return loss

基于经验回放池和目标网络的参数更新如下

def update(self, batch_size):batch = self.replay_buffer.sample(batch_size)loss = self.computeLoss(batch)self.optimizer.zero_grad()loss.backward()self.optimizer.step()# 更新target网络for target_param, param in zip(self.target_model.parameters(), self.model.parameters()):target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)# 退火self.epsilon = self.epsilon + self.epsilon_delta \if self.epsilon < self.epsilon_max else self.epsilon_max

基于DQN可以实现最基本的智能体,下面给出一些具体案例

  • Pytorch深度强化学习案例:基于DQN实现Flappy Bird游戏与分析

在这里插入图片描述

完整代码联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/301019.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 将PDF 转为图片 工具 【Free Spire.PDF for Java】(免费版)

Java 将PDF 转为图片 使用工具&#xff1a;Free Spire.PDF for Java&#xff08;免费版&#xff09; Jar文件获取及导入&#xff1a; 方法1&#xff1a;通过官网下载jar文件包。下载后&#xff0c;解压文件&#xff0c;并将lib文件夹下的Spire.Pdf.jar文件导入Java程序。 方…

win10安装ffmpeg

1 ffmpeg官网下载 官网地址&#xff1a;https://ffmpeg.org/ ffmpeg可执行程序下载地址&#xff1a;https://www.gyan.dev/ffmpeg/builds/ ffmpeg官网文档&#xff1a;https://ffmpeg.org/documentation.html 选择对应的版本点解下载可执行程序包&#xff0c;比如6.1版本的…

YOLO在深度学习视觉应用中的使用场景与部署

​引言 OpenCV和YOLO&#xff08;You Only Look Once&#xff09;&#xff0c;这些工具在各种视觉识别任务中的强大功能和广泛应用。YOLO是一个流行的实时对象检测系统&#xff0c;它以其速度和准确性在工业和研究领域中广受欢迎。 YOLO的使用场景​ YOLO的应用场景非常广泛…

西南科技大学计算机网络实验三 (路由器基本配置与操作,RIP、OSPF路由协议配置)

一、实验目的 基于网络设备模拟软件,学习和使用路由器的各种基本配置与验证命令,学习和使用路由器的静态路由、RIP、OSPF路由协议配置。 二、实验环境 使用RouterSim Network Visualizer软件来模拟网络设备与网络环境;主机操作系统为windows。 三、实验内容 1、路由器名称…

如何进行实例管理

目录 修改实例规格 修改网络带宽 网站的访问量每天都比较高&#xff0c;网站明显变慢了&#xff0c;这是怎么回事&#xff1f; 这说明你的网站的并发访问能力已经不足了&#xff0c;并发访问是指同一时间&#xff0c;多个用户请求访问同一个域名下的资源或服务&#xff0c;请…

OAuth2.0 四种授权方式讲解

一、OAuth2.0 的理解 OAuth2是一个开放的授权标准&#xff0c;允许第三方应用程序以安全可控的方式访问受保护的资源&#xff0c;而无需用户将用户名和密码信息与第三方应用程序共享。OAuth2被广泛应用于现代Web和移动应用程序开发中&#xff0c;可以简化应用程序与资源服务器之…

代码随想录算法训练营Day9 | 20.有效的括号、1047.删除字符串中的所有相邻重复项、150.逆波兰表达式求值

LeetCode 20 有效的括号 本题思路&#xff1a;利用栈来完成&#xff0c;如果遇到左括号类型就放入栈&#xff0c;如果遇到右括号类型&#xff0c;就弹出栈顶的元素和该元素进行匹配&#xff0c;如果不匹配就返回 false。 注意点&#xff1a; 第一个就是右括号类型&#xff0c;那…

Qt Designer 常见需求

窗口 参考链接 【转载】Qt Designer 使用全攻略_qtdesigner使用-CSDN博客 QT屏幕自适应自动布局&#xff0c;拖动窗口自动变大变小&#xff08;一&#xff09;_qt布局随窗口大小变化-CSDN博客 pyqt5设置高分辨率以及icon显示模糊解决办法_python qt图显示不清晰-CSDN博客 窗…

算法基础之最长公共子序列

最长公共子序列 核心思想&#xff1a; 线性dp 集合定义 : f[i][j]存 a[1 ~ i] 和 b[1 ~ j] 的最长公共子序列长度 状态计算&#xff1a; 分为取/不取a[i]/b[j] 共四种情况 其中 中间两种会包含两个都不取的情况(去掉) 但是因为取最大值 有重复也没事用f[i-1][j] 和 f[i][j-1]表…

希尔排序详解(C语言)

前言 希尔排序是一种基于插入排序的快速排序算法。所以如果还会插入排序的小伙伴可以点击链接学习一下插入排序&#xff08;点我点我&#xff01;&#xff09; &#xff0c;相较于插入排序&#xff0c;希尔排序拥有更高的效率&#xff0c;小伙伴们肯定已经迫不及待学习了吧&…

使用防火墙是否可以应对DDoS攻击?

很多游戏行业公司对网络安全不够了解&#xff0c;觉得装个防火墙就可以万事大吉了。实际上使用防火墙确实是解决DDoS攻击问题的一种有效方法&#xff0c;一些更先进的防火墙还可以采用其他防御措施&#xff0c;例如:深度包检测、行为分析、人工智能等&#xff0c;来识别和防御各…

MyBatis见解4

10.MyBatis的动态SQL 10.5.trim标签 trim标签可以代替where标签、set标签 mapper //修改public void updateByUser2(User user);<update id"updateByUser2" parameterType"User">update user<!-- 增加SET前缀&#xff0c;忽略&#xff0c;后缀…