使用强化学习进行神经网络结构搜索的代码以及修改

目录

 代码一(Using TensorFlow):

代码二(Using TensorFlow):

代码三(Using PyTorch):

参考:


本人在网上找了三个相关的代码,但是都有问题,这里记录一下修改哪些地方之后可以跑通。

 代码一(Using TensorFlow):

代码地址: 

https://github.com/wallarm/nascell-automl

 这个代码有详细的说明:

The First Step-by-Step Guide for Implementing Neural Architecture Search with Reinforcement Learning Using TensorFlow

代码一和代码二使用TensorFlow,有很多版本的问题,基本上可以用以下代码解决:

import tensorflow.compat.v1 as tf

还有些小的问题:

需要在环境中的tensorflow文件夹下加一个example文件夹:

链接: https://pan.baidu.com/s/1mjIsDxr2TCh6wop0-99Cyw?pwd=gb6s

提取码: gb6s 

可以在TensorFlow官网搜索相关函数,查看正确用法,比如查看NASCell。

可以发现NASCell(4 * max_layers)的用法在tensorflow-addons,需要安装这个包,并且导入。

import tensorflow_addons as tfa

cnn.py中的这行初始化的代码也进行了修改。

tf.initializers.glorot_normal()

修改后代码:NAS with RL(Using TensorFlow)-CSDN博客

代码二(Using TensorFlow):

代码地址:
https://github.com/titu1994/neural-architecture-search?tab=readme-ov-file 

也是有很多版本的问题,解决思路和代码一类似,除了上面的操作,还有:

在train.py中代码开始的地方添加了下面的代码:

tf.compat.v1.disable_eager_execution()

controller.py中下面两行代码有问题:

            _, loss, summary, global_step = self.policy_session.run([self.train_op, self.total_loss, self.summaries_op,self.global_step],feed_dict=feed_dict)self.summary_writer.add_summary(self.summaries_op, global_step)

self.summaries_op这个变量好像有问题,无法fetch到,我也没看懂,好像是跟可视化有关(tensorboard),索性直接把这块删了,改成下面这块:

            _, loss, global_step = self.policy_session.run([self.train_op, self.total_loss,self.global_step],feed_dict=feed_dict)# self.summary_writer.add_summary(self.summaries_op, global_step)

修改后代码:NAS with RL(Using TensorFlow)-CSDN博客

代码三(Using PyTorch):

代码地址:

https://github.com/Longcodedao/NAS-With-RL

修改1:

把controller类下forward中注释掉的

self.total_layer = torch.randint(1, self.max_layer, (1,)).item()

移到初始化__init__下面了。

修改2:

将代码中play_episode函数下的unsqueeze和squeeze的参数由1改为0。

对其整理和解析的博客如下:

NAS with RL(使用强化学习进行神经网络架构搜索,基于pytorch框架)-CSDN博客

参考:

超详细No module named ‘tensorflow.examples’报错解决方法,详细有效!_no module named 'tensorflow.examples-CSDN博客

【问题解决】pytorch: RuntimeError: DataLoader worker (pid(s) 27292) exited unexpectedly_runtimeerror: dataloader worker (pid(s) 25676, 116-CSDN博客

Python-squeeze()、unsqueeze()函数的理解_python squeeze-CSDN博客

torch.gather/torch.scatter_size does not match previous size-CSDN博客

Tensorflow报错:TypeError: Fetch argument None has invalid type class ‘NoneType’_typeerror: fetch argument none has invalid type <c-CSDN博客

TensorFlow报错:tf.placeholder() is not compatible with eager execution.-CSDN博客

tensorflow_addons(tfa)安装与使用-CSDN博客

AttributeError: module ‘keras.backend‘ has no attribute ‘set_session‘_module 'keras.backend' has no attribute 'set_sessi-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/427724.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web--HTML基础

文章目录 安装环境HTMLhtml框架html基础标签语义标签html特殊符号 安装环境 安装vscode后 安装插件 可以先不写后台直接将前度界面展示出来 自动补全tag&#xff0c;同时修改tag时自动改另一半 在设置里将保存自动格式化的选项勾上 创建一个index.htm文件&#xff0c;这个…

基于sentinel-2 遥感数据的水体提取(水体指数法)

本文框架设置如下&#xff1a; 简单介绍senintel-2数据&#xff1b;如何利用sentinel-2数据获取水体边界/范围 1 Sentinel-2数据介绍及下载方式 有Sentinel-2A/2B两颗卫星&#xff0c;其参数基本一致&#xff0c;因此两颗卫星的数据联合使用很方便。 分辨率有&#xff1a;1…

springboot114基于多维分类的知识管理系统

简介 【毕设源码推荐 javaweb 项目】基于springbootvue 的基于多维分类的知识管理系统 适用于计算机类毕业设计&#xff0c;课程设计参考与学习用途。仅供学习参考&#xff0c; 不得用于商业或者非法用途&#xff0c;否则&#xff0c;一切后果请用户自负。 看运行截图看 第五章…

黑马Java——面向对象进阶(static继承)

1.static静态变量 静态变量是随着类的加载而加载的&#xff0c;优先与对象出现的

Feature Pyramid Grids 原理与代码解析

paper&#xff1a;Feature Pyramid Grids third-party implementation&#xff1a;https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/necks/fpg.py 存在的问题 基于NAS得到的特征金字塔结构如NAS-FPN展现了良好的性能表现&#xff0c;但用NAS寻找改进结…

如何给openai的assistant添加Functions

我的chatgpt网站&#xff1a; https://chat.xutongbao.top/ {"name": "get_current_datetime","description": "获取北京时间&#xff0c;当前时间&#xff0c;当前日期","parameters": {"type": "object&q…

《GreenPlum系列》GreenPlum初级教程-GreenPlum详细入门教程

文章目录 GreenPlum详细入门教程第一章 GreenPlum介绍1.MPP架构介绍2.GreenPlum介绍3.GreenPlum数据库架构4.GreenPlum数据库优缺点 第二章 GreenPlum单节点安装1.Docker创建centos容器1.1 拉取centos7镜像1.2 创建容器1.3 进入容器1.4 容器和服务器免密操作1.4.1 生成密钥1.4.…

1.8 万 Star!这款 Nginx 可视化配置工具太强了

NginxConfig简介 Nginx Config 是一个强大的 Nginx 配置文件生成器&#xff0c;号称配置 Nginx 服务器所需的唯一工具。 正因为 Nginx 功能强大&#xff0c;所以针对其各个功能的配置项会显得特别多&#xff0c;对于我们来说要记住那么多配置是一件十分头疼的事&#xff0c;甚…

仰暮计划|“去咱们那的小坝上吹吹风,看看黄河的水势有没有上涨…”

从来不觉得时间过得有多快&#xff0c;只是日月不断的更替。到了今天&#xff0c;我才不得不承认时间已经过去了很久很久&#xff0c;我的爷爷也已不再年轻。我是爷爷奶奶带大的&#xff0c;自从我记事起&#xff0c;他们就一直陪伴着我了。那时候爸爸妈妈外出打工&#xff0c;…

spring-framework6.x版本源码构建

6.x.修改gradle仓库构建 IDEA版本及gradle构建设置 在gradle指定仓库地址/wrapper/dists/找到与gradle wrapper相对应的gradle版本&#xff0c;在gradle的init.d/目录下新建init.gradle文件&#xff0c;内容如下&#xff1a; allprojects{repositories {mavenLocal()maven { …

41.while语句

目录 一.什么是while语句 二.语法 三.执行流程图 四.举例 五.视频教程 一.什么是while语句 只要条件为真&#xff0c;while循环中的语句会一直重复执行。 二.语法 while&#xff08;表达式&#xff09;{//代码块 } 三.执行流程图 从流程图可以看出&#xff0c;while循环…

【技术】SpringBoot 接口怎么加密解密

1. 介绍 在我们日常的Java开发中&#xff0c;免不了和其他系统的业务交互&#xff0c;或者微服务之间的接口调用 如果我们想保证数据传输的安全&#xff0c;对接口出参加密&#xff0c;入参解密。 但是不想写重复代码&#xff0c;我们可以提供一个通用starter&#xff0c;提…