强化学习10——免模型控制Q-learning算法

Q-learning算法

主要思路

由于 V π ( s ) = ∑ a ∈ A π ( a ∣ s ) Q π ( s , a ) V_\pi(s)=\sum_{a\in A}\pi(a\mid s)Q_\pi(s,a) Vπ(s)=aAπ(as)Qπ(s,a) ,当我们直接预测动作价值函数,在决策中选择Q值最大即动作价值最大的动作,则可以使策略和动作价值函数同时最优,那么由上述公式可得,状态价值函数也是最优的。
Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + γ max ⁡ a Q ( s t + 1 , a ) − Q ( s t , a t ) ] Q(s_t,a_t)\leftarrow Q(s_t,a_t)+\alpha[r_t+\gamma\max_aQ(s_{t+1},a)-Q(s_t,a_t)] Q(st,at)Q(st,at)+α[rt+γamaxQ(st+1,a)Q(st,at)]
Q-learning基于时序差分的更新方法,具体流程如下所示:

  • 初始化 Q ( s , a ) Q(s,a) Q(s,a)
  • for 序列 e = 1 → E e=1\to E e=1E do:
    • 得到初始状态s
    • for 时步 t = 1 → T t=1\to T t=1T do:
      • 使用 ϵ − g r e e d y \epsilon -greedy ϵgreedy 策略根据Q选择当前状态s下的动作a
      • 得到环境反馈 r , s ′ r,s' r,s
      • Q ( s , a ) ← Q ( s , a ) + α [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) − Q ( s , a ) ] Q(s,a)\leftarrow Q(s,a)+\alpha[r+\gamma\max_{a^{\prime}}Q(s^{\prime},a^{\prime})-Q(s,a)] Q(s,a)Q(s,a)+α[r+γmaxaQ(s,a)Q(s,a)]
      • s ← s ′ s\gets s' ss
    • end for
  • end for

算法实战

我们在悬崖漫步环境下实习Q-learning算法。

首先创建悬崖漫步的环境:

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm  # tqdm是显示循环进度条的库class CliffWalkingEnv:def __init__(self, ncol, nrow):self.nrow = nrowself.ncol = ncolself.x = 0  # 记录当前智能体位置的横坐标self.y = self.nrow - 1  # 记录当前智能体位置的纵坐标def step(self, action):  # 外部调用这个函数来改变当前位置# 4种动作, change[0]:上, change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)# 定义在左上角change = [[0, -1], [0, 1], [-1, 0], [1, 0]]self.x = min(self.ncol - 1, max(0, self.x + change[action][0]))self.y = min(self.nrow - 1, max(0, self.y + change[action][1]))next_state = self.y * self.ncol + self.xreward = -1done = Falseif self.y == self.nrow - 1 and self.x > 0:  # 下一个位置在悬崖或者目标done = Trueif self.x != self.ncol - 1:reward = -100return next_state, reward, donedef reset(self):  # 回归初始状态,坐标轴原点在左上角self.x = 0self.y = self.nrow - 1return self.y * self.ncol + self.x

创建Q-learning算法

class QLearning:def __init__(self, ncol, nrow, epsilon, alpha, gamma,n_action=4):self.epsilon = epsilon  # 随机探索的概率self.alpha = alpha  # 学习率self.gamma = gamma  # 折扣因子self.n_action = n_action  # 动作数量# 给每一个状态创建一个长度为4的列表。self.Q_table = np.zeros([nrow*ncol,n_action])  # 初始化Q(s,a)def take_action(self,state):# 选取下一步的操作if np.random.random()<self.epsilon:action = np.random.randint(self.n_action)  # 随机探索else:action = np.argmax(self.Q_table[state])  # 贪婪策略,选择Q值最大的动作return actiondef best_action(self, state):  # 用于打印策略Q_max = np.max(self.Q_table[state])a = [0 for _ in range(self.n_action)]for i in range(self.n_action):if self.Q_table[state, i] == Q_max:a[i] = 1return adef update(self,s0,a0,r,s1):td_error = r+self.gamma*self.Q_table[s1].max()-self.Q_table[s0,a0]self.Q_table[s0, a0] += self.alpha * td_error
ncol = 12
nrow = 4    
np.random.seed(0)
epsilon = 0.1
alpha = 0.1
gamma = 0.9
env = CliffWalkingEnv(ncol, nrow)
agent = QLearning(ncol, nrow, epsilon, alpha, gamma)
num_episodes = 500  # 智能体在环境中运行的序列的数量
return_list = [] # 记录每一条序列的回报
# 显示10个进度条
for i in range(10):# tqdm的进度条功能with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):  # 每个进度条的序列数episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done = env.step(action)episode_return += reward  # 这里回报的计算不进行折扣因子衰减agent.update(state, action, reward, next_state)state = next_statereturn_list.append(episode_return)if (i_episode + 1) % 10 == 0:  # 每10条序列打印一下这10条序列的平均回报pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Q-learning on {}'.format('Cliff Walking'))
plt.show()action_meaning = ['^', 'v', '<', '>']
print('Q-learning算法最终收敛得到的策略为:')
def print_agent(agent, env, action_meaning, disaster=[], end=[]):for i in range(env.nrow):for j in range(env.ncol):if (i * env.ncol + j) in disaster:print('****', end=' ')elif (i * env.ncol + j) in end:print('EEEE', end=' ')else:a = agent.best_action(i * env.ncol + j)pi_str = ''for k in range(len(action_meaning)):pi_str += action_meaning[k] if a[k] > 0 else 'o'print(pi_str, end=' ')print()action_meaning = ['^', 'v', '<', '>']
print('Sarsa算法最终收敛得到的策略为:')
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
print_agent(agent, env, action_meaning, list(range(37, 47)), [47])
Iteration 0: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2040.03it/s, episode=50, return=-105.700]
Iteration 1: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 2381.99it/s, episode=100, return=-70.900] 
Iteration 2: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3209.35it/s, episode=150, return=-56.500] 
Iteration 3: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3541.95it/s, episode=200, return=-46.500] 
Iteration 4: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 5005.26it/s, episode=250, return=-40.800] 
Iteration 5: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 3936.76it/s, episode=300, return=-20.400] 
Iteration 6: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 4892.00it/s, episode=350, return=-45.700] 
Iteration 7: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 5502.60it/s, episode=400, return=-32.800] 
Iteration 8: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6730.49it/s, episode=450, return=-22.700] 
Iteration 9: 100%|███████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6768.50it/s, episode=500, return=-61.700] 
Q-learning算法最终收敛得到的策略为:
Qling算法最终收敛得到的策略为:
^ooo ovoo ovoo ^ooo ^ooo ovoo ooo> ^ooo ^ooo ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ^ooo ooo> ooo> ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE
^ooo ovoo ovoo ^ooo ^ooo ovoo ooo> ^ooo ^ooo ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ^ooo ooo> ooo> ooo> ooo> ovoo
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE

image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/336551.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

RT-Thread基于AT32单片机的CAN应用

1 硬件电路 2 RT-Thread驱动配置 RT-Studio中没有CAN相关的图形配置&#xff0c;需要手动修改board.h。在board.h的末尾&#xff0c;增加相关的BSP配置。 #define RT_CAN_USING_HDR #define BSP_USING_CAN13 IO配置 at32_msp.c中的IO配置是PB9和PB10&#xff0c;掌上实验室V…

JIRA新BUG单浏览器通知

以下代码请放在油猴内使用&#xff1a; // UserScript // name JIRA未处理任务通知 // namespace https://blog.csdn.net/weixin_43515759 // version 1.0 // description Polls an API endpoint and sends a notification if conditions are met // author …

ChatGPT付费创作系统V2.5.5独立版+前端

ChatGPT付费创作系统V2.5.5版本优化了很多细节&#xff0c;功能增加增加长篇写作功能。该版本为编译版无开源&#xff0c;本版本特别处理了后台弹窗、暗链网址。特别优化了数据库。升级过程中未发现任何BUG&#xff0c;全新安装或者升级安装均未出现400或者500错误&#xff0c;…

高级JavaScript中最有趣的原型、原型链?

封装、继承、多态 基于类 class&#xff0c;JavaScript没有类&#xff1b;JavaScript可以实现面向对象语言特征&#xff1a;封装、继承、多态 封装&#xff1a;通俗的来说就是封装函数&#xff0c;通过私有化的变量和私有化的方法&#xff0c;不让外部访问到 继承&#xff1…

【模拟IC学习笔记】 PSS和Pnoise仿真

目录 PSS Engine Beat frequency Number of harmonics Accuracy Defaults Run tranisent?的3种设置 Pnoise type noise Timeaverage sampled(jitter) Edge Crossing Edge Delay Sampled Phase sample Ratio 离散时间网络(开关电容电路)的噪声仿真方法 PSS PSS…

【web】springboot3 生成本地文件 url

文章目录 流程效果静态资源访问ServiceServiceImplController 流程 avatar_dir&#xff1a;请求图片在服务端的存放路径user.dir&#xff1a;项目根目录 效果 静态资源访问 application.yml 设置静态文件存储路径custom:upload:avatar_dir: ${user.dir}/avatar_dir/avatar_d…

Kubernetes(K8S)云服务器实操TKE

一、 Kubernetes(K8S)简介 Kubernetes源于希腊语,意为舵手,因为首尾字母中间正好有8个字母,简称为K8S。Kubernetes是当今最流行的开源容器管理平台,是 Google 发起并维护的基于 Docker 的开源容器集群管理系统。它是大名鼎鼎的Google Borg的开源版本。 K8s构建在 Docker …

【C++】十大排序算法

文章目录 十大排序算法插入排序O(n^2^)冒泡排序O(n^2^)选择排序O(n^2^)希尔排序——缩小增量排序O(nlogn)快速排序O(nlogn)堆排序O(nlogn)归并排序(nlogn)计数排序O(nk)基数排序O(n*k)桶排序O(nk) 十大排序算法 排序算法的稳定性&#xff1a;在具有多个相同关键字的记录中&…

C# WPF 数据绑定

需求 后台变量发生改变&#xff0c;前端对应的相关属性值也发生改变 实现 接口 INotifyPropertyChanged 用于通知客户端&#xff08;通常绑定客户端&#xff09;属性值已更改。 示例 示例一 官方示例代码如下 using System; using System.Collections.Generic; using Sy…

社交距离 - 华为OD统一考试

OD统一考试(C卷) 分值: 200分 题解: Java / Python / C++ 题目描述 疫情期间,需要大家保证一定的社交距离,公司组织开交流会议,座位有一排共N个座位,编号分别为[0…N-1],要求员工一个接着一个进入会议室,并且可以在任何时候离开会议室。 满足:每当一个员工进入时,…

鱼哥赠书活动第⑥期:《内网渗透实战攻略》看完这本书教你玩转内网渗透测试成为实战高手!!!!

鱼哥赠书活动第⑥期&#xff1a;《内网渗透实战攻略》 如何阅读本书&#xff1a;本书章节介绍&#xff1a;本书大致目录&#xff1a;适合阅读对象&#xff1a;赠书抽奖规则:往期赠书福利&#xff1a; 当今&#xff0c;网络系统面临着越来越严峻的安全挑战。在众多的安全挑战中&…

springboot2.7集成sharding-jdbc4.1.1实现业务分表

1、引入maven <dependency><groupId>org.apache.shardingsphere</groupId><artifactId>sharding-jdbc-spring-boot-starter</artifactId><version>4.1.1</version></dependency> 2、基本代码示例 基本逻辑&#xff1a;利用数…