强化学习Agent系列(二)——PyGame虚拟环境创建与Python 贪吃蛇Agent制作实战教学

文章目录

  • 一、前言
  • 二、gymnasium 简单虚拟环境创建
    • 1、gymnasium介绍
    • 2、gymnasium 贪吃蛇简单示例
  • 三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent
    • 1、虚拟环境
    • 2、虚拟环境注册
    • 3、训练程序
    • 4、模型测试
  • 三、卷积虚拟环境
    • 1、卷积神经网络虚拟环境
    • 2、训练代码

一、前言

大家好,未来的开发者们请上座
随着人工智能的发展,强化学习基本会再次来到人们眼前,遂想制作一下相关的教程。强化学习第一步基本离不开虚拟环境的搭建,下面用大家耳熟能详的贪吃蛇游戏为基础,制作一个Agent,完成对这个游戏的绝杀。
万里长城第二步:用python开发贪吃蛇智能体****加粗样式

二、gymnasium 简单虚拟环境创建

1、gymnasium介绍

gymnasium(此前称为gym)是一个由 OpenAI 开发的 Python 库,用于开发和比较强化学习算法。它提供了一组丰富的环境,模拟了各种任务,包括但不限于经典的控制问题、像素级游戏、机器人模拟等。

以下是gymnasium库的一些主要特点:

  1. 环境多样性:gymnasium包含了一系列不同的环境,每个环境都有其独特的观察空间(输入)和动作空间(输出)。这些环境涵盖了从简单的文本控制任务到复杂的三维视觉任务的广泛范围。

  2. 标准化API:gymnasium库提供了一个简单且统一的API来与这些环境交互。这使得研究人员和开发人员可以轻松地用相同的代码测试和比较不同的强化学习算法。

  3. 扩展性:用户可以创建自定义环境并将其集成到gymnasium框架中,这使得库能够适应各种不同的研究需求和应用场景。

  4. 评估标准:gymnasium环境通常包括预定义的评估标准,如累积回报或任务完成时间,这有助于在不同算法间进行公平的比较。

  5. 社区支持:由于gymnasium是由OpenAI推出并得到了强化学习社区的广泛支持,因此有大量的教程、论坛讨论和第三方资源可供学习和参考。

  6. 可视化和监控:gymnasium提供了工具来可视化智能体的性能,并允许监控和记录实验过程,便于分析和调试。

使用gymnasium的基本步骤通常包括:

导入gymnasium库。
创建一个环境实例。
初始化环境。
在一个循环中,根据当前观察值选择动作,执行动作,并接收环境的反馈(新的观察值、奖励、完成状态等)。
结束实验并关闭环境。

2、gymnasium 贪吃蛇简单示例

下面是贪吃蛇虚拟环境的一个简单的示例
在本次示例中,暂未进行任何训练。一切行为主要是从状态空间中随机抽取一个动作并执行。
下面是gymnasium创建的虚拟环境的三个核心函数介绍

  • reset(): 这个函数用于重置环境到初始状态,并返回初始状态的观测值。在开始每个新的episode时,通常会调用这个函数来初始化环境。

  • step(action): 这个函数用于让Agent在环境中执行一个动作(action),并返回四个值:观测值(observation),奖励(reward),是否终止(done),以及额外信息(info)。Agent根据环境返回的信息来决定下一步的动作。

  • render(): 这个函数用于在屏幕上渲染当前环境的状态,通常用于可视化环境以便观察Agent的行为。不是所有的环境都支持渲染,具体取决于环境的实现。

下面是具体的示例代码

import timeimport pygame
import sys
import random
import numpy as np
import gymnasium as gym
class SnakeEnv(gym.Env):def __init__(self):super().__init__()# 初始化Pygamepygame.init()# 屏幕宽高self.SCREEN_WIDTH=240self.SCREEN_HEIGHT=240#蛇的方块大小self.snakeCell=10# 创建窗口self.screen = pygame.display.set_mode((self.SCREEN_WIDTH,self.SCREEN_HEIGHT))pygame.display.set_caption('Snake_Game')self.action_space=gym.spaces.Discrete(4) #动作空间为4self.observation_space=gym.spaces.Box(low=0,high=7,shape=(self.SCREEN_WIDTH,self.SCREEN_HEIGHT),dtype=np.uint8)# 重启def reset(self):"""重置蛇和食物的位置"""# 蛇的初始位置self.snake_head=[100,50]self.snake_body=[[100,50],[100-self.snakeCell,50],[100-self.snakeCell*2,50]]self.len=3# 食物的初始位置self.food_pos=[random.randint(1,self.SCREEN_WIDTH//10-1)*10,random.randint(1,self.SCREEN_HEIGHT//10-1)*10]return self._get_observation()# 根据当前状态 和action 执行动作def step(self,action):# 定义动作到方向的映射directionDict={'LEFT':[1,0],'RIGHT':[-1,0],'UP':[0,-1],'DOWN':[0,1]}action_to_direction = {0: "UP",1: "DOWN",2: "LEFT",3: "RIGHT"}directionTarget=action_to_direction[action]nextPosDelay=np.array(directionDict[directionTarget])*self.snakeCell #加的位置self.snake_head=list(np.array(self.snake_body[0])+nextPosDelay)if self.snake_head in self.snake_body:return self._get_observation(), 0, True, False, {}self.snake_body.insert(0,self.snake_head)# 如果是吃到食物,就重新刷新果子,同时长度 +1if self.food_pos == self.snake_head:self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10)) * 10,random.randrange(1, (self.SCREEN_HEIGHT // 10)) * 10]self.len+=1# 弹出while self.len<len(self.snake_body):self.snake_body.pop()# 奖励reward,done=self._get_reward()truncated = Truereturn self._get_observation(), reward, truncated, done, {}# 渲染def render(self,mode="human"):# 实现可视化screen = self.screen# 颜色定义WHITE= (255,255,255)GREEN = (0,255,0)RED = (255,0,0)# 清空屏幕screen.fill(WHITE)# 画蛇和食物for pos in self.snake_body:pygame.draw.rect(screen,GREEN,pygame.Rect(pos[0],pos[1],self.snakeCell,self.snakeCell))pygame.draw.rect(screen,RED,pygame.Rect(self.food_pos[0],self.food_pos[1],self.snakeCell,self.snakeCell))pygame.display.update()# 获取奖励def _get_reward(self):# 计算奖励reward = 0done = False# 检查蛇是否吃到食物if self.snake_head:reward+=10# 检查蛇是否撞到墙壁或自身head=self.snake_headif head[0]<0 or head[0]>self.SCREEN_WIDTH-10 or head[1]<0 or head[1]>self.SCREEN_HEIGHT-10:reward = -10done = Truereturn reward,done# 获取当前观察空间def _get_observation(self):# 获取窗口内容作为观察值observation = pygame.display.get_surface()# 将观察值调整为指定的宽度和高度# observation = pygame.transform.scale(observation, (self.SCREEN_WIDTH, self.SCREEN_HEIGHT))return observationdef main():snakeEnv=Snake()snakeEnv.reset()done=Falsewhile not done:# 获取事件for event in pygame.event.get():# 处理退出事件if event.type == pygame.QUIT:pygame.quit()done = True# 从动作空间随机获取一个动作action= snakeEnv.action_space.sample()screen, reward, truncated, done,_=snakeEnv.step(action)snakeEnv.render()time.sleep(0.03)
if __name__=="__main__":main()

程序运行截图:
在这里插入图片描述

三、基于gymnasium创建的虚拟环境训练贪吃蛇Agent

在上一步中,你已经创建出你要的虚拟环境了,现在让我们在这个创建好的环境中进行训练吧!

1、虚拟环境

SnakeEnv2.py

import timeimport pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optionalclass SnakeEnv(gym.Env):metadata = {"render_modes": ["human", "rgb_array"],"render_fps": 30,}def __init__(self, render_mode="human"):super().__init__()# 初始化Pygamepygame.init()# 屏幕宽高self.SCREEN_WIDTH = 100self.SCREEN_HEIGHT = 100# 蛇的方块大小self.snakeCell = 10# 游戏速度self.speed = 12self.clock = pygame.time.Clock()# 创建窗口self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))pygame.display.set_caption('Snake_Game')self.render_mode = render_modeself.action_space = gym.spaces.Discrete(4)  # 动作空间为4self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT),dtype=np.float32)# 初始化蛇和食物的位置等属性# ...self.num_timesepts = 0# 重启def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):"""重置蛇和食物的位置"""super().reset(seed=seed)self.curStep = 0  # 步数# 蛇的初始位置self.snake_next = [60, 50]self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]self.len = 3# 食物的初始位置self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]info = {}#return self._get_observation(), info# 根据当前状态 和action 执行动作def step(self, action):self.num_timesepts += 1  # 步骤统计加1# 定义动作到方向的映射directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}action_to_direction = {0: "UP",1: "DOWN",2: "LEFT",3: "RIGHT"}directionTarget = action_to_direction[action]nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置# 输入action 获取到 snake_next 下一步self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)if self.snake_next == self.snake_body[1]:return self._get_observation(), -0.5, False, False, {}# 奖励reward, terminated = self._get_reward()# 如果是吃到食物,就重新刷新果子,同时长度 +1if self.food_pos == self.snake_next:self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]self.len += 1truncated = Falseinfo = {}# if self.render_mode == "human":if self.render_mode == "human" and self.num_timesepts % 5000 > 4000 and self.num_timesepts > 10000:self.render()for event in pygame.event.get():if event == pygame.QUIT:pygame.quit()sys.exit()if not terminated:self.snake_body.insert(0, self.snake_next)# 弹出while self.len < len(self.snake_body):self.snake_body.pop()return self._get_observation(), reward,terminated,  truncated,{}# 渲染def render(self):# 实现可视化screen = self.screen# 颜色定义WHITE = (255, 255, 255)GREEN = (0, 255, 0)RED = (255, 0, 0)# 清空屏幕screen.fill(WHITE)# 画蛇和食物snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)for i in range(len(self.snake_body)):pos = self.snake_body[len(self.snake_body) - i - 1]color = [int(round(component * snakecolor[i])) for component in GREEN]pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))pygame.display.update()self.clock.tick(self.speed)def GetDic(self,p1,p2):return np.linalg.norm(np.array(p1) - np.array(p2))# 获取奖励def _get_reward(self):# 计算奖励self.curStep += 1  # 步数reward = 0terminated = Falseflag=0# 正向激励# 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励if self.snake_next == self.food_pos:reward += 500 + pow(5, self.len)self.curStep = 0#print(reward)# 负向激励# 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励head = self.snake_nextif head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:reward -= 100 / self.lenterminated = Trueself.curStep = 0# 摸鱼步数超过一定值就开始负向奖励if self.curStep > 100 * self.len:reward -=  1 / self.len# 中向激励if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):reward += 2 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_body[0],self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.else:reward -= 1 / self.len#print(reward * 0.3)if reward<0:#print(reward * 0.2)pass#print(reward* 0.2)return reward * 0.2, terminated# 获取当前观察空间def _get_observation(self):# 返回观察空间,也就是一个二维数组obs = np.zeros((self.SCREEN_WIDTH, self.SCREEN_HEIGHT), dtype=np.float32)obs[tuple(np.transpose(self.snake_body))] = np.linspace(0.8, 0.2, len(self.snake_body), dtype=np.float32)obs[tuple(self.snake_body[0])] = 1.0obs[tuple(self.food_pos)] = -1.0return obsdef main():snakeEnv = SnakeEnv()snakeEnv.reset()done = Falsewhile not done:# 获取事件for event in pygame.event.get():# 处理退出事件if event.type == pygame.QUIT:pygame.quit()done = True# 从动作空间随机获取一个动作action = snakeEnv.action_space.sample()screen, reward, truncated, done, _ = snakeEnv.step(action)snakeEnv.render()if __name__ == "__main__":main()

2、虚拟环境注册

打开当前项目的site-packages
在这里插入图片描述
找到gymnasium
将其SnakeEnv2.py放置如下,并在init.py中添加调用注册函数
在这里插入图片描述
到这里就注册完毕,可以进行训练了

3、训练程序

snake_train.py 具体代码如下

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):if isinstance(initial_value, str):initial_value = float(initial_value)final_value = float(final_value)assert (initial_value > 0.0)def scheduler(progress):return final_value + progress * (initial_value - final_value)return scheduler# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnv-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-2, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)model = PPO("MlpPolicy", env, verbose=1, device="cuda",n_steps=2048,batch_size=512,n_epochs=4,gamma=0.94,learning_rate=lr_schedule,clip_range=clip_range_schedule,)
# 4、训练模型:# Set the save directory
num=1
save_dir="trained_models_mlp"
while True:save_dir = "trained_models_mlp_{}".format(num)if not os.path.exists(save_dir):os.mkdir(save_dir)breakelse:num +=1checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)model.learn(total_timesteps=int(200000),callback=[checkpoint_callback]
)# Restore stdout
sys.stdout = original_stdout# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))

4、模型测试

对训练好的模型进行测试,可以用如下代码

import time
import random
from sb3_contrib import MaskablePPO
from stable_baselines3 import PPO
from snakecnn23 import SnakeEnv
import pygameMODEL_PATH=r'H:\AILab\RL\Snaker2\trained_models_cnn\ppo_snake_final'
# Load the trained model
model = MaskablePPO.load(MODEL_PATH)snakeEnv = SnakeEnv()
for i in range(10):obs,info=snakeEnv.reset()terminated = Falsewhile not terminated:# 获取事件for event in pygame.event.get():if event == pygame.QUIT:pygame.quit()# 从动作空间随机获取一个动作action ,_=  model.predict(obs, action_masks=snakeEnv.get_action_mask())prev_mask = snakeEnv.get_action_mask()action_value=int(action.item())obs, reward,  terminated, truncated, _ = snakeEnv.step(action_value)snakeEnv.render()

三、卷积虚拟环境

上面的是基于多层感知机,上限有限,可能效果不是很好,可以对其进行一点点改进
核心是修改
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),dtype=np.uint8)
和 _get_observation() 观察空间
改完这两个其他的基本不用变

1、卷积神经网络虚拟环境

import timeimport pygame
import sys
import random
import numpy as np
import gymnasium as gym
from typing import Optionalclass SnakeEnv(gym.Env):metadata = {"render_modes": ["human", "rgb_array"],"render_fps": 30,}def __init__(self, render_mode="human"):super().__init__()# 初始化Pygamepygame.init()# 屏幕宽高self.SCREEN_WIDTH = 84self.SCREEN_HEIGHT = 84# 蛇的方块大小self.snakeCell = 7# 游戏速度self.speed = 12self.clock = pygame.time.Clock()# 创建窗口self.screen = pygame.display.set_mode((self.SCREEN_WIDTH, self.SCREEN_HEIGHT))pygame.display.set_caption('Snake_Game')self.render_mode = render_modeself.action_space = gym.spaces.Discrete(4)  # 动作空间为4self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.SCREEN_WIDTH, self.SCREEN_HEIGHT,3),dtype=np.uint8)# 初始化蛇和食物的位置等属性# ...self.num_timesepts = 0# 重启def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):"""重置蛇和食物的位置"""super().reset(seed=seed)self.curStep = 0  # 步数# 蛇的初始位置self.snake_next = [60, 50]self.snake_body = [[60, 50], [60 - self.snakeCell, 50], [60 - self.snakeCell * 2, 50]]self.len = 3# 食物的初始位置self.food_pos = [random.randint(1, self.SCREEN_WIDTH // 10 - 1) * 10,random.randint(1, self.SCREEN_HEIGHT // 10 - 1) * 10]info = {}#return self._get_observation(), info# 根据当前状态 和action 执行动作def step(self, action):self.num_timesepts += 1  # 步骤统计加1# 定义动作到方向的映射directionDict = {'LEFT': [-1, 0], 'RIGHT': [1, 0], 'UP': [0, -1], 'DOWN': [0, 1]}action_to_direction = {0: "UP",1: "DOWN",2: "LEFT",3: "RIGHT"}directionTarget = action_to_direction[action]nextPosDelay = np.array(directionDict[directionTarget]) * self.snakeCell  # 加的位置# 输入action 获取到 snake_next 下一步self.snake_next = list(np.array(self.snake_body[0]) + nextPosDelay)if self.snake_next == self.snake_body[1]:return self._get_observation(), -0.5, False, False, {}# 奖励reward, terminated = self._get_reward(action)# 如果是吃到食物,就重新刷新果子,同时长度 +1if self.food_pos == self.snake_next:self.food_pos = [random.randrange(1, (self.SCREEN_WIDTH // 10-1)) * 10,random.randrange(1, (self.SCREEN_HEIGHT // 10-1)) * 10]self.len += 1truncated = Falseinfo = {}# if self.render_mode == "human":if self.render_mode == "human" and self.num_timesepts % 5000 > 4500 and self.num_timesepts > 10000:self.render()for event in pygame.event.get():if event == pygame.QUIT:pygame.quit()sys.exit()if not terminated:self.snake_body.insert(0, self.snake_next)# 弹出while self.len < len(self.snake_body):self.snake_body.pop()return self._get_observation(), reward,terminated,  truncated,{}# 渲染def render(self):# 实现可视化screen = self.screen# 颜色定义WHITE = (255, 255, 255)GREEN = (0, 255, 0)RED = (255, 0, 0)# 清空屏幕screen.fill(WHITE)# 画蛇和食物snakecolor = np.linspace(0.9, 0.5, len(self.snake_body), dtype=np.float32)for i in range(len(self.snake_body)):pos = self.snake_body[len(self.snake_body) - i - 1]color = [int(round(component * snakecolor[i])) for component in GREEN]pygame.draw.rect(screen, color, pygame.Rect(pos[0], pos[1], self.snakeCell, self.snakeCell))pygame.draw.rect(screen, RED, pygame.Rect(self.food_pos[0], self.food_pos[1], self.snakeCell, self.snakeCell))pygame.display.update()self.clock.tick(self.speed)def GetDic(self,p1,p2):return np.linalg.norm(np.array(p1) - np.array(p2))# 获取奖励def _get_reward(self,action):# 计算奖励self.curStep += 1  # 步数reward = 0terminated = Falseflag=0# 正向激励# 检查蛇是否吃到食物 ,吃到食物,就开始猛猛奖励if self.snake_next == self.food_pos:reward += 400 + pow(5, self.len)self.curStep = 0# print(reward)#print(action)# print(self.snake_body,self.food_pos,self.snake_next)# 负向激励# 检查蛇是否撞到墙壁或自身,游戏结束就负向奖励head = self.snake_nextif head[0] < 0 or head[0] > self.SCREEN_WIDTH-10  or head[1] < 0 or head[1] > self.SCREEN_HEIGHT-10  or self.snake_next in self.snake_body or self.curStep>500:reward -= 200 / self.lenterminated = Trueself.curStep = 0# 摸鱼步数超过一定值就开始负向奖励if self.curStep > 250 * self.len:reward -=  1 / self.len# 中向激励if  self.GetDic(self.snake_next,self.food_pos)< self.GetDic(self.snake_body[0],self.food_pos):reward += 4 / self.len * (self.SCREEN_WIDTH-self.GetDic(self.snake_next,self.food_pos)) /self.SCREEN_WIDTH # No upper limit might enable the agent to master shorter scenario faster and more firmly.elif self.curStep>50 and self.GetDic(self.snake_next,self.food_pos)>= self.GetDic(self.snake_body[0],self.food_pos):reward -= 2 / self.len * self.GetDic(self.snake_next,self.food_pos) /self.SCREEN_WIDTH#print(reward * 0.3)if reward<0:#print(reward * 0.2)pass# print(reward* 0.2)return reward * 0.2, terminated# 获取当前观察空间def _get_observation(self):obs = np.zeros((self.SCREEN_WIDTH//self.snakeCell, self.SCREEN_HEIGHT//self.snakeCell), dtype=np.uint8)# Set the snake body to gray with linearly decreasing intensity from head to tail.newsnake=np.array(self.snake_body)//7obs[tuple(np.transpose(newsnake))] = np.linspace(200, 50, len(newsnake), dtype=np.uint8)# Stack single layer into 3-channel-image.obs = np.stack((obs, obs, obs), axis=-1)# Set the snake head to green and the tail to blueobs[tuple(newsnake[0])] = [0, 255, 0]obs[tuple(newsnake[-1])] = [255, 0, 0]# Set the food to redobs[np.array(self.food_pos)//7] = [0, 0, 255]# Enlarge the observation to 84x84obs = np.repeat(np.repeat(obs, self.snakeCell, axis=0), self.snakeCell, axis=1)return obsdef main():snakeEnv = SnakeEnv()snakeEnv.reset()done = Falsewhile not done:# 获取事件for event in pygame.event.get():# 处理退出事件if event.type == pygame.QUIT:pygame.quit()done = True# 从动作空间随机获取一个动作action = snakeEnv.action_space.sample()screen, reward, truncated, done, _ = snakeEnv.step(action)snakeEnv.render()if __name__ == "__main__":main()

2、训练代码

核心是修改算法名,之前用MlpPolicy,现在改为CnnPolicy
其余不变

# 1、导入必要的库并创建环境:
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
import sys
# Linear schedulerfrom stable_baselines3 import PPO
def linear_schedule(initial_value, final_value=0.0):if isinstance(initial_value, str):initial_value = float(initial_value)final_value = float(final_value)assert (initial_value > 0.0)def scheduler(progress):return final_value + progress * (initial_value - final_value)return scheduler
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
# 2、创建环境,例如 CartPole
env = gym.make('SnakeEnvcnn-test',render_mode="human")
# 3、创建 PPO 模型并指定环境:
lr_schedule = linear_schedule(2.5e-3, 2.5e-6)
clip_range_schedule = linear_schedule(0.15, 0.025)model = PPO(  "CnnPolicy",env,device="cuda",verbose=1,n_steps=2048,batch_size=512,n_epochs=4,gamma=0.94,learning_rate=lr_schedule,clip_range=clip_range_schedule,tensorboard_log=LOG_DIR)
# 4、训练模型:# Set the save directory
num=1
save_dir="trained_models_cnn"
while True:save_dir = "trained_models_cnn_{}".format(num)if not os.path.exists(save_dir):os.mkdir(save_dir)breakelse:num +=1checkpoint_interval = 30000  # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_snake")# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
print('开始训练'+save_dir)model.learn(total_timesteps=int(200000),callback=[checkpoint_callback]
)# Restore stdout
sys.stdout = original_stdout# Save the final model
model.save(os.path.join(save_dir, "ppo_snake_final.zip"))# 5、测试训练好的模型:
obs = env.reset()for i in range(1000):action, _states = model.predict(obs, deterministic=True)observation, reward, terminated, truncated, info = env.step(action)env.render()if terminated:obs = env.reset()
env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/497728.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++重点---STL简介

顾得泉&#xff1a;个人主页 个人专栏&#xff1a;《Linux操作系统》 《C从入门到精通》 《LeedCode刷题》 键盘敲烂&#xff0c;年薪百万&#xff01; 一、STL简介 STL&#xff08;Standard Template Library&#xff09;是C标准库中的一个重要组成部分&#xff0c;它提供了…

Flutter Dio进阶:使用Flutter Dio拦截器实现高效的API请求管理和身份验证刷新

Flutter笔记 使用Flutter Dio拦截器实现高效的API请求管理和身份验证刷新 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article…

js 面试 什么是WebSockets?HTTP和HTTPS有什么不同?web worker是什么?

概念&#xff1a; webSocket 是一种在客户端和服务端之间建立持久连接的协议&#xff0c;它提供全双工通信通道&#xff0c;是服务器可以主动向客户端推送数据&#xff0c;同时也可以接受客户端发送的数据。 1 webSocket与https区别&#xff1f; 在网络通信中&#xff0c;We…

websocket在django中的运用

14-2 聊天室实现思路&#xff1a;轮训、长轮训、websocket_哔哩哔哩_bilibili 参考大佬的B站学习笔记 https://www.cnblogs.com/wupeiqi/p/6558766.html 参考博客 https://www.cnblogs.com/wupeiqi/articles/9593858.html 参考博客 http协议: 是短连接&#xff0c;无状态…

C#,双向链表(Doubly Linked List)归并排序(Merge Sort)算法与源代码

1 双向链表 双向链表也叫双链表&#xff0c;是链表的一种&#xff0c;它的每个数据结点中都有两个指针&#xff0c;分别指向直接后继和直接前驱。所以&#xff0c;从双向链表中的任意一个结点开始&#xff0c;都可以很方便地访问它的前驱结点和后继结点。一般我们都构造双向循…

vscode 设置打开终端的默认工作目录/路径

vscode 设置打开终端的默认工作目录/路径** 文章目录 vscode 设置打开终端的默认工作目录/路径**打开vscode&#xff0c;打开设置UI 或是设置JSON文件&#xff0c;找到相关设置项方式1&#xff1a;通过打开settings.json的UI界面 设置:方式2&#xff1a;通过打开设置settings.j…

linux系统-----------搭建LNMP 架构

PHP(Hypertext Preprocessor 超文本预处理器)是通用服务器端脚本编程语言&#xff0c;主要用于web开发实现动态web页面&#xff0c;也是最早实现将脚本嵌入HTML源码文档中的服务器端脚本语言之一。同时&#xff0c;php还提供了一个命令行接口&#xff0c;因此&#xff0c;其也可…

搭建LNMP环境并配置个人博客系统

LNMP是Linux&#xff08;操作系统&#xff09;、Nginx&#xff08;Web服务器&#xff09;、MySQL&#xff08;数据库&#xff09;和PHP&#xff08;脚本解释器&#xff09;的组合&#xff0c;常用于部署高性能的动态网站&#xff0c;如WordPress等博客平台 一、安装Linux操作系…

计算机专业大学四年应该如何规划(Java方向)

计算机专业的学生&#xff0c;如何在大学四年内提高自己的竞争力&#xff0c;毕业之后直接进大厂工作&#xff1f; 以下将从大学四年计算机专业的学习规划、课程设置、能力提升、参考书籍等方面&#xff0c;为同学们提供一些建议和指导。 大一&#xff1a; 主攻技能学习并且达…

K8S存储卷与PV,PVC

一、前言 Kubernetes&#xff08;K8s&#xff09;中的存储卷是用于在容器之间共享数据的一种机制。存储卷可以在多个Pod之间共享数据&#xff0c;并且可以保持数据的持久性&#xff0c;即使Pod被重新调度或者删除&#xff0c;数据也不会丢失。 Kubernetes支持多种类型的存储卷…

使用全局事件总线实现任意组件间的通讯

本文以vue2中爷孙组件通讯为例&#xff0c;需求是点击孙组件的按钮&#xff0c;实现关闭爷组件的弹窗。 全局事件总线是通过Vue实例的事件系统来实现组件之间的通讯&#xff0c;可以方便地在任何组件中进行事件的触发和监听。 以下是使用全局事件总线实现爷孙组件通讯的步骤&a…

alembic

alembic是sqlalchemy的作者开发的。 用来做OMR模型与数据库的迁移与映射。 第一个&#xff0c;alembic的所有命令都是以alembic开头 第二&#xff0c;alembic的迁移文件也是通过版本进行控制的。首先&#xff0c;通过pip install alembic进行安装。以下将解释alembic的用法 方…