[PyTorch][chapter 60][强化学习-2-有模型学习2]

前言:

   前面我们讲了一下策略评估的原理,以及例子.

   强化学习核心是找到最优的策略,这里

   重点讲解两个知识点:

    策略改进

   策略迭代与值迭代

   最后以下面环境E 为例,给出Python 代码

目录:

     1:  策略改进

      2:  策略迭代与值迭代

      3: 策略迭代代码实现  Python 代码


一  策略改进

      理想的策略应该能够最大化累积奖赏:

       \pi^{*}= arg max_{\pi} \sum_{x \in X} V^{\pi}(x)

     最优策略对应的值函数V^{*}称为最优值函数

      \forall x\in X: V^{*}(x)= V^{\pi^*}(x)

状态值函数(Bellman 等式):

 动作求和

 V_{T}^{\pi}=\sum_{a \in A}\pi(x,a)\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(\frac{1}{T}R_{x \rightarrow x^{'}}^a+\frac{T-1}{T}V_{T-1}^{\pi}(x^{'}))......16.9

 V_{\gamma}^{\pi}=\sum_{a \in A}\pi(x,a)\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(R_{x \rightarrow x^{'}}^a+\gamma V_{\gamma}^{\pi}(x^{'}))......16.9

状态-动作值函数

状态值函数(Bellman 等式): 动作求和

 Q_{T}^{\pi}(x,a)=\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(\frac{1}{T}R_{x \rightarrow x^{'}}^a+\frac{T-1}{T}V_{T-1}^{\pi}(x^{'}))...16.10

 V_{\gamma}^{\pi}(x,a)=\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(R_{x \rightarrow x^{'}}^a+\gamma V_{\gamma}^{\pi}(x^{'}))...16.10

   由于最优值的累计奖赏已经最大,可以对前面的Bellman 等式做改动,

 即使对动作求和  改为取最优

    最优

 V_{T}^{\pi}=max_{a \in A}\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(\frac{1}{T}R_{x \rightarrow x^{'}}^a+\frac{T-1}{T}V_{T-1}^{*}(x^{'}))....16.13

 V_{\gamma}^{\pi}=\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(R_{x \rightarrow x^{'}}^a+\gamma V_{\gamma}^{*}(x^{'}))...16.13

V^{*}(x)=max_{a\in A}Q^{\pi^{*}}(x,a)....16.14  带入16.10

Q_{T}^{*}(x,a)=\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(\frac{1}{T}R_{x \rightarrow x^{'}}^a+\frac{T-1}{T}max_{a^{'} \in A}Q_{T-1}^{*}(x^{'},a^{'}))...16.10

 V_{\gamma}^{\pi}(x,a)=\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(R_{x \rightarrow x^{'}}^a+\gamma max_{a^{'} \in A} Q_{\gamma}^{*}(x^{'},a^{'}))...16.10

      最优Bellman 等式揭示了非最优策略的改进方式:

      将策略选择的动作改变为当前的最优动作。这样改进能使策略更好

   策略为\pi^{'},改变动作的条件为: Q^{\pi}(x,\pi^{'}(x)) \geq V^{\pi}(x)

带入16.10,可以得到递推不等式

    V^{\pi}(x)\leq Q^{\pi}(x,\pi^{'}(x))

             =\sum_{x^{'} \in X}P_{x\rightarrow x^{'}}^{\pi^{'}(x)}(R_{x\rightarrow x^{'}}^{\pi^{'}(x)}+\gamma V^{\pi}(x^{'}))

             =\sum_{x^{'} \in X}P_{x\rightarrow x^{'}}^{\pi^{'}(x)}(R_{x\rightarrow x^{'}}^{\pi^{'}(x)}+\gamma Q^{\pi}(x^{'},\pi^{'}(x^{'})))

             =V^{\pi^{*}}(x)    16.16


二  策略迭代与值迭代

可以看出:策略迭代法在每次改进策略后都要对策略进行重新评估,因此比较耗时。

由公式16.16  V^{\pi}(x)\leq Q^{\pi}(x,\pi^{'}(x))\leq V^{\pi^{*}}(x) 策略改进 与值函数的改进是一致的

由公式16.13可得  

V_{T}(x)=max_{a \in A}\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(\frac{1}{T}R_{x \rightarrow x^{'}}^a+\frac{T-1}{T}V_{T-1}^{*}(x^{'}))

 V_{\gamma}^{\pi}=max_{a\in A}\sum_{x^{'}\in X}P_{x\rightarrow x^{'}}^a(R_{x \rightarrow x^{'}}^a+\gamma V_{\gamma}^{*}(x^{'}))

于是可得值迭代(value iteration)算法.


三  策略迭代代码实现


# -*- coding: utf-8 -*-
"""
Created on Wed Nov  1 19:34:00 2023@author: cxf
"""# -*- coding: utf-8 -*-
"""
Created on Mon Oct 30 15:38:17 2023@author: chengxf2
"""
import numpy as np
from enum import Enum
import copyclass State(Enum):#状态空间X    shortWater =1 #缺水health = 2   #健康overflow = 3 #凋亡apoptosis = 4 #溢水class Action(Enum):#动作空间Awater = 1 #浇水noWater = 2 #不浇水class Env():def __init__(self):#状态空间self.X = [State.shortWater, State.health,State.overflow, State.apoptosis]   #动作空间self.A = [Action.water,Action.noWater]   #从状态x出发,执行动作a,转移到新的状态x',得到的奖赏 r为已知道self.Q ={}self.Q[State.shortWater] =          [[Action.water,0.5,   State.shortWater,-1],[Action.water,0.5,   State.health,1],[Action.noWater,0.4, State.shortWater,-1],[Action.noWater,0.6, State.overflow,-100]]self.Q[State.health] =                [[Action.water,0.6,  State.health,1],[Action.water,0.4,   State.overflow,-1],[Action.noWater,0.6, State.shortWater,-1],[Action.noWater,0.4, State.health,1]]self.Q[State.overflow] =                [[Action.water,0.6,   State.overflow,-1],[Action.water,0.4,   State.apoptosis,-100],[Action.noWater,0.6, State.health,1],[Action.noWater,0.4, State.overflow,-1]]self.Q[State.apoptosis] =[[Action.water,1, State.apoptosis,-100],[Action.noWater,1, State.apoptosis,-100]]self.curV ={} #前面的累积奖赏,t时刻的累积奖赏self.V ={} #累积奖赏,t-1时刻的累积奖赏for x in self.X:    self.V[x] =0self.curV[x]=0def GetX(self):#获取状态空间return self.Xdef GetAction(self):#获取动作空间return self.Adef GetQTabel(self):#获取状态转移概率return self.Qclass LearningAgent():def initStrategy(self):   #初始化策略stragegy ={}stragegy[State.shortWater] = Action.waterstragegy[State.health] =    Action.waterstragegy[State.overflow] = Action.waterstragegy[State.apoptosis] = Action.waterself.stragegy = stragegydef __init__(self):env = Env()self.X = env.GetX()self.A = env.GetAction()self.QTabel = env.GetQTabel()self.curV ={} #前面的累积奖赏self.V ={} #累积奖赏for x in self.X:    self.V[x] =0self.curV[x]=0def  evaluation(self,T):#策略评估for t in range(1,T):#当前策略下面的累积奖赏for  state in self.X: #状态空间reward = 0.0action = self.stragegy[state]QTabel= self.QTabel[state]for Q in QTabel:if action == Q[0]:#在状态x 下面执行了动作a,转移到了新的状态,得到的rnewstate = Q[2] p_a_ss =   Q[1]r_a_ss =   Q[-1]#print("\n p_a_ss",p_a_ss, "\t r_a_ss ",r_a_ss)reward += p_a_ss*((1.0/t)*r_a_ss + (1.0-1/t)*self.V[newstate])self.curV[state] = reward               if (T+1)== t:breakelse:self.V = self.curVdef  improve(self,T):#策略改进stragegy ={}for  state in self.X:QTabel= self.QTabel[state]max_reward = -float('inf') #计算每种Q(state, action)for action in self.A:reward = 0.0for Q in QTabel:if action == Q[0]:#在状态x 下面执行了动作a,转移到了新的状态,得到的rnewstate = Q[2] p_a_ss =   Q[1]r_a_ss =   Q[-1]#print("\n p_a_ss",p_a_ss, "\t r_a_ss ",r_a_ss)reward += p_a_ss*((1.0/T)*r_a_ss + (1.0-1/T)*self.V[newstate])if reward> max_reward:max_reward = rewardstragegy[state] = action#print("\n state ",state, "\t action ",action, "\t reward %4.2f"%reward)return stragegydef compare(self,dict1, dict2):#策略比较for key in dict1:if dict1[key] != dict2.get(key):return Falsereturn Truedef learn(self,T):#随机初始化策略self.initStrategy()n = 0while True:self.evaluation(T-1) #策略评估n = n+1print("\n 迭代次数 %d"%n ,State.shortWater.name, "\t 奖赏: %4.2f "%self.V[State.shortWater],State.health.name, "\t 奖赏: %4.2f "%self.V[State.health],State.overflow.name, "\t 奖赏: %4.2f "%self.V[State.overflow],State.apoptosis.name, "\t 奖赏: %4.2f "%self.V[State.apoptosis],)strategyN =self.improve(T) #策略改进#print("\n ---cur---\n",self.stragegy,"\n ---new-- \n ",strategyN )if self.compare(self.stragegy,strategyN):print("\n ----- 最终策略 -----\n ")for state in self.X:print("\n state ",state, "\t action: ",self.stragegy[state])breakelse:for state in self.X:self.stragegy[state] = strategyN[state]if __name__ == "__main__":T =10agent = LearningAgent()agent.learn(T)

参考:

机器学习.周志华《16 强化学习 》_51CTO博客_机器学习 周志华

CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/161039.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nn.LayerNorm解释

这个是层归一化。我们输入一个参数,这个参数就必须与最后一个维度对应。但是我们也可以输入多个维度,但是必须从后向前对应。 import torch import torch.nn as nna torch.rand((100,5)) c nn.LayerNorm([5]) print(c(a).shape)a torch.rand((100,5,…

PDF文件解析

一、PDF文件介绍 PDF是英文Portable Document Format缩写,就是可移植的意思,它是以PostScript语言图象模型为基础,无论在哪种打印机上都可保证精确的颜色和准确的打印效果,PostScript咱也不懂,估计和SVG的原理差不多吧…

Spring Boot 常见面试题

目录 1.Spring Boot 快速入门什么是 Spring Boot?有什么优点?Spring Boot 与 Spring MVC 有什么区别?Spring 与 Spring Boot 有什么关系?✨什么是 Spring Boot Starters?Spring Boot 支持哪些内嵌 Servlet 容器?如何设…

CMake引用QT、CMake构建一个转换为3d tile的开源代码

在CMake里单独运行一下 find_package(Qt5 REQUIRED COMPONENTS Core Xml Test) ,Core Xml Test 这三个是需要的qt组件; 情况如下;提示找不到QT; 根据资料,cmake引用qt需要3-4个方面, 首先Qt包含三个编译工具:moc、uic、rcc, moc:元对象编译器(Meta O…

Magics测量两个圆形中心点距离的方法

摘要:本文介绍如何使用magics测量两个圆孔之间的距离。 问题来源:3D模型打开后,两个圆孔中心点之间的间距测量无法直接通过测距实现,需要进行一些小小的设置才行。 工具选择“量尺”,如果不设置的话,它会默…

1.Netty概述

原生NIO存在的问题(Netty要解决的问题) 虽然JAVA NIO 和 JAVA AIO框架提供了多路复用IO/异步IO的支持,但是并没有提供给上层“信息格式”的良好封装。JAVA NIO 的 API 使用麻烦,需要熟练掌握 ByteBuffer、Channel、Selector等 , 所以用这些API实现一款真正的网络应…

VSIX:C#项目 重命名所有标识符(Visual Studio扩展开发)

出于某种目的(合法的,真的合法的,合同上明确指出可以这样做),我准备了一个重命名所有标识符的VS扩展,用来把一个C#库改头换面,在简单的测试项目上工作很满意,所有标识符都被准确替换…

高斯过程回归 | 高斯过程回归(GPR)区间预测

对于高斯过程,高斯指的是多元高斯分布,过程指的是随机过程。 我们都知道随机过程就是指函数的分布,那么多元高斯分布实际上应该是指无限元的高斯分布。 协方差函数也称为核函数,是高斯过程回归的重点。核函数的选取方式有很多,包括径向基函数(高斯核函数)、线性核函数、…

幂等性(防重复提交)

文章目录 1. 实现原理2.使用示例3. Idempotent注解4. debug过程 主要用途:防止用户快速双击某个按钮,而前端没有禁用,导致发送两次重复请求。 1. 实现原理 幂等性要求参数相同的方法在一定时间内,只能执行一次。本质上是基于red…

Mysql进阶-SQL优化篇

插入数据 insert 我们需要一次性往数据库表中插入多条记录,可以从以下三个方面进行优化。 批量插入数据 一条insert语句插入多个数据,但要注意,每个insert语句最好插入500-1000行数据,就得重新写另一条insert语句 Insert into…

Rocky9 上安装 redis-dump 和redis-load 命令

一、安装依赖环境 1、依赖包 dnf -y install perl gcc gcc-c zlib-devel2、编译openssl 1.X ### 下载编译 wget https://www.openssl.org/source/openssl-1.1.1t.tar.gz tar xf openssl-1.1.1t.tar.gz cd openssl-1.1.1t ./config --prefix/usr/local/openssl make make ins…

Xcode中如何操作Git

👨🏻‍💻 热爱摄影的程序员 👨🏻‍🎨 喜欢编码的设计师 🧕🏻 擅长设计的剪辑师 🧑🏻‍🏫 一位高冷无情的编码爱好者 大家好,我是全栈工…