TensorFlow中slim包的具体用法

TensorFlow中slim包的具体用法

  • 1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)
  • 3、模型训练
    • 1、数据集相关模块:
    • 2、设置网络模型模块
    • 3、数据预处理模块
    • 4、定义损失loss
    • 5、定义优化器模块

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
在这里插入图片描述

TensorFlow中slim包的目录结构:

-- slim|-- BUILD|-- README.md|-- WORKSPACE|-- __init__.py|-- datasets|   |-- __init__.py|   |-- __pycache__|   |   |-- __init__.cpython-37.pyc|   |   |-- dataset_utils.cpython-37.pyc|   |   |-- download_and_convert_cifar10.cpython-37.pyc|   |   |-- download_and_convert_flowers.cpython-37.pyc|   |   `-- download_and_convert_mnist.cpython-37.pyc|   |-- build_imagenet_data.py|   |-- cifar10.py|   |-- dataset_factory.py|   |-- dataset_utils.py|   |-- download_and_convert_cifar10.py|   |-- download_and_convert_flowers.py|   |-- download_and_convert_imagenet.sh|   |-- download_and_convert_mnist.py|   |-- download_imagenet.sh|   |-- flowers.py|   |-- imagenet.py|   |-- imagenet_2012_validation_synset_labels.txt|   |-- imagenet_lsvrc_2015_synsets.txt|   |-- imagenet_metadata.txt|   |-- mnist.py|   |-- preprocess_imagenet_validation_data.py|   `-- process_bounding_boxes.py|-- deployment|   |-- __init__.py|   |-- model_deploy.py|   `-- model_deploy_test.py|-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式|-- eval_image_classifier.py        # 测试模型分类效果|-- export_inference_graph.py|-- export_inference_graph_test.py|-- nets|   |-- __init__.py|   |-- alexnet.py|   |-- alexnet_test.py|   |-- cifarnet.py|   |-- cyclegan.py|   |-- cyclegan_test.py|   |-- dcgan.py|   |-- dcgan_test.py|   |-- i3d.py|   |-- i3d_test.py|   |-- i3d_utils.py|   |-- inception.py|   |-- inception_resnet_v2.py|   |-- inception_resnet_v2_test.py|   |-- inception_utils.py|   |-- inception_v1.py|   |-- inception_v1_test.py|   |-- inception_v2.py|   |-- inception_v2_test.py|   |-- inception_v3.py|   |-- inception_v3_test.py|   |-- inception_v4.py|   |-- inception_v4_test.py|   |-- lenet.py|   |-- mobilenet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- conv_blocks.py|   |   |-- madds_top1_accuracy.png|   |   |-- mnet_v1_vs_v2_pixel1_latency.png|   |   |-- mobilenet.py|   |   |-- mobilenet_example.ipynb|   |   |-- mobilenet_v2.py|   |   `-- mobilenet_v2_test.py|   |-- mobilenet_v1.md|   |-- mobilenet_v1.png|   |-- mobilenet_v1.py|   |-- mobilenet_v1_eval.py|   |-- mobilenet_v1_test.py|   |-- mobilenet_v1_train.py|   |-- nasnet|   |   |-- README.md|   |   |-- __init__.py|   |   |-- nasnet.py|   |   |-- nasnet_test.py|   |   |-- nasnet_utils.py|   |   |-- nasnet_utils_test.py|   |   |-- pnasnet.py|   |   `-- pnasnet_test.py|   |-- nets_factory.py|   |-- nets_factory_test.py|   |-- overfeat.py|   |-- overfeat_test.py|   |-- pix2pix.py|   |-- pix2pix_test.py|   |-- resnet_utils.py|   |-- resnet_v1.py|   |-- resnet_v1_test.py|   |-- resnet_v2.py|   |-- resnet_v2_test.py|   |-- s3dg.py|   |-- s3dg_test.py|   |-- vgg.py|   `-- vgg_test.py|-- preprocessing|   |-- __init__.py|   |-- cifarnet_preprocessing.py|   |-- inception_preprocessing.py|   |-- lenet_preprocessing.py|   |-- preprocessing_factory.py|   `-- vgg_preprocessing.py|-- scripts                     # gqr:存储的是相关的模型训练脚本                |   |-- export_mobilenet.sh|   |-- finetune_inception_resnet_v2_on_flowers.sh|   |-- finetune_inception_v1_on_flowers.sh|   |-- finetune_inception_v3_on_flowers.sh|   |-- finetune_resnet_v1_50_on_flowers.sh|   |-- train_cifarnet_on_cifar10.sh|   `-- train_lenet_on_mnist.sh|-- setup.py|-- slim_walkthrough.ipynb`-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; thenmkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; thenwget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gztar -xvf resnet_v1_50_2016_08_28.tar.gzmv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckptrm resnet_v1_50_2016_08_28.tar.gz
fi# Download the dataset
python download_and_convert_data.py \--dataset_name=flowers \--dataset_dir=${DATASET_DIR}# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50 \--checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \--batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR} \--eval_dir=${TRAIN_DIR} \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \--train_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=${DATASET_DIR} \--checkpoint_path=${TRAIN_DIR} \--model_name=resnet_v1_50 \--max_number_of_steps=1000 \--batch_size=32 \--learning_rate=0.001 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004# Run evaluation.
python eval_image_classifier.py \--checkpoint_path=${TRAIN_DIR}/all \--eval_dir=${TRAIN_DIR}/all \--dataset_name=flowers \--dataset_split_name=validation \--dataset_dir=${DATASET_DIR} \--model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
在这里插入图片描述
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
在这里插入图片描述
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
在这里插入图片描述
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
在这里插入图片描述

3、模型训练

模型训练文件为
在这里插入图片描述
以下是该文件中各个模块相关内容

1、数据集相关模块:

在这里插入图片描述

2、设置网络模型模块

在这里插入图片描述

3、数据预处理模块

在这里插入图片描述

4、定义损失loss

在这里插入图片描述

5、定义优化器模块

在这里插入图片描述

运行训练指令:

python train_image_classifier.py \--train_dir=./data/flowers-models/resnet_v1_50\--dataset_name=flowers \--dataset_split_name=train \--dataset_dir=./data/flowers \--model_name=resnet_v1_50 \--checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \--checkpoint_exclude_scopes=resnet_v1_50/logits \--trainable_scopes=resnet_v1_50/logits \--max_number_of_steps=3000 \ --batch_size=32 \--learning_rate=0.01 \--save_interval_secs=60 \--save_summaries_secs=60 \--log_every_n_steps=100 \--optimizer=rmsprop \--weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/84006.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MathType7MAC中文版数学公式编辑器下载安装教程

如今许多之前需要手写的内容都可以在计算机中完成了。以前我们可以通过word输入一些简单的数学公式,但现在通过数学公式编辑器便可以完成几乎所有数学公式的写作。许多简单的数学公式,我们可以使用输入法一个个找到特殊符号并输入,但是对于高…

LeetCode--HOT100题(42)

目录 题目描述:108. 将有序数组转换为二叉搜索树(简单)题目接口解题思路代码 PS: 题目描述:108. 将有序数组转换为二叉搜索树(简单) 给你一个整数数组 nums ,其中元素已经按 升序 排列&#xf…

单片机学习-蜂鸣器电子元件

蜂鸣器是有什么作用的? 蜂鸣器 是 一种 一体化结构 的电子训响器,可以发出声音的电子元器件 蜂鸣器分类? ①压电式蜂鸣器(图左) 称: 无源蜂鸣器 ②电磁式蜂鸣器(图右) 称&#xf…

Android平台RTMP|RTSP直播播放器功能进阶探讨

我们需要怎样的直播播放器? 很多开发者在跟我聊天的时候,经常问我,为什么一个RTMP或RTSP播放器,你们需要设计那么多的接口,真的有必要吗?带着这样的疑惑,我们今天聊聊Android平台RTMP、RTSP播放…

MAVEN利器:一文带你了解IDEA中如何使用Maven

前言: 强大的构建工具——Maven。作为Java生态系统中的重要组成部分,Maven为开发人员提供了一种简单而高效的方式来构建、管理和发布Java项目。无论是小型项目还是大型企业级应用,Maven都能帮助开发人员轻松处理依赖管理、编译、测试和部署等…

17.4 【Linux】systemctl 针对 timer 的配置文件

有时候,某些服务你想要定期执行,或者是开机后执行,或者是什么服务启动多久后执行等等的。在过去,我们大概都是使用 crond 这个服务来定期处理, 不过,既然现在有一直常驻在内存当中的 systemd 这个好用的东西…

neo4jd3拓扑节点显示为节点标签(自定义节点显示)

需求描述:如下图所示,我的拓扑图中有需要不同类型的标签节点,我希望每个节点中显示的是节点的标签 在官方示例中,我们可以看到,节点里面是可以显示图标的,现在我们想将下面的图标换成我们自定义的内容 那…

基于Python3 的 简单股票 可转债 提醒逻辑

概述 通过本地的定时轮训,结合本地建议数据库。检查股票可转债价格的同事,进行策略化提醒 详细 前言 为什么会有这么个东西出来呢,主要是因为炒股软件虽然有推送,但是设置了价格之后,看到推送也未必那么及时&#…

SpringCloud学习笔记(四)_ZooKeeper注册中心

基于Spring Cloud实现服务的发布与调用。而在18年7月份,Eureka2.0宣布停更了,将不再进行开发,所以对于公司技术选型来说,可能会换用其他方案做注册中心。本章学习便是使用ZooKeeper作为注册中心。 本章使用的zookeeper版本是 3.6…

Ansible 自动化安装软件

例子如下: 创建一个名为/ansible/package.yml 的 playbook : 将 php 和 mariadb 软件包安装到 dev、test 和 prod 主机组中的主机上 将 RPM Development Tools 软件包组安装到 dev 主机组中的主机上 将 dev 主机组中主机上的所有软件包更新为最新版本 --- - name:…

Python自动化测试代理程序可用性

在网络爬虫和数据采集过程中,代理服务器扮演着重要的角色。然而,代理服务器的可用性经常会受到影响,给爬虫工作带来一定的挑战。本文将介绍如何使用Python自动化测试代理程序的可用性,为您提供具备实际操作价值的解决方案。让我们…

【业务功能篇83】微服务SpringCloud-ElasticSearch-Kibanan-docke安装-应用层实战

五、ElasticSearch应用 1.ES 的Java API两种方式 Elasticsearch 的API 分为 REST Client API(http请求形式)以及 transportClient API两种。相比来说transportClient API效率更高,transportClient 是通过Elasticsearch内部RPC的形式进行请求…