创建自定义 gym env 教程

gym-0.26.1
pygame-2.1.2
自定义环境 GridWolrdEnv

教程参考 官网自定义环境 ,我把一些可能有疑惑的地方讲解下。

首先整体文件结构, 这里省略了wrappers

gym-examples/main.py    # 这个是测试自定义的环境setup.py  gym_examples/__init__.pyenvs/__init__.pygrid_world.py

先讲几个基础知识

  1. init.py 的作用
    最主要的作用是: 将所在的目录标记为 Python 包的一部分。
    在 Python 中,一个包是一个包含模块(即 .py 文件)的目录,
    init.py 文件表明这个目录可以被视为一个包,允许从这个目录导入模块或其他包。
  2. class里以 _ 开头的变量,说明是私有变量,以 _ 开头方法被视为私有方法。(默认的规定,但不强制)
  3. 实例的变量的初始化可以不在 __init__函数里,比如在这里有些变量就是 在 reset 函数里初始化。

grid_world.py

原版的英文注释已经很清楚了,所以我们这里就是沿用就好了

import gym
from gym import spaces
import pygame
import numpy as npclass GridWorldEnv(gym.Env):metadata = {"render_modes": ["human", "rgb_array"], "render_fps":4}def __init__(self, render_mode=None, size=5):super().__init__()self.size = size   # The size of the square gridself.window_size = 512  # The size of the PyGame window# Observations are dictionaries with the agent's and the target's location.# Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).self.observation_space = spaces.Dict({"agent": spaces.Box(0, size - 1, shape=(2,), dtype=int),"target": spaces.Box(0, size - 1, shape=(2,), dtype=int)})# We have 4 actions, corresponding to "right", "up", "left", "down"self.action_space = spaces.Discrete(4)"""The following dictionary maps abstract actions from `self.action_space` to the direction we will walk in if that action is taken.I.e. 0 corresponds to "right", 1 to "up" etc."""self._action_to_direction = {0: np.array([1, 0]),1: np.array([0, 1]),2: np.array([-1, 0]),3: np.array([0, -1])}assert render_mode is None or render_mode in self.metadata["render_modes"]self.render_mode = render_mode"""If human-rendering is used, `self.window` will be a referenceto the window that we draw to. `self.clock` will be a clock that is usedto ensure that the environment is rendered at the correct framerate inhuman-mode. They will remain `None` until human-mode is used for thefirst time."""self.window = Noneself.clock = Nonedef _get_obs(self):return {"agent": self._agent_location, "target": self._target_location}def _get_info(self):return {"distance": np.linalg.norm(self._agent_location - self._target_location, ord=1)}def reset(self, seed=None, options=None):# We need the following line to seed self.np_randomsuper().reset(seed=seed)# Choose the agent's location uniformly at randomself._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)# We will sample the target's location randomly until it does not coincide with the agent's locationself._target_location = self._agent_locationwhile np.array_equal(self._target_location, self._agent_location):self._target_location = self.np_random.integers(0, self.size, size=2, dtype=int)observation = self._get_obs()info = self._get_info()if self.render_mode == "human":self._render_frame()return observation, infodef step(self, action):# Map the action (element of {0,1,2,3}) to the direction we walk indirection = self._action_to_direction[action]# We use `np.clip` to make sure we don't leave the gridself._agent_location = np.clip(self._agent_location + direction, 0, self.size - 1)# An episode is done iff the agent has reached the targetterminated = np.array_equal(self._agent_location, self._target_location)reward = 1 if terminated else 0observation = self._get_obs()info = self._get_info()if self.render_mode == "human":self._render_frame()# truncated = Falsereturn observation, reward, terminated, False, infodef render(self):if self.render_mode == "rgb_array":return self._render_frame()def _render_frame(self):if self.window is None and self.render_mode == "human":pygame.init()pygame.display.init()self.window = pygame.display.set_mode((self.window_size, self.window_size))if self.clock is None and self.render_mode == "human":self.clock = pygame.time.Clock()canvas = pygame.Surface((self.window_size, self.window_size))canvas.fill((255, 255, 255))pix_square_size = (self.window_size / self.size) # The size of a single grid square in pixels# First we draw the targetpygame.draw.rect(canvas,(255, 0, 0),pygame.Rect(pix_square_size * self._target_location,(pix_square_size, pix_square_size),))# Now we draw the agentpygame.draw.circle(canvas,(0, 0, 255),(self._agent_location + 0.5) * pix_square_size,pix_square_size / 3,)# Finally, add some gridlinesfor x in range(self.size + 1):pygame.draw.line(canvas,0,(0, pix_square_size * x),(self.window_size, pix_square_size * x),width=3)pygame.draw.line(canvas,0,(pix_square_size * x, 0),(pix_square_size * x, self.window_size),width=3)if self.render_mode == "human":# The following line copies our drawings from `canvas` to the visible windowself.window.blit(canvas, canvas.get_rect())pygame.event.pump()pygame.display.update()# We need to ensure that human-rendering occurs at the predefined framerate.# The following line will automatically add a delay to keep the framerate stable.self.clock.tick(self.metadata["render_fps"])else: # rgb_arrayreturn np.transpose(np.array(pygame.surfarray.pixels3d(canvas)),axes=(1, 0, 2))def close(self):if self.window is not None:pygame.display.quit()pygame.quit()

envs目录下的__init__.py

from gym_examples.envs.grid_world import GridWorldEnv

envs同级别的__init__.py

这里是必需要通过register先注册环境的

from gym.envs.registration import registerregister(id='gym_examples/GridWorld-v0',  # 可自定义,但是要唯一,不要与现有的有冲突entry_point='gym_examples.envs:GridWorldEnv', # 这个是根据包的路径和类名定义的max_episode_steps=300,
)

最外层的setup.py

主要的作用

  1. 定义包的元数据包括 包名和版本号。
  2. 管理依赖。
  3. 如果其他人想要使用你的 gym_examples 包,他们只需要克隆你的代码库,并在包的根目录下运行 pip install .。这会自动安装 gym_examples 包以及指定版本的 gym 和 pygame。

所以本地开发测试的话 不用setup.py也没有问题,它主要是负责定义和管理包的分发和安装。

from setuptools import setupsetup(name="gym_examples",version="0.0.1",install_requires=["gym==0.26.1", "pygame==2.1.2"],
)

测试的 main.py

import gym
import gym_examples  # 这个就是之前定义的包env = gym.make('gym_examples/GridWorld-v0', render_mode="human")observation, info = env.reset()
done, truncated = False, False
while not done and not truncated:action = env.action_space.sample()observation, reward, done, truncated, info = env.step(action)env.close()

实际效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/287104.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【NI-RIO入门】扫描模式

于NI KB摘录 所有CompactRIO设备都可以访问CompactRIO扫描引擎和LabVIEW FPGA。 CompactRIO 904x 系列是第一个引入 DAQmx 功能的产品线。 扫描引擎(IO 变量) – 主要为迁移和初始开发而设计。控制循环频率高达 1 kHz1,性能控制器上的频率更…

【ArkTS】路由传参

传参 使用router.pushUrl(),router.push()官方不推荐再使用了。 格式: router.pushUrl({url: 路由地址,params:{参数名:值} )跳转时需要注意路由表中是否包含路由地址。 路由表路径: entry > src > main > resources &g…

专栏十六:bulk以及单细胞空转中的progeny通路分析

progeny本身有自己的R包,可以提取通路基因集信息,团队把他嵌入另一个R包decoupleR中完成富集分析。decoupleR自己有详细的针对bulk和scRNAseq的教程 简单安装一下 devtools::install_github(saezlab/OmnipathR) devtools::install_github("saezlab/progeny") Bio…

Sectigo DV多域名证书能保护几个域名

多域名SSL证书不限制受保护的域名的类型,可以时多个主域名或者子域名,多域名SSL证书都可以同时保护,比较灵活。但是,多域名https证书并不是免费无限制保护域名数量,一把的多域名SSL证书默认保护3-5个域名记录&#xff…

LVS+Keepalived 高可用集群

一.Keepalived工具介绍 1.支持故障自动切换(Failover) 2.支持节点健康状态检查(Health Checking) 3.基于vrrp协议完成地址流动 4.为vip地址所在的节点生成ipvs规则(在配置文件中预先定义) 5.为ipvs集群的各RS做健康状态检测 6.基于脚本调用接口完成脚本中定义的功能&…

沙盘模型3D打印加工服务建筑设计模型3D打印展览展示模型3D打印-CASAIM

随着3D打印技术的不断发展,沙盘模型3D打印已经成为建筑行业中的一项创新应用。这种技术能够将设计师的创意以实体形式呈现,为建筑项目的沟通和展示提供了更加直观和便捷的方式。本文将介绍CASAIM沙盘模型3D打印的优势和应用。 一、CASAIM沙盘模型3D打印的…

你在为其他知识付费平台做流量吗?

随着知识付费市场的蓬勃发展,越来越多的知识提供者选择将自己的课程放到各大知识付费平台上进行销售。然而,你是否意识到,你正在为这些平台做流量、做数据、做流水、做品牌,而卖出去的课程平台还要抽取你的佣金? 如果…

三菱PLC FX3U滑动平均值滤波

三菱PLC滑动平均值滤波其它相关写法,请参考下面文章链接: https://rxxw-control.blog.csdn.net/article/details/125044013https://rxxw-control.blog.csdn.net/article/details/125044013滑动平均值滤波程序总共分为三部分,第一步为:滑动采样。 第二步为:队列求和,第三…

Unity SRP 管线【第四讲:URP 阴影】

URP 全文源码解析参照 引入 在UniversalRenderer.cs/ line 505行处 此处已经准备好了所有渲染数据(所有数据全部存储在了renderingData中) 我们只用renderingData中的数据初设置mainLightShadows bool mainLightShadows m_MainLightShadowCasterPass…

天津仁爱学院专升本化学工程与工艺专业《化工原理》考试大纲

天津仁爱学院化学工程与工艺专业高职升本入学考试《化工原理》课程考试大纲 一.参考教材 《化工原理》(第3版)上、下册,陈常贵,柴诚敬编,天津大学出版社;ISBN:9787561833797&#…

mysql 2day 对表格的增删改查、对数据的增删改查、对内容进行操作

目录 mysql 配置文件授权 远程链接 (grant)数据库操作创建库(create)切换数据库(use)查看当前所在库 表操作创建一张员工表查看表结构修改表名称增加字段修改字段名修改字段类型以及约束条件删除字段 内容操…

随笔记录-springboot_LoggingApplicationListener+LogbackLoggingSystem

环境:springboot-2.3.1 加载日志监听器初始化日志框架 SpringApplication#prepareEnvironment SpringApplicationRunListeners#environmentPrepared EventPublishingRunListener#environmentPrepared SimpleApplicationEventMulticaster#multicastEvent(Applicati…