强化学习从基础到进阶-案例与实践[5.1]:Policy Gradient-Cart pole游戏展示

强化学习从基础到进阶-案例与实践[5.1]:Policy Gradient-Cart pole游戏展示

  • 强化学习(Reinforcement learning,简称RL)是机器学习中的一个领域,区别与监督学习和无监督学习,强调如何基于环境而行动,以取得最大化的预期利益。
  • 基本操作步骤:智能体agent在环境environment中学习,根据环境的状态state(或观测到的observation),执行动作action,并根据环境的反馈reward(奖励)来指导更好的动作。

比如本项目的Cart pole小游戏中,agent就是动图中的杆子,杆子有向左向右两种action

1.Policy Gradient简介

  • 在强化学习中,有两大类方法,一种基于值(Value-based),一种基于策略(Policy-based

    • Value-based的算法的典型代表为Q-learningSARSA,将Q函数优化到最优,再根据Q函数取最优策略。
    • Policy-based的算法的典型代表为Policy Gradient,直接优化策略函数。
  • 采用神经网络拟合策略函数,需计算策略梯度用于优化策略网络。

    • 优化的目标是在策略π(s,a)的期望回报:所有的轨迹获得的回报R与对应的轨迹发生概率p的加权和,当N足够大时,可通过采样N个Episode求平均的方式近似表达。

    • 优化目标对参数θ求导后得到策略梯度:

## 安装依赖
!pip install pygame
!pip install gym
!pip install atari_py
!pip install parl
import gym
import os
import random
import collectionsimport paddle
import paddle.nn as nn
import numpy as np
import paddle.nn.functional as F

2.模型Model

这里的模型可以根据自己的需求选择不同的神经网络组建。

PolicyGradient用来定义前向(Forward)网络,可以自由的定制自己的网络结构。

class PolicyGradient(nn.Layer):def __init__(self, act_dim):super(PolicyGradient, self).__init__()act_dim = act_dimhid1_size = act_dim * 10self.linear1 = nn.Linear(in_features=4, out_features=hid1_size)self.linear2 = nn.Linear(in_features=hid1_size, out_features=act_dim)def forward(self, obs):out = self.linear1(obs)out = paddle.tanh(out)out = self.linear2(out)out = F.softmax(out)return out

3.智能体Agent的学习函数

这里包括模型探索与模型训练两个部分

Agent负责算法与环境的交互,在交互过程中把生成的数据提供给Algorithm来更新模型(Model),数据的预处理流程也一般定义在这里。

def sample(obs, MODEL):global ACTION_DIMobs = np.expand_dims(obs, axis=0)obs = paddle.to_tensor(obs, dtype='float32')act = MODEL(obs)act_prob = np.squeeze(act, axis=0)act = np.random.choice(range(ACTION_DIM), p=act_prob.numpy())return actdef learn(obs, action, reward, MODEL):obs = np.array(obs).astype('float32')obs = paddle.to_tensor(obs)act_prob = MODEL(obs)action = paddle.to_tensor(action.astype('int32'))log_prob = paddle.sum(-1.0 * paddle.log(act_prob) * F.one_hot(action, act_prob.shape[1]), axis=1)reward = paddle.to_tensor(reward.astype('float32'))cost = log_prob * rewardcost = paddle.sum(cost)opt = paddle.optimizer.Adam(learning_rate=LEARNING_RATE,parameters=MODEL.parameters())  # 优化器(动态图)cost.backward()opt.step()opt.clear_grad()return cost.numpy()

4.模型梯度更新算法

def run_train(env, MODEL):MODEL.train()obs_list, action_list, total_reward = [], [], []obs = env.reset()while True:# 获取随机动作和执行游戏obs_list.append(obs)action = sample(obs, MODEL) # 采样动作action_list.append(action)obs, reward, isOver, info = env.step(action)total_reward.append(reward)# 结束游戏if isOver:breakreturn obs_list, action_list, total_rewarddef evaluate(model, env, render=False):model.eval()eval_reward = []for i in range(5):obs = env.reset()episode_reward = 0while True:obs = np.expand_dims(obs, axis=0)obs = paddle.to_tensor(obs, dtype='float32')action = model(obs)action = np.argmax(action.numpy())obs, reward, done, _ = env.step(action)episode_reward += rewardif render:env.render()if done:breakeval_reward.append(episode_reward)return np.mean(eval_reward)

5.训练函数与验证函数

设置超参数

LEARNING_RATE = 0.001  # 学习率大小OBS_DIM = None
ACTION_DIM = None# 根据一个episode的每个step的reward列表,计算每一个Step的Gt
def calc_reward_to_go(reward_list, gamma=1.0):for i in range(len(reward_list) - 2, -1, -1):# G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1reward_list[i] += gamma * reward_list[i + 1]  # Gtreturn np.array(reward_list)def main():global OBS_DIMglobal ACTION_DIMtrain_step_list = []train_reward_list = []evaluate_step_list = []evaluate_reward_list = []# 初始化游戏env = gym.make('CartPole-v0')# 图像输入形状和动作维度action_dim = env.action_space.nobs_dim = env.observation_space.shape[0]OBS_DIM = obs_dimACTION_DIM = action_dimmax_score = -int(1e4)# 创建存储执行游戏的内存MODEL = PolicyGradient(ACTION_DIM)TARGET_MODEL = PolicyGradient(ACTION_DIM)# 开始训练print("start training...")# 训练max_episode个回合,test部分不计算入episode数量for i in range(1000):obs_list, action_list, reward_list = run_train(env, MODEL)if i % 10 == 0:print("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))batch_obs = np.array(obs_list)batch_action = np.array(action_list)batch_reward = calc_reward_to_go(reward_list)cost = learn(batch_obs, batch_action, batch_reward, MODEL)if (i + 1) % 100 == 0:total_reward = evaluate(MODEL, env, render=False) # render=True 查看渲染效果,需要在本地运行,AIStudio无法显示print("Test reward: {}".format(total_reward))if __name__ == '__main__':main()
W0630 11:26:18.969960   322 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0630 11:26:18.974581   322 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.start training...
Episode 0, Reward Sum 37.0.
Episode 10, Reward Sum 27.0.
Episode 20, Reward Sum 32.0.
Episode 30, Reward Sum 20.0.
Episode 40, Reward Sum 18.0.
Episode 50, Reward Sum 38.0.
Episode 60, Reward Sum 52.0.
Episode 70, Reward Sum 19.0.
Episode 80, Reward Sum 27.0.
Episode 90, Reward Sum 13.0.
Test reward: 42.8
Episode 100, Reward Sum 28.0.
Episode 110, Reward Sum 44.0.
Episode 120, Reward Sum 30.0.
Episode 130, Reward Sum 28.0.
Episode 140, Reward Sum 27.0.
Episode 150, Reward Sum 47.0.
Episode 160, Reward Sum 55.0.
Episode 170, Reward Sum 26.0.
Episode 180, Reward Sum 47.0.
Episode 190, Reward Sum 17.0.
Test reward: 42.8
Episode 200, Reward Sum 23.0.
Episode 210, Reward Sum 19.0.
Episode 220, Reward Sum 15.0.
Episode 230, Reward Sum 59.0.
Episode 240, Reward Sum 59.0.
Episode 250, Reward Sum 32.0.
Episode 260, Reward Sum 58.0.
Episode 270, Reward Sum 18.0.
Episode 280, Reward Sum 24.0.
Episode 290, Reward Sum 64.0.
Test reward: 116.8
Episode 300, Reward Sum 54.0.
Episode 310, Reward Sum 28.0.
Episode 320, Reward Sum 44.0.
Episode 330, Reward Sum 18.0.
Episode 340, Reward Sum 89.0.
Episode 350, Reward Sum 26.0.
Episode 360, Reward Sum 57.0.
Episode 370, Reward Sum 54.0.
Episode 380, Reward Sum 105.0.
Episode 390, Reward Sum 56.0.
Test reward: 94.0
Episode 400, Reward Sum 70.0.
Episode 410, Reward Sum 35.0.
Episode 420, Reward Sum 45.0.
Episode 430, Reward Sum 117.0.
Episode 440, Reward Sum 50.0.
Episode 450, Reward Sum 35.0.
Episode 460, Reward Sum 41.0.
Episode 470, Reward Sum 43.0.
Episode 480, Reward Sum 75.0.
Episode 490, Reward Sum 37.0.
Test reward: 57.6
Episode 500, Reward Sum 40.0.
Episode 510, Reward Sum 85.0.
Episode 520, Reward Sum 86.0.
Episode 530, Reward Sum 30.0.
Episode 540, Reward Sum 68.0.
Episode 550, Reward Sum 25.0.
Episode 560, Reward Sum 82.0.
Episode 570, Reward Sum 54.0.
Episode 580, Reward Sum 53.0.
Episode 590, Reward Sum 58.0.
Test reward: 147.2
Episode 600, Reward Sum 24.0.
Episode 610, Reward Sum 78.0.
Episode 620, Reward Sum 62.0.
Episode 630, Reward Sum 58.0.
Episode 640, Reward Sum 50.0.
Episode 650, Reward Sum 67.0.
Episode 660, Reward Sum 68.0.
Episode 670, Reward Sum 51.0.
Episode 680, Reward Sum 36.0.
Episode 690, Reward Sum 69.0.
Test reward: 84.2
Episode 700, Reward Sum 34.0.
Episode 710, Reward Sum 59.0.
Episode 720, Reward Sum 56.0.
Episode 730, Reward Sum 72.0.
Episode 740, Reward Sum 28.0.
Episode 750, Reward Sum 35.0.
Episode 760, Reward Sum 54.0.
Episode 770, Reward Sum 61.0.
Episode 780, Reward Sum 32.0.
Episode 790, Reward Sum 147.0.
Test reward: 123.0
Episode 800, Reward Sum 129.0.
Episode 810, Reward Sum 65.0.
Episode 820, Reward Sum 73.0.
Episode 830, Reward Sum 54.0.
Episode 840, Reward Sum 60.0.
Episode 850, Reward Sum 71.0.
Episode 860, Reward Sum 54.0.
Episode 870, Reward Sum 74.0.
Episode 880, Reward Sum 34.0.
Episode 890, Reward Sum 55.0.
Test reward: 104.8
Episode 900, Reward Sum 41.0.
Episode 910, Reward Sum 111.0.
Episode 920, Reward Sum 33.0.
Episode 930, Reward Sum 49.0.
Episode 940, Reward Sum 62.0.
Episode 950, Reward Sum 114.0.
Episode 960, Reward Sum 52.0.
Episode 970, Reward Sum 64.0.
Episode 980, Reward Sum 94.0.
Episode 990, Reward Sum 90.0.
Test reward: 72.2

项目链接fork一下即可运行

https://www.heywhale.com/mw/project/649e7dc170567260f8f12d54

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/3493.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

win10安装配置PostgreSQL

win10安装配置PostgreSQL 1 下载安装PostgreSQL ①进入官网https://www.postgresql.org/,点击页面中心处的download 也可以直接跳过下面的步骤(下面的步骤主要是为了帮助大家了解一般外国软件是如何从官网进入下载页面),直接进入下载页面,链…

java的注解方式和xml方式这两种方式对数据库进行操作详解

首先需要引入mybatisplus包 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.1.1</version> </dependency>第一种注解方式&#xff1a;参数是通过#{}来接收的 p…

LLaMA模型微调版本 Vicuna 和 Stable Vicuna 解读

Vicuna和StableVicuna都是LLaMA的微调版本&#xff0c;均遵循CC BY-NC-SA-4.0协议&#xff0c;性能方面Stable版本更好些。 CC BY-NC-SA-4.0是一种知识共享许可协议&#xff0c;其全称为"署名-非商业性使用-相同方式共享 4.0 国际"。 即 用的时候要署名原作者&#x…

信号链噪声分析18

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 提示&#xff1a;这里可以添加技术概要 到目前为止&#xff0c;我们考虑的是基带采样情况&#xff0c;即所有目标信号均位于第一奈奎斯特区内。 图 显示了另外一种情况&#xff0c;其中采样信号频带局限于第一奈奎斯…

碳排放预测模型 | Python实现基于LR线性回归的碳排放预测模型

文章目录 效果一览文章概述研究内容源码设计参考资料效果一览 文章概述 碳排放预测模型 | Python实现基于LR线性回归的碳排放预测模型 研究内容 碳排放被认为是全球变暖的最主要原因之一。 该项目旨在提供各国碳排放未来趋势的概述以及未来十年的全球趋势预测。 其方法是分析这…

【前端】导航栏html(ul+li)/css/js(jq)

引入jq <script src"https://cdn.staticfile.org/jquery/1.10.2/jquery.min.js"></script> css代码 <style>ul {list-style: none;margin: 0;padding: 0;}li {cursor: pointer;}.color-white {color: #FFFFFF !important;background-color: rgb…

9.用python写网络爬虫,完结

前言 这是python网络爬虫的最后一篇给大家做个总结&#xff0c;且看且珍惜把&#xff01; 截止到目前&#xff0c; 前几章本书介绍的爬虫技术都应用于一个定制网站&#xff0c;这样可以帮助我们更加专注于学习特定技巧。而在本章中&#xff0c;我们将分析几个真实网站&#xff…

桥接模式(Bridge)

定义 桥接是一种结构型设计模式&#xff0c;可将一个大类或一系列紧密相关的类拆分为抽象和实现两个独立的层次结构&#xff0c;从而能在开发时分别使用。 前言 1. 问题 假如你有一个几何形状&#xff08;Shape&#xff09;类&#xff0c; 从它能扩展出两个子类&#xff1a…

UE5.1.1 C++从0开始(15.作业4个人作业分享)

教程链接&#xff1a;https://www.bilibili.com/video/BV1nU4y1X7iQ 好吧这个作业应该是之前写的&#xff0c;但是我发现我没写&#xff0c;后面我又回去自己写了一遍再看代码&#xff0c;感觉上大差不差&#xff0c;各位可以看着我的和老师的还有自己的对比下。 SBTService_…

8.10 TCP是如何实现可靠传输的

目录 TCP 最主要的特点 面向流的概念 Socket 有多种不同的意思 TCP是如何实现可靠传输的&#xff1f; A 如何知道 B 是否正确收到了 M1 呢&#xff1f; 确认丢失 确认迟到 连续 ARQ 协议 累计确认 TCP报文段的首部格式 TCP 最主要的特点 TCP 是面向连接的运输层协议&a…

自动化测试验证码tesseract安装以及python联调

前提 经常会遇到登录系统时候需要输入动态验证码的情况&#xff0c;但是自动化如何识别图片然后登陆系统&#xff1f; 需要用到pytesseract识别验证码图片以及PIL图像处理方法 import pytesseract from PIL import Image, ImageEnhance1、tesseract安装 OCR&#xff0c;即O…

高速电路设计系列分享-熟悉JESD204B(中)

目录 概要 整体架构流程 技术名词解释 技术细节 1.数据链路层 小结 概要 提示&#xff1a;这里可以添加技术概要 随着高速ADC跨入GSPS范围&#xff0c;与FPGA(定制ASIC)进行数据传输的首选接口协JESD204B。为了捕捉频率范围更高的RF频谱&#xff0c;需要宽带RFADC。在其推动下…