策略梯度AC算法 - CartPole环境, 使用RNN作为策略网络

news/2025/1/8 14:46:53/文章来源:https://www.cnblogs.com/tshaaa/p/18659675

参考资料:

  • Vanila Policy Gradient with a Recurrent Neural Network Policy – Abhishek Mishra – Artificial Intelligence researcher
  • 动手学强化学习-HandsOnRL,本文中展示的代码都是基于《动手学RL》的代码库。
  • rlorigro/cartpole_policy_gradient: a simple 1st order PG solution to the openAI cart pole problem (adapted from @ts1839) and an RNN based PG variation on this

为什么使用RNN

对于一些简单的环境,只需要知道当前时刻的状态以及动作,就可以预测下一个时刻的状态 (即环境满足一元的马尔可夫假设)。比如说车杆环境:

但是:

  • 对于一些复杂的环境,可能需要多个时刻的状态才足以预测下一时刻。
  • 在部分可观测的环境,我们无从知道环境的真实状态。

在DQN玩雅达利游戏中,作者是将连续几帧的图像作为状态传入网络。一种可能更好的替代方法是使用RNN作为策略网络。

代码

基于动手学强化学习-HandsOnRL的ActorCritic代码进行修改。主要改动:

  • 策略网络使用RNN.
  • 每次tack_action的时候就计算log_probs,并且记录下来,而不是最后在update时一起计算。记录log_probs时,应该使用Tensor,保留梯度。
  • 每个episode开始的时候,应该清空隐状态h.

导入包

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import rl_utils
from tqdm import tqdm

值网络和RNN策略网络

class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class RNNPolicy(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(RNNPolicy, self).__init__()self.rnn = nn.GRUCell(input_size=state_dim, hidden_size=hidden_dim)self.fc = nn.Linear(hidden_dim, action_dim)def forward(self, x, hidden_state=None): # 传入x和hidden_stateh = self.rnn(x, hidden_state)x = F.leaky_relu(h)return F.softmax(self.fc(x), dim=-1), h

ActorCritic-Agent

class RNN_ActorCritic:def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr, gamma, device) -> None:self.actor = RNNPolicy(state_dim, hidden_dim, action_dim)self.critic = ValueNet(state_dim, hidden_dim)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.device = devicedef take_action(self, state, hidden_state=None):state = torch.tensor([state], dtype=torch.float).to(self.device)action_prob, hidden_state = self.actor(state, hidden_state)action_dist = torch.distributions.Categorical(action_prob)action = action_dist.sample()log_prob = action_dist.log_prob(action)return action.item(), hidden_state, log_prob # 返回action, 隐状态, log_probdef update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)# 使用torch.stack()将log_probs转换为tensor,可以保留梯度log_probs = torch.stack(transition_dict['log_probs']).to(self.device)td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)td_delta = td_target - self.critic(states) actor_loss = torch.mean(-log_probs * td_delta.detach())critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()  critic_loss.backward() self.actor_optimizer.step() self.critic_optimizer.step() 

训练过程

def train_on_policy_agent(env, agent, num_episodes, render=False):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'next_states': [], 'rewards': [], 'dones': [], 'log_probs': []}state = env.reset()# 初始化hidden_statehidden_state = Nonedone = Falsewhile not done:action, hidden_state, log_prob = agent.take_action(state, hidden_state)next_state, reward, done, _ = env.step(action)if render:env.render()transition_dict['states'].append(state)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)# 无需记录action, 因为记录action是为了计算log_prob,而我们已经算好了log_probtransition_dict['log_probs'].append(log_prob)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_list

训练

actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 1000
hidden_dim = 128
gamma = 0.98
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = RNN_ActorCritic(state_dim, hidden_dim, action_dim, actor_lr, critic_lr,gamma, device)
return_list = train_on_policy_agent(env, agent, num_episodes)

CartPole结果

这是使用RNN策略网络的结果:
image

这是使用普通MLP的结果 (Hands-On-RL原来的结果):
image

用RNN效果相对差的可能原因:

  • 环境太简单。
  • RNN要学习的参数更多,收敛速度更慢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/865882.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Harbor配置https

harbor是不附带任何证书的,因此默认情况下使用http来进行访问 K8S在使用harbor作为私有仓库时或生产环境下强烈建议使用https 生成证书 生产环境下,需要从CA获取证书,测试或者开发可以使用OpenSSL自己生成证书 生成私钥 # 创建证书的存储目录 mkdir /home/ssl cd /home/sslo…

vue3引入ts以及js文件使用案例

ts:先确保项目正确集成TypeScript 添加tsconfig.json文件{"compilerOptions": {"target": "esnext","module": "esnext","strict": true,"jsx": "preserve","importHelpers": true…

coenzyme A 辅酶A

coenzyme A 化学式 C21H36N7O16P3S 分子量 767.535密度1.1335 g/cm3 (20C) 熔点 -5C 沸点 146 - 147是一种辅酶,值得注意的是其在合成和氧化脂肪酸的角色,和在三羧酸循环中氧化丙酮酸。

【每日一题】20250108

路的尽头什么都没有。但在路上遇见的不起眼的瞬间和记忆,最终会成就我们。【每日一题】一物体作匀加速直线运动,通过一段位移 \(\Delta x\) 所用的时间为 \(t_1\),紧接着通过下一段位移 \(\Delta x\) 所用时间为 \(t_2\).则物体运动的加速度为 A. \(\frac {2\Delta x( t_{1…

JS-21 字符串方法_charAt()

charAt方法返回指定位置的字符,参数是从0开始编号的var s =new String (zifuchuan)s.charAt(1)//"t"s.charAt(s.length-1)//"n" 如果参数为负数,或大于等于字符串的长度,charAt返回空字符串zifuchuan.charAt(-1)//""zifuchuan.charAt(9)//&…

应用质数和模算法

生成RSA加密密钥 密钥生成时先选择两个素数p和q,计算他们的乘积n=p*q,RSA的安全性是基于从n推导出p和q是很困难的,p和q越大,在给定n推到p和q的值越难,简单逻辑如下: 1、选择两个大的素数 2、计算n和phi(欧拉商函数) 3、选择一个公共指数e 4、计算私有指数d 5、使用公钥…

ASE100N03-ASEMI中低压N沟道MOS管ASE100N03

ASE100N03-ASEMI中低压N沟道MOS管ASE100N03编辑:ll ASE100N03-ASEMI中低压N沟道MOS管ASE100N03 型号:ASE100N03 品牌:ASEMI 封装:TO-252 最大漏源电流:100A 漏源击穿电压:30V 批号:最新 RDS(ON)Max:5.0mΩ 引脚数量:3 沟道类型:N沟道MOS管 芯片尺寸:MIL 漏电流: …

【信息安全】发布漏洞信息是否违法?如何量刑?

引言 在全球数字化进程加速的今天,信息安全问题成为了国家、企业乃至个人面临的重大挑战。网络漏洞作为信息安全的薄弱环节,一旦被恶意利用,可能导致数据泄露、系统崩溃甚至经济损失。随着安全研究人员和黑客的逐步增加,漏洞信息的披露也成为网络安全领域的一个重要议题。昨…

智能驾驶场地和道路测试服务

智能驾驶产品不断迭代更新,智驾功能日新月异。实车测试是智能驾驶功能和性能测试必不可少的手段之一,根据测试环境和测试项不同分为场地测试和道路测试。经纬恒润通过多年智能驾驶系统产品开发经验、实际场地和道路测试经验以及工具开发经验的积累,可以为客户提供智能驾驶相…

【unity】学习制作2D横板冒险游戏-1-

创建项目2D(Built-In Render Pipeline)核心模板 为2D游戏开发提供基础架构。 配置了适合2D项目的纹理导入、Sprite Packer、Scene视图、光照和正交摄像机等设置。3D(Built-In Render Pipeline)核心模板 开启3D游戏开发之旅,提供强大的3D场景构建能力。 配置了使用Unity内置…

水位自动监测摄像机

水位自动监测摄像机作为现代智能监控技术的重要应用,正在广泛应用于水利工程、防洪管理和环境监测等领域,显著提升了监测效率和数据准确性。水位自动监测摄像机利用高精度摄像头和先进的图像处理技术,能够实时监测水体的液位变化。其原理是通过安装在指定监测点的摄像头,连…

工厂安全生产检测系统 车间作业异常行为识别系统

工厂安全生产检测系统 车间作业异常行为识别系的核心是基于YOLOv5+Python深度学习算法,工厂安全生产检测系统 车间作业异常行为识别系统通过车间部署的摄像头能够更准确地分析判断工人是否按照规定的操作流程进行操作,是否存在违规行为,如未佩戴安全帽、未按规定使用工具等。…