强化学习快速复习笔记--待更新

目录

      • 蒙特卡洛方法
      • 动态规划算法
        • 策略迭代
      • 时序差分方法
        • Sarsa算法
        • Q-learning算法
        • 如何区分在线学习和离线学习
        • DQN深度强化Q学习
          • 概念介绍
          • 代码解析
        • DQN改进算法
          • Double DQN网络

蒙特卡洛方法

求解价值函数和状态价值函数,可以使用蒙特卡洛方法和动态规划。首先介绍一下蒙特卡洛的方法,这个方法是统计模拟方法,基于概率统计来进行数值计算。

  • 优点: 不需要知道环境模型,直接从交互中学习
  • 缺点: 每一次更新都需要完整的轨迹,使用一些具有明确任务终止的情况,而且采样效率很低下。

如果用这个方法来估计策略在一个马尔可夫决策中的状态价值函数。需要采集到很多条的序列。,然后使用这些轨迹进行更新和迭代。计算方法如下:

请添加图片描述

首先采样到的序列如下:

第一条序列[('s1', '前往s2', 0, 's2'), ('s2', '前往s3', -2, 's3'), ('s3', '前往s5', 0, 's5')]
第二条序列[('s4', '概率前往', 1, 's4'), ('s4', '前往s5', 10, 's5')]
第五条序列[('s2', '前往s3', -2, 's3'), ('s3', '前往s4', -2, 's4'), ('s4', '前往s5', 10, 's5')]

这里采用的是增量式计算方法:

# 对所有采样序列计算所有状态的价值
def MC(episodes, V, N, gamma):for episode in episodes:G = 0for i in range(len(episode) - 1, -1, -1):  #一个序列从后往前计算(s, a, r, s_next) = episode[i]G = r + gamma * GN[s] = N[s] + 1V[s] = V[s] + (G - V[s]) / N[s]	

强化学习的目的:学习到一个策略,能够让agent从初始点到达终点且取到最大回报。

动态规划算法

本方法要求马尔可夫决策过程是已知的,也就是智能体交互的环境完全知道,在这种条件下,智能体不需要和环境进行交互来采样数据,直接用动态规划就可以算出来最优价值和策略。类似于监督学习任务,直接给出数据的分布公式,可以通过期望的层面来最小化模型的泛化误差更新模型。

利用马尔可夫决策过程的模型MDP, 进行求解,根据贝尔曼方程进行迭代更新,逐渐找到最优策略和值函数。

  • 优点:理论基础牢固,能够找到全局最优解。在这样一个白盒环境中,不需要通过智能体和环境的大量交互来学习,可以直接用动态规划求解状态价值函数
  • 缺点:需要事先了解环境的完整模型(状态转移概率和即时的奖励), 大规模问题求解比较困难。

动态规划方法主要分为两种:策略迭代价值迭代

策略迭代

包括两个部分:策略评估策略提升

策略迭代是不需要和环境进行交互的,因为环境的状态状态转移概率和奖励函数都已经确定了。

  • 策略评估
    • 计算一下当前策略下的值函数,使用贝尔曼方程迭代更新值函数,知道值函数收敛为止。
  • 策略改进
    • 我门将当前的值函数选择在每一个状态上具有最大值的动作,从而更新策略为新的最优策略。策略改进也不需要与环境进行互动。只是对值函数的一种估计。
    • 更新策略,让每一个状态选择最大值的动作。
  • 重复评估和更新直到收敛。

问:策略选择动作的概率和状态转移概率是一样吗?

策略 π(a|s) 表示在给定状态 s 下选择动作 a 的概率,而状态转移概率 P(s’|s,a) 则表示在给定状态 s 和动作 a 后,从状态 s 转移到状态 s’ 的概率。

假设有一个网格世界,包含两个状态 s1 和 s2,以及两个动作 a1 和 a2。我们定义一个策略 π,其中在状态 s1 下选择动作 a1 的概率为 0.6,选择动作 a2 的概率为 0.4。状态转移概率如下:

P(s2|s1, a1) = 0.8 P(s1|s1, a1) = 0.2 P(s2|s1, a2) = 0.3 P(s1|s1, a2) = 0.7


时序差分方法

大部分时候环境模型是不可能获得的,这就需要智能体和环境进行交互,这类方法被成为 无模型强化学习

首先简单理解一下时序差分方法:

回顾一下蒙特卡洛的更新方法:

请添加图片描述
这里对于价值更新是要等整个序列都结束以后,才能计算这一次回报G。在实际计算中,对于一条序列而言,从后面往前开始累加G,依次更新每一个V。

时序差分方法利用了这里增量式更新的方法,可以在每一个时间步进行值函数更新。具体而言,时序差分方法用当前获得的奖励加上了下一个状态的价值估计来作为当前状态会获得的回报,即:
请添加图片描述
折扣因子后面一串就是时序差分(Temporal difference)误差,时序差分算法将步长的乘积作为状态价值的更新量。

时序差分的特点:

时序差分(Temporal Difference, TD)方法是一类强化学习算法,它结合了动态规划和蒙特卡洛方法的优点。与蒙特卡洛方法需要等待完整的回合结束后才能更新值函数不同,TD 方法可以在每个时间步对值函数进行更新。这使得 TD 方法具有以下特点:

  1. 增量式更新:每一个时间步都可以对值函数进行更新。
  2. 时序性:通过比较当前状态的值函数和下一个状态的值函数估计来更新值函数。
  3. 学习策略灵活:TD方法可以结合不同的策略,如贪心策略,或者softmax策略,进行值函数的更新和动作选择。
  4. Off-policy和On-policy:TD方法适用于离线学习和在线学习。离线学习可以直接使用历史数据进行值函数更新,不需要和环境交互。在线学习方法需要交互,根据当前策略选择动作和更新值函数。

具体到经典的算法Sarsa和Q-learning算法来理解。

Sarsa算法

Sarsa:在每个时间步,根据当前状态 s、采取的动作 a、即时奖励 r、转移到的下一个状态 s’ 以及在下一个状态选择的动作 a’。需要五元组数据,SARSA。

Q值更新公式:
Q ( s , a ) ← Q ( s , a ) + α ∗ ( r + γ ∗ Q ( s ′ , a ′ ) − Q ( s , a ) ) Q(s, a) ← Q(s, a) + α * (r + γ * Q(s', a') - Q(s, a)) Q(s,a)Q(s,a)+α(r+γQ(s,a)Q(s,a))
其中 α 是学习率,γ 是折扣因子。

伪代码如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-a7GjzcnD-1689171098747)(C:\Users\风净尘\AppData\Roaming\Typora\typora-user-images\image-请添加图片描述
.png)]

SARSA中,对于时序差分的计算用的下一个状态采取贪心算法下动作得到的价值来计算。而Q-learning采用的式直接选取下一个状态下能够获得最大价值的动作。而不是和SARSA那样根据策略来选取的动作。

SARSA是在线学习方法,他用State-Action-Reward-State-Action的序列进行更新,并且每一个时间步都进行值函数的更新和策略的选择。

具体看看一个简单的例子来理解这个过程:

# Sarsa算法
for episode in range(num_episodes):state = 0  # 初始状态if np.random.rand() < epsilon:action = np.random.randint(num_actions)  # 使用ε-贪心策略随机选择动作else:action = np.argmax(Q[state])  # 使用当前值函数估计选择动作while state != num_states - 1:  # 当前状态不是终止状态时# 执行动作并观察下一个状态和即时奖励next_state = np.random.choice(np.where(maze[state] >= 0)[0])  # 随机选择下一个状态reward = maze[state, action]# 使用ε-贪心策略选择下一个动作if np.random.rand() < epsilon:next_action = np.random.randint(num_actions)  # 随机选择下一个动作else:next_action = np.argmax(Q[next_state])  # 使用当前值函数估计选择下一个动作# 使用Sarsa更新值函数估计Q[state, action] += alpha * (reward + gamma * Q[next_state, next_action] - Q[state, action])state = next_stateaction = next_action  # 更新当前状态和动作,Q-learning只需要更新状态即可。

Q-learning算法

Q-learning:在每个时间步,根据当前状态 s、采取的动作 a、即时奖励 r、转移到的下一个状态 s’,使用贝尔曼最优方程更新 Q 值:Q(s, a) ← Q(s, a) + α * (r + γ * max(Q(s’, a’)) - Q(s, a))。只需要一个四元组数据即可。
请添加图片描述

如何区分在线学习和离线学习

我们称采样数据的策略为行为策略(behavior policy),称用这些数据来更新的策略为目标策略(target policy)。在线策略(on-policy)算法表示行为策略和目标策略是同一个策略;而离线策略(off-policy)算法表示行为策略和目标策略不是同一个策略。Sarsa 是典型的在线策略算法,而 Q-learning 是典型的离线策略算法。判断二者类别的一个重要手段是看计算时序差分的价值目标的数据是否来自当前的策略,如图 5-1 所示。具体而言:

  • 对于 Sarsa,它的更新公式必须使用来自当前策略采样得到的五元组,因此它是在线策略学习方法;
  • 对于 Q-learning,它的更新公式使用的是四元组来更新当前状态动作对的价值,数据中的s和a是给定的条件,reward和s_new皆由环境采样得到,该四元组并不需要一定是当前策略采样得到的数据,也可以来自行为策略,因此它是离线策略算法。

img

在某些情况下,Sarsa 更加稳定,因为它考虑了当前策略下的动作选择,可以更好地控制探索和利用的权衡。Q-learning 则更倾向于选择具有最大 Q 值的动作,可能导致更快地收敛到最优策略,但也可能导致过度估计和不稳定性。

DQN深度强化Q学习

概念介绍

DQN就是用神经网络来拟合这个动作价值函数。DQN只能处理动作离散的情况,因为在动作选择中有max的操作。动作值函数更新公式如下:
Q ( s , a ) = Q ( s , a ) + α ∗ ( r + γ ∗ m a x ( Q ( s ′ , a ′ ) ) − Q ( s , a ) ) Q(s, a) = Q(s, a) + α * (r + γ * max(Q(s', a')) - Q(s, a)) Q(s,a)=Q(s,a)+α(r+γmax(Q(s,a))Q(s,a))
其中,Q(s, a)是状态s下采取动作a的动作值估计,在DQN中,我们使用深度神经网络来估计Q值函数。网络的输入是状态s,输出是每个可能动作的Q值。网络的参数通过优化目标来不断更新,目标函数如下:
T a r g e t = r + γ ∗ m a x ( Q ( s ′ , a ′ ) ) Target = r + γ * max(Q(s', a')) Target=r+γmax(Q(s,a))
其中,r是即时奖励,γ是折扣因子,s’是下一个状态,a’是在下一个状态下采取的动作。目标值表示在下一个状态下,通过选择最优的动作所能获得的最大累积奖励。

DQN算法的训练过程中,通过最小化预测Q值与目标Q值之间的均方误差来更新网络的参数。具体来说,我们使用以下损失函数:
L o s s = ( Q ( s , a ) − T a r g e t ) 2 Loss = (Q(s, a) - Target)^2 Loss=(Q(s,a)Target)2
利用神经网络的反向传播算法,更新网络参数,让预测的Q(s, a)逼近Target。

两个关键的方法:

  • 经验回放

DQN属于离线学习方法,所以可以对数据加以重复利用,维护一个回放缓冲区,每次将环境中采样得到的四元组(s, a, r, s’)存放在缓冲区,在训练过程中再对缓冲区的历史数据进行随机训练。这样可以使样本满足独立假设。打破样本之间的相关性。

  • 目标网络

为了让Q(s, a)逼近r + γ * max(Q(s’, a’)),神经网络更新参数的同时,目标r + γ * max(Q(s’, a’))也不断改变,容易让神经网络不太稳定,因此加入了目标网络思想,利用两套神经网络,先将目标网络的参数固定住,然后隔N步再进行更新。原来的训练网络Q(s, a)使用正常梯度下降方法来更新,每一步都要进行更新。

综合所述得到伪代码如下:

请添加图片描述

代码解析
  • 首先是经验回放的类设计
class ReplayBuffer:''' 经验回放池 '''def __init__(self, capacity):# collections是python标准库,有list、deque、字典等数据结构。self.buffer = collections.deque(maxlen=capacity)  # 双端队列,先进先出,一端添加,一端删除def add(self, state, action, reward, next_state, done):  # 将数据加入bufferself.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):  # 从buffer中采样数据,数量为batch_size# random.sample(population, k) 在抽样的序列中,随机抽取k个元素,返回包含元素的列表。transitions = random.sample(self.buffer, batch_size)# 将每一个元素的对应位置值打包为一个元组,返回以这个元组为列的迭代器。state, action, reward, next_state, done = zip(*transitions)return np.array(state), action, reward, np.array(next_state), donedef size(self):  # 目前buffer中数据的数量return len(self.buffer)
  • 然后是Q值网络
class Qnet(torch.nn.Module):''' 只有一层隐藏层的Q网络 '''def __init__(self, state_dim, hidden_dim, action_dim):super(Qnet, self).__init__()# 输入层有状态个神经元self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))  # 隐藏层使用ReLU激活函数return self.fc2(x)
  • DQN算法类实现
class DQN:''' DQN算法 '''def __init__(self, state_dim, hidden_dim, action_dim, learning_rate, gamma,epsilon, target_update, device):self.action_dim = action_dimself.q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)  # Q网络# 目标网络self.target_q_net = Qnet(state_dim, hidden_dim,self.action_dim).to(device)# 使用Adam优化器self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = gamma  # 折扣因子self.epsilon = epsilon  # epsilon-贪婪策略self.target_update = target_update  # 目标网络更新频率self.count = 0  # 计数器,记录更新次数self.device = devicedef take_action(self, state):  # epsilon-贪婪策略采取动作if np.random.random() < self.epsilon:action = np.random.randint(self.action_dim)else:state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.q_net(state).argmax().item()return actiondef update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 状态传入Q网络, 在输出的每一个状态对应的Q值中,选出actions对应的Q值,也就是类似监督学习中的yiq_values = self.q_net(states).gather(1, actions)  # Q值# 下个状态的最大Q值,这个输出类似监督学习的y标签,也就是目标值。max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)  # TD误差目标dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))  # 均方误差损失函数self.optimizer.zero_grad()  # PyTorch中默认梯度会累积,这里需要显式将梯度置为0dqn_loss.backward()  # 反向传播更新参数self.optimizer.step()if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict())  # 更新目标网络self.count += 1

DQN改进算法

Double DQN网络

在传统DQN网络中,我们训练网络是用来一个Q网络进行估计动作价值函数的,并且利用贪婪策略来选择动作,这样容易到值对Q值的过高估计,从而影响到训练的稳定性和收敛性。

Double DQN网络利用了两个Q网络来解决这个问题,一个在线网络用来选择动作,另一个用来估计动作价值函数。具体来说,使用一个Q网络(称为"online"网络)选择下一个动作,然后使用另一个Q网络(称为"target"网络)估计该动作的值。通过将选择动作和估计值的过程分离,Double DQN可以更准确地估计动作值,并减少过高估计的影响。Double DQN的流程如下:

  1. 选择动作:从 online 网络中选择动作 a,即 a = argmax Q_online(s, a)。
  2. 执行动作a:获得奖励s’和r;
  3. 估计下一个状态动作值:先使用online网络选取s’下的动作a’ = argmax Q_online(s’, a; θ),然后利用目标网络来计算s’和a’下的Q值。
  4. 计算TD目标值:y = r + γ * Q_target(s’, a’; θ’)
  5. 计算损失函数:loss = (Q_online(s, a; θ) - y)^2
  6. 更新在线网络参数:使用梯度下降法根据损失函数对在线网络的参数θ进行更新。
  7. 更新目标网络参数:定期将在线网络的参数θ复制到目标网络的参数θ’。

这么看来,DQN与Double DQN的差异仅仅是体现在计算Q值得时候动作的选取,DQN中是直接用目标网络进行max选取动作,而Double DQN中利用在线网络选取动作,再用这个动作来在目标网络中计算Q值。

请添加图片描述

因此在代码差异上其实并不太大,DDQN代码如下:

# 下个状态的最大Q值if self.dqn_type == 'DoubleDQN': # DQN与Double DQN的区别max_action = self.q_net(next_states).max(1)[1].view(-1, 1)max_next_q_values = self.target_q_net(next_states).gather(1, max_action)else: # DQN的情况max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1, 1)

这个环境用的是倒立摆,由于动作空间是连续的,但是DQN只能处理离散动作空间,因此这里提一嘴如何将连续动作空间离散化。

倒立摆动作空间:力矩大小[-2,2]

倒立摆的Q值:倒立摆向上直立时候能选取的最大Q值就为0,如果Q值大于0说明出现过高估计。

离散化:将动作空间离散化为11个动作。动作[0, 1, 2, 3, … 9, 10]分别代表力矩[-2, -1.6, -1.2…, 1.2 , 1.6, 2.0]


在Gym中,要step环境需要把离散得动作专户为连续的数值,代码如下:

def dis_to_con(discrete_action, env, action_dim):  # 离散动作转回连续的函数action_lowbound = env.action_space.low[0]  # 连续动作的最小值action_upbound = env.action_space.high[0]  # 连续动作的最大值return action_lowbound + (discrete_action /(action_dim - 1)) * (action_upbound -action_lowbound)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/19248.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从零开始的前后端分离项目学习(前后端从零环境搭建)

一、 前后端分离介绍&#xff1a; 前端独立编写客户端代码&#xff08;用户交互数据展示&#xff09;&#xff0c;后端独立编写服务端代码&#xff08;提供数据处理接口&#xff09;&#xff0c;并提供数据接口就行。 前端通过Ajax访问后端数据借口&#xff0c;将model展示到…

数据结构05:树与二叉树[C++][哈夫曼树HuffmanTree]

图源&#xff1a;文心一言 小白友好、代码可跑&#xff0c;但是不一定适合考研~~&#x1f95d;&#x1f95d; 第1版&#xff1a;查资料、画导图、画配图~&#x1f9e9;&#x1f9e9; 参考用书&#xff1a;王道考研《2024年 数据结构考研复习指导》 参考用书配套视频&#xf…

form表单使用Select 选择器

案例: ps&#xff1a;年度的值类型要与select 选择器中 value 类型一致&#xff01;&#xff01; 如果input框中显示的是数字&#xff0c;说明年度的值没有与选择器中的的value一致&#xff01;&#xff01;&#xff01; YearNum 要与 value 类型一致&#xff01;&#xff01…

Jmeter的常用设置(一)

文章目录 前言一、Jmeter设置中文 方法一&#xff08;临时改为中文&#xff09;方法二&#xff08;永久改成中文&#xff09;二、启动Jmeter的两种方式 方法一&#xff08;直接启动&#xff0c;不打开cmd窗口&#xff09;方法二&#xff08;带有cmd窗口的启动&#xff09;三、调…

走进Vue2飞入Vue3

✅作者简介&#xff1a;大家好&#xff0c;我是Cisyam&#xff0c;热爱Java后端开发者&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;Cisyam-Shark的博客 &#x1f49e;当前专栏&#xff1a; 前端相关 ✨特色专栏&…

第一代Spring Cloud核心组件

第一代Spring Cloud核心组件&#xff08;Spring Cloud Netflix&#xff09; Eureka服务注册中心(服务注册中心:Eureka,Nacos,Zookeeper,Consul) Ribbon负载均衡 Hystrix熔断器 Feign远程调用组件(Feign RestTemplate Ribbon Hystrix) GateWay网关组件 Config分布式配置中心 …

【C语言】-- 死循环了怎么办?

#include <stdio.h> int main() {int i 0;int arr[] {1,2,3,4,5,6,7,8,9,10};for(i0; i<12; i){arr[i] 0;printf("hello\n");}return 0; } 阅读上面这个代码&#xff0c;我们会认为这不就是简单的数组访问越界么。那么这段代码就应该会报错&#xff0c;…

三维重建以及神经渲染中的学习(三)

三维重建以及神经渲染中的学习 公众号AI知识物语 本文内容为参加过去一次暑期课程学习时的笔记&#xff0c;浅浅记录下。 三维图形可控生成&#xff1a; 1&#xff1a;学习一个图形生成模型 2&#xff1a;具有可控三维变量&#xff1a;1物体形状&#xff1b;2物体位置&…

Ubuntu 放弃了战斗向微软投降

导读这几天看到 Ubuntu 放弃 Unity 和 Mir 开发&#xff0c;转向 Gnome 作为默认桌面环境的新闻&#xff0c;作为一个Linux十几年的老兵和Linux桌面的开发者&#xff0c;内心颇感良多。Ubuntu 做为全世界Linux界的桌面先驱者和创新者&#xff0c;突然宣布放弃自己多年开发的Uni…

七牛云的使用(图片超详讲解)

一、为什么要使用七牛云的OSS(对象存储服务)&#xff1f; 二、七牛云使用&#xff1a; 登录七牛云官网&#xff0c;注册并认证 (初次认证有30天免费使用权限)新建存储空间 点击创建的空间名字&#xff0c;进入 空间概括如下&#xff1a; 阅读帮助文档&#xff0c;在自己的…

Java微服务金融项目智牛股-基础知识三(Restful、HATEOAS、GRPC、SEATA )

Restful定义 Restful是一种软件架构与设计风格&#xff0c; 并非一套标准&#xff0c; 只提供了一些原则与约定条件。REST提供了一组架构约束&#xff0c;当作为一个整体来应⽤用时&#xff0c;强调组件交互的可伸缩性。接⼝口的通⽤用性、组件的独⽴立部署、以及⽤用来减少交…

spring cloud 之 Hystrix

Hystrix概述 Hystrix是一个供分布式系统使用&#xff0c;提供延迟和容错功能&#xff0c;保证复杂的分布系统在面临不可避免的失败是时&#xff0c;仍具有弹性。 当服务器A调用服务器B时&#xff0c;如果服务器B宕机&#xff0c;则服务器A不去调用。当服务器B在时间范围内未响…