SARAS多步TD目标算法

SARAS多步TD目标算法

代码仓库:https://github.com/daiyizheng/DL/tree/master/09-rl
SARSA算法是on-policy 时序差分

在迭代的时候,我们基于 ϵ \epsilon ϵ-贪婪法在当前状态 S t S_t St 选择一个动作 A t A_t At ,然后会进入到下一个状态 S t + 1 S_{t+1} St+1 ,同时获得奖励 R t + 1 R_{t+1} Rt+1 ,在新的状态 S t + 1 S_{t+1} St+1 我们同样基于 ϵ \epsilon ϵ-贪婪法选择一个动作 A t + 1 A_{t+1} At+1 ,然后用它来更新我们的价值函数,更新公式如下:
Q ( S t , A t ) ← Q ( S t , A t ) + α [ R t + 1 + γ Q ( S t + 1 , A t + 1 ) − Q ( S t , A t ) ] Q\left(S_t, A_t\right) \leftarrow Q\left(S_t, A_t\right)+\alpha\left[R_{t+1}+\gamma Q\left(S_{t+1}, A_{t+1}\right)-Q\left(S_t, A_t\right)\right] Q(St,At)Q(St,At)+α[Rt+1+γQ(St+1,At+1)Q(St,At)]

  • 注意: 这里我们选择的动作 A t + 1 A_{t+1} At+1 ,就是下一步要执行的动作,这点是和Q-Learning算法的最大不同
  • 这里的 TD Target: δ t = R t + 1 + γ Q ( S t + 1 , A t + 1 ) \delta_t=R_{t+1}+\gamma Q\left(S_{t+1}, A_{t+1}\right) δt=Rt+1+γQ(St+1,At+1)
  • 在每一个非终止状态 S t S_t St
  • 进行一次更新,我们要获取 5 个数据, < S t , A t , R t + 1 , S t + 1 , A t + 1 > <S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}> <St,At,Rt+1,St+1,At+1>

那么n-step Sarsa如何计算

Q ( S t , A t ) ← Q ( S t , A t ) + α ( q t ( n ) − Q ( S t , A t ) ) Q\left(S_t, A_t\right) \leftarrow Q\left(S_t, A_t\right)+\alpha\left(q_t^{(n)}-Q\left(S_t, A_t\right)\right) Q(St,At)Q(St,At)+α(qt(n)Q(St,At))

其中 q ( n ) q_{(n)} q(n) 为:
q t ( n ) = R t + 1 + γ R t + 2 + ⋯ + γ n − 1 R t + n + γ n Q ( S t + n , A t + n ) q_t^{(n)}=R_{t+1}+\gamma R_{t+2}+\cdots+\gamma^{n-1} R_{t+n}+\gamma^n Q\left(S_{t+n}, A_{t+n}\right) qt(n)=Rt+1+γRt+2++γn1Rt+n+γnQ(St+n,At+n)

代码

  1. 构建环境
import gym#定义环境
class MyWrapper(gym.Wrapper):def __init__(self):#is_slippery控制会不会滑env = gym.make('FrozenLake-v1',render_mode='rgb_array',is_slippery=False)super().__init__(env)self.env = envdef reset(self):state, _ = self.env.reset()return statedef step(self, action):state, reward, terminated, truncated, info = self.env.step(action)over = terminated or truncated#走一步扣一份,逼迫机器人尽快结束游戏if not over:reward = -1#掉坑扣100分if over and reward == 0:reward = -100return state, reward, over#打印游戏图像def show(self):from matplotlib import pyplot as pltplt.figure(figsize=(3, 3))plt.imshow(self.env.render())plt.show()env = MyWrapper()env.reset()env.show()
  1. 构建Q表
import numpy as np#初始化Q表,定义了每个状态下每个动作的价值
Q = np.zeros((16, 4))Q
  1. 构建数据
from IPython import display
import random#玩一局游戏并记录数据
def play(show=False):state = []action = []reward = []next_state = []over = []s = env.reset()o = Falsewhile not o:a = Q[s].argmax()if random.random() < 0.1:a = env.action_space.sample()ns, r, o = env.step(a)state.append(s)action.append(a)reward.append(r)next_state.append(ns)over.append(o)s = nsif show:display.clear_output(wait=True)env.show()return state, action, reward, next_state, over, sum(reward)play()[-1]
  1. 训练
#训练
def train():#训练N局for epoch in range(50000):#玩一局游戏,得到数据state, action, reward, next_state, over, _ = play()for i in range(len(state)):#计算valuevalue = Q[state[i], action[i]]#计算target#累加未来N步的reward,越远的折扣越大#这里是在使用蒙特卡洛方法估计targetreward_s = 0for j in range(i, min(len(state), i + 5)):reward_s += reward[j] * 0.9**(j - i)#计算最后一步的value,这是target的一部分,按距离给折扣target = Q[next_state[j]].max() * 0.9**(j - i + 1)#如果最后一步已经结束,则不需要考虑状态价值#最后累加reward就是targettarget = target + reward_s#更新Q表Q[state[i], action[i]] += (target - value) * 0.05if epoch % 5000 == 0:test_result = sum([play()[-1] for _ in range(20)]) / 20print(epoch, test_result)train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/171762.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue基础必备掌握知识点-Vue的指令系统讲解

什么是Vue&#xff1f; Vue的概念 Vue是一个用于构建用户界面的渐进式框架(通过数据渲染出用户所能够看到的界面) Vue的两种使用方式 1&#xff1a;Vue核心包开发 场景&#xff1a;局部模块的改造 2&#xff1a;Vue核心包&Vue工程化的开发 场景&#xff1a;整站开发 Vue开发…

安装virt-manger虚拟机管理器

环境&#xff1a; redhat7:192.168.1.130 安装步骤&#xff1a; 安装qemu-kvm yum install -y qemu-kvm安装libvirt yum install -y libvirt重启libvirt systemctl restart libvirtd查看libvirt的版本信息 virsh version安装virt-manager yum install -y virt-manager检验…

OCR转换技巧:如何避免图片转Word时出现多余的换行?

在将图片中的文字识别转换为Word文档时&#xff0c;我们很多时候时会遇到识别内容的一个自然段还没结束就换行的问题&#xff0c;这些就是我们常说的多余换行的问题。为什么会产生这个问题呢&#xff1f;主要是由于OCR返回的识别结果是按图片上的文字换行而换行&#xff0c;而不…

知识竞赛中常用的物料有哪些

办一场知识竞赛&#xff0c;需要准备的物料要根据具体竞赛规则和流程来定。但是要仔细分析起来&#xff0c;还是可以做一个常用物料清单的&#xff0c;下面我将知识竞赛活动中常用的物料做了一个分类和列表&#xff0c;大家以后在竞赛活动举办过程中&#xff0c;可以参考。 一、…

【STM32】定时器+基本定时器

一、定时器的基本概述 1.软件定时器原理 原来我们使用51单片机的时候&#xff0c;是通过一个__nop()__来进行延时 我们通过软件的方式来进行延时功能是不准确的&#xff0c;受到很多不确定因素。 2.定时器原理&#xff1a;计数之间的比值 因为使用软件延时受到影响&#xff0c…

IntelliJ IDEA 安装 GitHub Copilot插件 (最新)

注意&#xff1a; GitHub Copilot 插件对IDEA最低版本要求是2021.2&#xff0c;建议直接用2023.3&#xff0c;一次到位反正后续要升级的。 各个版本的依赖关系&#xff0c;请参照&#xff1a; ##在线安装&#xff1a; 打开 IntelliJ IDEA扩展商店&#xff0c;输入 "Git…

HTML 之常用标签的介绍

文章目录 h标签p标签a标签img 标签table、tr、td标签ul、ol、li 标签div 标签 h标签 <h> 标签用于定义 HTML 文档中的标题&#xff0c;其中 h 后面跟着一个数字&#xff0c;表示标题的级别。HTML 提供了 <h1> 到 <h6> 六个不同级别的标题&#xff0c;其中 &…

[C++]Leetcode17电话号码的字母组合

题目描述 解题思路&#xff1a; 这是一个深度优先遍历的题目&#xff0c;涉及到多路递归&#xff0c;下面通过画图和解析来分析这道题。 首先说到的是映射关系&#xff0c;那么我们就可以通过一个字符串数组来表示映射关系&#xff08;字符串下标访问对应着数字映射到对应的…

打破语言壁垒,实现全球商贸:多语言多商户跨境商城源码引领电商新潮流

随着全球化的不断深入&#xff0c;电子商务的蓬勃发展&#xff0c;传统的单语言电商模式已经无法满足日益多元化的市场需求。多语言多商户跨境商城源码&#xff0c;一种创新的电商解决方案&#xff0c;应运而生。它打破了语言和地域的限制&#xff0c;让全球的商家和消费者都能…

这 11 个 for 循环优化你得会

日常开发中&#xff0c;经常会遇到一些循环耗时计算的操作&#xff0c;一般也都会采用 for 循环来处理&#xff0c;for 作为编程入门基础&#xff0c;主要是处理重复的计算操作&#xff0c;虽然简单好用&#xff0c;但在写法上也有很多的考究&#xff0c;如果处理不好&#xff…

时间序列预测实战(十五)PyTorch实现GRU模型长期预测并可视化结果

往期回顾&#xff1a;时间序列预测专栏——包含上百种时间序列模型带你从入门到精通时间序列预测 一、本文介绍 本文讲解的实战内容是GRU(门控循环单元)&#xff0c;本文的实战内容通过时间序列领域最经典的数据集——电力负荷数据集为例&#xff0c;深入的了解GRU的基本原理和…

百度智能云千帆大模型平台再升级,SDK版本开源发布!

&#x1f4eb;作者简介&#xff1a;小明java问道之路&#xff0c;2022年度博客之星全国TOP3&#xff0c;专注于后端、中间件、计算机底层、架构设计演进与稳定性建设优化&#xff0c;文章内容兼具广度、深度、大厂技术方案&#xff0c;对待技术喜欢推理加验证&#xff0c;就职于…