强化学习A3C算法

强化学习A3C算法

效果:
在这里插入图片描述

a3c.py

import  matplotlib
from    matplotlib import pyplot as plt
matplotlib.rcParams['font.size'] = 18
matplotlib.rcParams['figure.titlesize'] = 18
matplotlib.rcParams['figure.figsize'] = [9, 7]
matplotlib.rcParams['font.family'] = ['KaiTi']
matplotlib.rcParams['axes.unicode_minus']=Falseplt.figure()import os
import  threading
import  gym
import  multiprocessing
import  numpy as np
from    queue import Queueimport  tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import layers,optimizers,losses# os.environ["CUDA_VISIBLE_DEVICES"] = "0" #使用GPU
# 按需占用GPU显存
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:try:# 设置 GPU 显存占用为按需分配,增长式for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)except RuntimeError as e :# 异常处理print(e)SEED_NUM = 1234
tf.random.set_seed(SEED_NUM)
np.random.seed(SEED_NUM)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')# 互斥锁,用于线程同步数据
g_mutex = threading.Lock()class ActorCritic(keras.Model):""" Actor-Critic模型 """def __init__(self, state_size, action_size):super(ActorCritic, self).__init__()self.state_size = state_size # 状态向量长度self.action_size = action_size # 动作数量# 策略网络Actorself.dense1 = layers.Dense(128, activation='relu')self.policy_logits = layers.Dense(action_size)# V网络Criticself.dense2 = layers.Dense(128, activation='relu')self.values = layers.Dense(1)def call(self, inputs):# 获得策略分布Pi(a|s)x = self.dense1(inputs)logits = self.policy_logits(x)# 获得v(s)v = self.dense2(inputs)values = self.values(v)return logits, valuesdef record(episode,episode_reward,worker_idx,global_ep_reward,result_queue,total_loss,num_steps):""" 统计工具函数  """if global_ep_reward == 0:global_ep_reward = episode_rewardelse:global_ep_reward = global_ep_reward * 0.99 + episode_reward * 0.01print(f"{episode} | "f"Average Reward: {int(global_ep_reward)} | "f"Episode Reward: {int(episode_reward)} | "f"Loss: {int(total_loss / float(num_steps) * 1000) / 1000} | "f"Steps: {num_steps} | "f"Worker: {worker_idx}")result_queue.put(global_ep_reward) # 保存回报,传给主线程return global_ep_rewardclass Memory:""" 数据 """def __init__(self):self.states = []self.actions = []self.rewards = []def store(self, state, action, reward):self.states.append(state)self.actions.append(action)self.rewards.append(reward)def clear(self):self.states = []self.actions = []self.rewards = []class Agent:""" 智能体,包含了中央参数网络server """def __init__(self):# 服务模型优化器,client不需要,直接从server拉取参数self.opt = optimizers.Adam(1e-3)# 服务模型(状态向量,动作数量)self.server = ActorCritic(4, 2) self.server(tf.random.normal((2, 4)))def train(self):# 共享队列,线程安全,不需要加锁同步res_queue = Queue() # 根据cpu线程数量创建多线程Workerworkers = [Worker(self.server, self.opt, res_queue, i)for i in range(10)] #multiprocessing.cpu_count()# 启动多线程Workerfor i, worker in enumerate(workers):print("Starting worker {}".format(i))worker.start()# 统计并绘制总回报曲线returns = []while True:reward = res_queue.get()if reward is not None:returns.append(reward)else: # 结束标志break# 等待线程退出 [w.join() for w in workers] print(returns)plt.figure()plt.plot(np.arange(len(returns)), returns)# plt.plot(np.arange(len(moving_average_rewards)), np.array(moving_average_rewards), 's')plt.xlabel('回合数')plt.ylabel('总回报')plt.savefig('a3c-tf-cartpole.svg')class Worker(threading.Thread): def __init__(self,  server, opt, result_queue, idx):super(Worker, self).__init__()self.result_queue = result_queue # 共享队列self.server = server # 服务模型self.opt = opt # 服务优化器self.client = ActorCritic(4, 2) # 线程私有网络self.worker_idx = idx # 线程idself.env = gym.make('CartPole-v1').unwrapped #私有环境self.ep_loss = 0.0def run(self): # 每个worker自己维护一个memorymem = Memory() # 1回合最大500步for epi_counter in range(500): # 复位client游戏状态current_state,info = self.env.reset(seed=SEED_NUM) mem.clear()ep_reward = 0.0ep_steps = 0  done = Falsewhile not done:# 输入AC网络状态获得Pi(a|s),未经softmaxlogits, _ = self.client(tf.constant(current_state[None, :],dtype=tf.float32))# 归一化概率probs = tf.nn.softmax(logits)# 随机采样动作action = np.random.choice(2, p=probs.numpy()[0])# 交互 new_state, reward, done, truncated, info = self.env.step(action) # 累加奖励ep_reward += reward # 记录mem.store(current_state, action, reward) # 计算回合步数ep_steps += 1# 刷新状态 current_state = new_state # 最长500步或者规则结束,回合结束if ep_steps >= 500 or done: # 计算当前client上的误差with tf.GradientTape() as tape:total_loss = self.compute_loss(done, new_state, mem) # 计算梯度grads = tape.gradient(total_loss, self.client.trainable_weights)# 梯度提交到server,在server上更新梯度global g_mutexg_mutex.acquire()self.opt.apply_gradients(zip(grads,self.server.trainable_weights))g_mutex.release()# 从server拉取最新的梯度g_mutex.acquire()self.client.set_weights(self.server.get_weights())g_mutex.release()# 清空Memory mem.clear() # 统计此回合回报self.result_queue.put(ep_reward)print(f"thread worker_idx : {self.worker_idx}, episode reward : {ep_reward}")break# 线程结束self.result_queue.put(None) def compute_loss(self,done,new_state,memory,gamma=0.99):if done:reward_sum = 0. # 终止状态的v(终止)=0else:# 私有网络根据新状态计算回报reward_sum = self.client(tf.constant(new_state[None, :],dtype=tf.float32))[-1].numpy()[0]# 统计折扣回报discounted_rewards = []for reward in memory.rewards[::-1]:  # reverse buffer rreward_sum = reward + gamma * reward_sumdiscounted_rewards.append(reward_sum)discounted_rewards.reverse()# 输入AC网络环境状态获取 Pi(a|s) v(s) 预测值logits, values = self.client(tf.constant(np.vstack(memory.states), dtype=tf.float32))# 计算advantage = R() - v(s) = 真实值 - 预测值advantage = tf.constant(np.array(discounted_rewards)[:, None], dtype=tf.float32) - values# Critic网络损失value_loss = advantage ** 2# 归一化概率预测值Pi(a|s)policy = tf.nn.softmax(logits)# 真实动作a 概率预测值Pi(a|s) 交叉熵policy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=memory.actions, logits=logits)# 计算策略网络损失时,并不会计算V网络policy_loss = policy_loss * tf.stop_gradient(advantage)# 动作概率测值Pi(a|s) 熵entropy = tf.nn.softmax_cross_entropy_with_logits(labels=policy, logits=logits)policy_loss = policy_loss - 0.01 * entropy# 聚合各个误差total_loss = tf.reduce_mean((0.5 * value_loss + policy_loss))return total_lossif __name__ == '__main__':master = Agent()master.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/70393.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Streamlit 讲解专栏(九):深入探索布局和容器

文章目录 1 前言2 st.sidebar - 在侧边栏增添交互元素2.1 将交互元素添加至侧边栏2.2 示例:在侧边栏添加选择框和单选按钮2.3 特殊元素的注意事项 3 st.columns - 并排布局多元素容器3.1 插入并排布局的容器3.2 嵌套限制 4 st.tabs - 以选项卡形式布局多元素容器4.1…

Keepalived源码安装

文章目录 Keepalived源码安装安装准备缺少OpenSSL解决方法 Keepalived 源码安装 安装准备 tar zxf keepalived-2.2.8.tar.gz /root/ ll drwxrwxr-x. 10 1000 1000 4096 Aug 9 18:29 keepalived-2.2.8 #进入目录执行以下命令查看帮助 ./configure --help #重要编译参数 -…

【爬虫】P1 对目标网站的背景调研(robot.txt,advanced_search,builtwith,whois)

对目标网站的背景调研 检查 robot.txt估算网站大小识别网站所用技术寻找网站的所有者 检查 robot.txt 目的: 大多数的网站都会包含 robot.txt 文件。该文件用于指出使用爬虫爬取网站时有哪些限制。而我们通过读 robot.txt 文件,亦可以最小化爬虫被封禁的…

E8—Aurora 64/66B ip实现GTX与GTY的40G通信2023-08-12

1. 场景 要在贴有K7系列FPGA芯片的板子和贴有KU系列FPGA芯片的板子之间通过光模块光纤QSFP实现40G的高速通信。可以选择的方式有多种,但本质的方案就一种,即实现4路GTX与GTY之间的通信。可以选择8B/10B编码通过GT IP核实现,而不能通过Aurora…

数据结构--拓扑排序

数据结构–拓扑排序 AOV⽹ A O V ⽹ \color{red}AOV⽹ AOV⽹(Activity On Vertex NetWork&#xff0c;⽤顶点表示活动的⽹)&#xff1a; ⽤ D A G 图 \color{red}DAG图 DAG图&#xff08;有向⽆环图&#xff09;表示⼀个⼯程。顶点表示活动&#xff0c;有向边 < V i , V j …

Actuator微服务信息完善-Eureka—SpringCloud(版)微服务学习教程(11)

一、Actuator是什么&#xff1f; Actuator是Springboot提供的用来对应用系统进行自省和监控的功能模块&#xff0c;借助于Actuator开发者可以很方便地对应用系统某些监控指标进行查看、统计等。 在Springboot中使用Actuator监控非常简单&#xff0c;只需要在工程POM文件中引入…

前端新手学习路线

文章目录 前端学习路线&#xff01;特点符号表大纲前言 - 学编程需要的特质一、前端入门⭐️ 开发工具浏览器编辑器文档笔记 ⭐️ HTML⭐️ CSS⭐️ JavaScript✅ ES6 特性 二、巩固基础前端基础知识计算机基础✅ 算法和数据结构✅ 计算机网络✅ 操作系统 软件开发基础✅ 设计模…

RabbitMq-1基础概念

RabbitMq-----分布式中的一种通信手段 1. MQ的基本概念&#xff08;message queue,消息队列&#xff09; mq:消息队列&#xff0c;存储消息的中间件 分布式系统通信的两种方式&#xff1a;直接远程调用&#xff0c;借助第三方完成间接通信 消息的发送方是生产者&#xff0c…

Mysql_5.7下载安装与配置基础操作教程

目录 一、Mysql57下载与安装 二、尝试登录Mysql 三、配置Mysql环境变量 一、Mysql57下载与安装 首先&#xff0c;进入Mysql下载官网&#xff1a;MySQL Community Downloads 随后&#xff0c;选择版本5.7.43&#xff0c;系统选择Windows&#xff0c;随后下方会出现两个下载选…

图片懒加载指令-vueUse

基于Vue的自定义钩子集合 https://vueuse.org/ 适用于Vue 3和Vue2.7版本之后 基于vueUse定义懒加载指令

ZooKeeper的应用场景(命名服务、分布式协调通知)

3 命名服务 命名服务(NameService)也是分布式系统中比较常见的一类场景&#xff0c;在《Java网络高级编程》一书中提到&#xff0c;命名服务是分布式系统最基本的公共服务之一。在分布式系统中&#xff0c;被命名的实体通常可以是集群中的机器、提供的服务地址或远程对象等一这…

优思学院|五大工具:APQP、FMEA、MSA、SPC、PPAP

在现代制造业中&#xff0c;质量是企业成功的关键之一。为了确保产品和过程的质量&#xff0c;需要采用一系列有效的工具和方法。APQP、FMEA、MSA、SPC和PPAP被认定为质量管理体系的五大核心工具&#xff0c;这些工具不仅在汽车行业中得到广泛应用&#xff0c;还被其他制造领域…