【强化学习】Deep Q Learning

Deep Q Learning

在前两篇文章中,我们发现RL模型的目标是基于观察空间 (observations) 和最大化奖励和 (maximumize sum rewards) 的。

如果我们能够拟合出一个函数 (function) 来解决上述问题,那就可以避免存储一个 (在Double Q-Learning中甚至是两个) 巨大的Q_table。

Tabular -> Function

  • Continous Observation: 函数能够让我们处理连续的观察空间,而表只能处理离散的。
  • Saving the space: 不用存储 len(state) * len(action) 大小的Q_table

在早期人们试过使用核函数或者线性函数等各种方法去拟合这个function,但后来深度神经网络出现后人们纷纷开始研究如何用DNN来拟合。

然而以上的拟合方式不免存在一个问题,我们期望得到一个DNN,使得DNN(state)->Q-value

可是强化学习中,最好的Q-value在开始时是不知道的 (这也是强化学习和机器学习不一样的地方:我们不知道能否训练到一个Q值,直到有人把它训练出来),这就导致我们在训练过程中没有目标函数。

Natural Deep Q Learning

所有的第一步必须从高维的感官输入中获得对环境的有效表示

深度Q网络(DQN)是一种将深度学习和Q学习相结合的强化学习方法。DQN由DeepMind于2015年提出,并在玩Atari视频游戏方面取得了显著的成功。DQN的核心原理是使用深度神经网络来近似Q函数,即在给定状态下采取某一动作的预期累积奖励。

DQN的关键创新

  1. 使用神经网络近似Q函数

    • 传统的Q学习使用表格(Q表)来存储每个状态-动作对的Q值。当状态空间很大或连续时,这变得不切实际。
    • DQN通过使用深度神经网络来近似Q函数,克服了这一限制。网络输入是状态,输出是该状态下所有可能动作的Q值。
  2. 经验回放

    • DQN引入了经验回放机制,即将代理的经验(状态、动作、奖励、新状态)存储在回放缓冲区中。

      image-20231114211049019
    • 训练时,从这个缓冲区中随机抽取小批量经验进行学习。这增加了数据的多样性,减少了样本之间的相关性,从而稳定了训练。

  3. 目标网络

    • DQN使用两个结构相同但参数不同的网络:一个是在线网络 (dqn_model),用于当前Q值的估计;另一个是目标网络 (target_model),用于计算目标Q值。
    • 目标网络的参数定期从在线网络复制过来,但不是每个训练步骤都更新。这减少了学习过程中的震荡,提高了稳定性。
    image-20231114211236348

训练过程

  • 在每个时间步,代理根据当前的Q值(通常结合探索策略,如ε-贪婪)选择一个动作,接收环境的反馈(新状态和奖励),并将这个转换存储在经验回放缓冲区中。
  • 训练神经网络时,从缓冲区中随机抽取一批经验,然后使用贝尔曼方程计算目标Q值和预测Q值,通过最小化这两者之间的差异来更新网络参数。

DQN解决月球着陆问题

导入环境

import time
from collections import defaultdictimport gymnasium as gym
import numpy as np
import randomfrom matplotlib import pyplot as plt, animation
from IPython.display import display, clear_output
env = gym.make("LunarLander-v2", continuous=False, render_mode='rgb_array')

定义经验池

class ExperienceBuffer:def __init__(self, size=0):self.states = []self.actions = []self.rewards = []self.states_next = []self.actions_next = []self.size = 0def clear(self):self.__init__()def append(self, s, a, r, s_n, a_n):self.states.append(s)self.actions.append(a)self.rewards.append(r)self.states_next.append(s_n)self.actions_next.append(a_n)self.size += 1def batch(self, batch_size=128):indices = np.random.choice(self.size, size=batch_size, replace=True)return  (np.array(self.states)[indices],np.array(self.actions)[indices],np.array(self.rewards)[indices],np.array(self.states_next)[indices],np.array(self.actions_next)[indices],)
import torchfrom torch import nn
from torch.nn.functional import relu
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

定义DQN

class DQN(nn.Module):def __init__(self, state_size, action_size):super().__init__()self.state_size = state_sizeself.action_size = action_sizeself.hidden_size = 32self.linear_1 = nn.Linear(self.state_size, self.hidden_size)self.linear_2 = nn.Linear(self.hidden_size, self.action_size)nn.init.uniform_(self.linear_1.weight, a=-0.1, b=0.1)nn.init.uniform_(self.linear_2.weight, a=-0.1, b=0.1)def forward(self, state):if not isinstance(state, torch.Tensor):state = torch.tensor([state], dtype=torch.float)state = state.to(device)return self.linear_2(relu(self.linear_1(state)))

定义policy

def policy(model, state, eval=False):eps = 0.1if not eval and random.random() < eps:return random.randint(0, model.action_size - 1)else:q_values = model(torch.tensor([state], dtype=torch.float))action = torch.multinomial(F.softmax(q_values), num_samples=1)return int(action[0])

collect

dqn_model = DQN(state_size=8, action_size=4).to(device)
target_model = DQN(state_size=8, action_size=4).to(device)
from tqdm.notebook import tqdm
# 学习率
alpha = 0.9
# 折扣因子
gamma = 0.95
# 训练次数
episode = 1000
experience_buffer = ExperienceBuffer()eval_iter = 100
eval_num = 100# collect
def collect():for e in tqdm(range(episode)):state, info = env.reset()action = policy(dqn_model, state)sum_reward = 0while True:state_next, reward, terminated, truncated, info_next = env.step(action)action_next= policy(dqn_model, state_next)sum_reward += rewardexperience_buffer.append(state, action, reward, state_next, action_next)if terminated or truncated:breakstate = state_nextinfo = info_nextaction = action_next

learning

## learning
from torch.optim import Adamloss_fn = nn.MSELoss()
optimizer = Adam(lr=1e-5, params=dqn_model.parameters())losses = []
target_fix_period = 5
epoch = 3def train():for e in range(epoch):batch_size = 128for i in range(experience_buffer.size // batch_size):s, a, r, s_n, a_n = experience_buffer.batch(batch_size)s = torch.tensor(s, dtype=torch.float).to(device)s_n = torch.tensor(s_n, dtype=torch.float).to(device)r = torch.tensor(r, dtype=torch.float).to(device)a = torch.tensor(a, dtype=torch.long).to(device)a_n = torch.tensor(a_n, dtype=torch.long).to(device)y = r + target_model(s_n).gather(1, a_n.unsqueeze(1)).squeeze(1)y_hat = dqn_model(s).gather(1, a.unsqueeze(1)).squeeze(1)loss = loss_fn(y, y_hat)optimizer.zero_grad()loss.backward()optimizer.step()if i % 500 == 0:print(f'i == {i}, loss = {loss} ')if i % target_fix_period == 0:target_model.load_state_dict(dqn_model.state_dict())

a_n:动作
s_n:状态

image-20231205221613164

image-20231205221643890

将状态 s_n 作为输入,target_model的输出是针对每个可能动作的 Q 值;如果 s_n 包含多个状态(比如一个批量),输出将是一个批量的 Q 值

image-20231205221710717

image-20231205221746045

image-20231205221827050

训练

for i in range(10):print(f'collect/train: {i}')experience_buffer.clear()collect()train()

结果

task_num = 10
frames = []for _ in range(10):state, _ = env.reset()while True:action = policy(dqn_model, state, eval=True)state_next, reward, terminated, truncated, info_next = env.step(action)frames.append(env.render())if terminated or truncated:break

output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/288498.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于python的图表生成系统,python导入数据生成图表

大家好&#xff0c;小编来为大家解答以下问题&#xff0c;用python将excel的内容生成图像&#xff0c;python画的图表如何导入word&#xff0c;现在让我们一起来看看吧&#xff01; 今天的主题是 Excel&#xff0c;相信大家都比较熟悉吧。而且我相信&#xff0c;大家在日常使用…

计算机组件操作系统BIOS的相关知识思维导图

&#x1f3ac; 艳艳耶✌️&#xff1a;个人主页 &#x1f525; 个人专栏 &#xff1a;《产品经理如何画泳道图&流程图》 ⛺️ 越努力 &#xff0c;越幸运 目录 一、运维实施工程师需要具备的知识 1、运维工程师、实施工程师是啥&#xff1f; 2、运维工程师、实施工…

3.基本数据类型

3.基本数据类型 int整型 作用:用来记录人的年龄&#xff0c;出生年份&#xff0c;学生人数等整数相关的状态 定义:age18 birthday1990 student_count48 float浮点型 作用&#xff1a;用来记录人的身高&#xff0c;体重&#xff0c;薪资等小数相关的状态 定义:height172.3 w…

el-select 全选

<template><div class"container"><el-selectv-model"choosedList"clearablemultiplecollapse-tagsplaceholder"请选择"change"select_Change"><div style"padding: 0 20px; line-height: 34px">&l…

【离散数学】——期末刷题题库(树其二)

&#x1f383;个人专栏&#xff1a; &#x1f42c; 算法设计与分析&#xff1a;算法设计与分析_IT闫的博客-CSDN博客 &#x1f433;Java基础&#xff1a;Java基础_IT闫的博客-CSDN博客 &#x1f40b;c语言&#xff1a;c语言_IT闫的博客-CSDN博客 &#x1f41f;MySQL&#xff1a…

MyBatis-Plus(一):根据指定字段更新或插入

根据指定字段更新或插入 1、概述2、实现方式2、总结 1、概述 MyBatis-Plus中提供了一个saveOrUpdate()方法&#xff0c;默认情况下可以根据主键是否存在进行更新或插入操作&#xff0c;但是实际场景中&#xff0c;根据指定字段进行更新或插入的情况也非常多见&#xff0c;今天…

计算机组成原理(复习题)

更多复习详情请见屌丝笔记 一、选择题 计算机系统概述 1、至今为止&#xff0c;计算机中的所有信息仍以二进制方式表示的理由是&#xff08; C &#xff09;。 A.运算速度快 B.信息处理方便 C.物理器件性能所致 D.节约元件 2、运算器的核心功能部件是&#xff08; D &am…

安防监控EasyCVR平台如何通过api接口设置实时流的sei数据实现画框等操作?

国标GB28181视频监控系统EasyCVR平台采用了开放式的网络结构&#xff0c;支持高清视频的接入和传输、分发&#xff0c;能提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、云存储等丰富的视频能力&#xff0…

DC-6靶场

DC-6靶场下载&#xff1a; https://www.five86.com/downloads/DC-6.zip 下载后解压会有一个DC-3.ova文件&#xff0c;直接在vm虚拟机点击左上角打开-->文件-->选中这个.ova文件就能创建靶场&#xff0c;kali和靶机都调整至NAT模式&#xff0c;即可开始渗透 首先进行主…

基于C#的线上特价商品购物系统asp.net+sqlserver

基于asp.net架构和sql server数据库&#xff0c; 三层架构 功能模块&#xff1a; 基于C#的线上特价商品购物系统 前台主要实现了购买商品和查看商品信息的功能 后台主要对前台的商品信息及订单进行管理。 &#xff08;1&#xff09;订单管理&#xff1a;在前台会员购买商品…

新型智慧视频监控系统:基于TSINGSEE青犀边缘计算AI视频识别技术的应用

边缘计算AI智能识别技术在视频监控领域的应用有很多。这项技术结合了边缘计算和人工智能技术&#xff0c;通过在摄像头或网关设备上运行AI算法&#xff0c;可以在现场实时处理和分析视频数据&#xff0c;从而实现智能识别和分析。目前来说&#xff0c;边缘计算AI视频智能技术可…

【Java异常】idea 报错:无效的目标发行版:17 的解决办法

【Java异常】idea 报错&#xff1a;无效的目标发行版&#xff1a;17 的解决办法 一&#xff0c;问题来源 springcloud的第一个demo项目就给我干趴了 二、原因分析 java: 无效的目标发行版: 17 原因就是 JDK 版本不对。从 IDEA 编辑器中可以找到问题的原因所在&#xff0c;…