Noisy DQN 跑 CartPole-v1

gym 0.26.1
CartPole-v1
NoisyNet DQN

NoisyNet 就是把原来Linear里的w/b 换成 mu + sigma * epsilon, 这是一种非常简单的方法,但是可以显著提升DQN的表现。
和之前最原始的DQN相比就是改了两个地方,一个是Linear改成了NoisyLinear,另外一个是在agenttake_action的时候策略 由ε-greedy改成了直接取argmax。详细见下面的代码。

本文的实现参考王树森的深度强化学习。

引用书上的一段话, 噪声DQN本身就带有随机性,可以鼓励探索,起到与ε-greedy策略相同的作用,直接用a_t = argmax Q(s,a,epsilon; mu,sigma), 作为行为策略,效果比ε-greedy更好。

import gym
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import random
import collections
from tqdm import tqdm
import matplotlib.pyplot as plt
from d2l import torch as d2l
import rl_utils
import mathclass ReplayBuffer:"""经验回放池"""def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity) # 队列,先进先出def add(self, state, action, reward, next_state, done): # 将数据加入bufferself.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size): # 从buffer中采样数据,数量为batch_sizetransition = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transition)return np.array(state), action, reward, np.array(next_state), donedef size(self): # 目前buffer中数据的数量return len(self.buffer)class NoisyLinear(nn.Linear):def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):super().__init__(in_features, out_features, bias)self.sigma_weight = nn.Parameter(torch.full((out_features, in_features), sigma_init))self.register_buffer("epsilon_weight", torch.zeros(out_features, in_features))if bias:self.sigma_bias = nn.Parameter(torch.full((out_features,), sigma_init))self.register_buffer("epsilon_bias", torch.zeros(out_features))self.reset_parameters()def reset_parameters(self):std = math.sqrt(3 / self.in_features)self.weight.data.uniform_(-std, std)self.bias.data.uniform_(-std, std)def forward(self, x, is_training=True):self.epsilon_weight.normal_()bias = self.biasif bias is not None:self.epsilon_bias.normal_()bias = bias + self.sigma_bias * self.epsilon_bias.dataif is_training:return F.linear(x, self.weight + self.sigma_weight * self.epsilon_weight.data, bias)else:return F.linear(x, self.weight, bias)class Q(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super().__init__()self.fc1 = NoisyLinear(state_dim, hidden_dim)self.fc2 = NoisyLinear(hidden_dim, action_dim)def forward(self, x, is_training=True):x = F.relu(self.fc1(x, is_training)) # 隐藏层之后使用ReLU激活函数return self.fc2(x, is_training)class DQN:"""DQN算法"""def __init__(self, state_dim, hidden_dim, action_dim, lr, gamma, target_update, device):self.action_dim = action_dimself.q = Q(state_dim, hidden_dim, action_dim).to(device) # Q网络self.target_q = Q(state_dim, hidden_dim, action_dim).to(device) # 目标网络self.target_q.load_state_dict(self.q.state_dict())  # 加载参数self.optimizer = torch.optim.Adam(self.q.parameters(), lr=lr)self.gamma = gammaself.target_update = target_update # 目标网络更新频率self.count = 0 # 计数器,记录更新次数self.device = devicedef take_action(self, state): # 这个地方就不用epsilon-贪婪策略state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)action = self.q(state).argmax().item()return actiondef update(self, transition_dict):states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).reshape(-1,1).to(self.device)rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).reshape(-1,1).to(self.device)next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'], dtype=torch.float).reshape(-1,1).to(self.device)q_values = self.q(states).gather(1, actions) # Q值# 下个状态的最大Q值max_next_q_values = self.target_q(next_states).max(1)[0].reshape(-1,1)q_targets = rewards + self.gamma * max_next_q_values * (1- dones) # TD误差loss = F.mse_loss(q_values, q_targets) # 均方误差self.optimizer.zero_grad() # 梯度清零,因为默认会梯度累加loss.mean().backward() # 反向传播self.optimizer.step() # 更新梯度if self.count % self.target_update == 0:self.target_q.load_state_dict(self.q.state_dict())self.count += 1
lr = 2e-3
num_episodes = 500
hidden_dim = 128
gamma = 0.98
target_update = 10
buffer_size = 10000
minimal_size = 500
batch_size = 64
device = d2l.try_gpu()
print(device)env_name = "CartPole-v1"
env = gym.make(env_name)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
replay_buffer = ReplayBuffer(buffer_size)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, target_update, device)
return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc=f'Iteration {i}') as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0state = env.reset()[0]done, truncated= False, Falsewhile not done and not truncated :action = agent.take_action(state)next_state, reward, done, truncated, info = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes / 10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title(f'Noisy DQN on {env_name}')
plt.show()

这次是在pycharm上运行jupyter file,结果如下:




效果对比之前的DQN 详细参考这篇 表现是显著提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/325986.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关于位运算

只出现一次的数字&#xff1a; 给你一个非空整数数组 nums&#xff0c;除了某个元素只出现一次以外&#xff0c;其余每个元素均出现两次。找出那个只出现一次的元素。 class Solution { public:int singleNumber(vector<int>& nums) {int j0;for(auto i:nums){j^i…

62.网游逆向分析与插件开发-游戏增加自动化助手接口-游戏公告类的C++还原

内容来源于&#xff1a;易道云信息技术研究院VIP课 上一个内容&#xff1a;游戏红字公告功能的逆向分析-CSDN博客 码云地址&#xff08;master分支&#xff09;&#xff1a;https://gitee.com/dye_your_fingers/sro_-ex.git 码云版本号&#xff1a;0888e34878d9e7dd0acd08ef…

VMware Tools 启动脚本未能在虚拟机中成功运行。如果您在此虚拟机中配置了自定义启动脚本,请确保该脚本没有错误。您也可以提交支持请求,报告此问题。

问题描述&#xff1a;今天打开centos7虚拟机就是直接打不开了报了下面的错误&#xff0c;也没有动任何东西&#xff0c;点确定后&#xff0c;也是依然没有反应 问题原因&#xff1a;可能是虚拟机中的内存满了&#xff0c;需要清理内存 解决方法如下 首先cmd打开终端敲入如下命…

DNS主从服务器、转发(缓存)服务器

一、主从服务器 1、基本含义 DNS辅助服务器是一种容错设计&#xff0c;考虑的是一旦DNS主服务器出现故障或因负载太重无法及时响应客户机请求&#xff0c;辅助服务器将挺身而出为主服务器排忧解难。辅助服务器的区域数据都是从主服务器复制而来&#xff0c;因此辅助服务器的数…

pytest安装失败,报错Could not find a version that satisfies the requirement pytest

问题 安装pytest失败&#xff0c;尝试使用的命令有 pip install pytest pip3 install pytest pip install -U pytest pip install pytest -i https://pypi.tuna.tsinghua.edu.cn/simple但是都会报同样的错&#xff1a; 解决方案 发现可能是挂了梯子的原因&#xff0c;关掉…

Android WiFi基础概览

Android WiFi 基础概览 1、WiFi协议2、Android WLAN 架构2.1 应用框架2.2 Wi-Fi 服务2.3 Wi-Fi HAL 3、相关编译 android13-release 1、WiFi协议 Wi-Fi&#xff08;无线通信技术&#xff09;_百度百科 2.4GHz 频段支持以下标准&#xff08;802.11b/g/n/ax&#xff09;&#xff…

根据MySql的表名,自动生成实体类,模仿ORM框架

ORM框架可以根据数据库的表自动生成实体类&#xff0c;以及相应CRUD操作 本文是一个自动生成实体类的工具&#xff0c;用于生成Mysql表对应的实体类。 新建Winform窗体应用程序AutoGenerateForm&#xff0c;框架(.net framework 4.5)&#xff0c; 添加对System.Configuration的…

java基于SSM的毕业生就业管理系统+vue论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本毕业生就业管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信…

更改ERPNEXT源

更改ERPNEXT源 一&#xff0c; 更改源 针对已经安装了erpnext的&#xff0c;需要更改源的情况&#xff1a; 1, 更改为官方默认源, 进入frapp-bench的目录&#xff0c; 然后执行: bench remote-reset-url frappe //重设frappe的源为官方github地址。 bench remote-reset-url…

java基于ssm框架的校园闲置物品交易平台论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本校园闲置物品交易平台就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据…

10分钟设置免费海外远程桌面使用Amazon Lightsail服务的免费套餐轻松搭建远程桌面

本篇文章授权活动官方亚马逊云科技文章转发、改写权&#xff0c;包括不限于在 亚马逊云科技开发者社区, 知乎&#xff0c;自媒体平台&#xff0c;第三方开发者媒体等亚马逊云科技官方渠道。 目录 前言 使用教程 启动 Amazon Lightsail 实例 配置远程桌面 启动远程桌面 使…

【Matplotlib】基础设置之文本公式04

处理文本&#xff08;数学表达式&#xff09; 在字符串中使用一对 $$ 符号可以利用 Tex 语法打出数学表达式&#xff0c;而且并不需要预先安装 Tex。在使用时我们通常加上 r 标记表示它是一个原始字符串&#xff08;raw string&#xff09; import matplotlib.pyplot as plt …