TensorFlow2实战-系列教程13:Resnet实战1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

1、残差连接

深度学习中出现了随着网络的堆叠效果下降的现象,Resnet使用残差连接的方法解决了这个问题,让深度学习从此变得深了起来。
残差连接的做法可以表示为:
y = f ( x ) + x ( s h o r t c u t ) y = f(x)+x(shortcut) y=f(x)+x(shortcut)
其中y表示网络最后的输出,x为输入,f(x)则表示输入经过几层网络的输出结果,比如三次卷积+三次批归一化+三次relu,一般情况下f(x)就是网络的最终输出,这里再加上x就是一个残差连接的操作。

残差连接的操作保证了,x经过网络后得到的y一定比f(x)更好的结果,最差是同等效果,也就是保证了不会出现效果降低的情况。

这个shortcut是什么意思呢?因为x经过几次卷积后,可能会出现多个特征图,也就是f(x)和x的通道数不一样了,这个时候就需要对x的通道数进行调整再与f(x)相加得到y,如果通道数一样就不需要调整了
在这里插入图片描述
上图就是通道数没有发生变化的情况, y = f ( x ) + x y = f(x)+x y=f(x)+x,x经过两次(卷积+批归一化+ReLU)和一次(卷积+批归一化)后得到f(x),再加上x后经过ReLU就得到了最终的y
在这里插入图片描述
上图就是通道数发生变化的情况, y = f ( x ) + C o n v 2 d ( x ) y = f(x)+Conv2d(x) y=f(x)+Conv2d(x),x经过两次(卷积+批归一化+ReLU)和一次(卷积+批归一化)后得到f(x),x再经过一次(二维卷积+批归一化),这个二维卷积的卷积核是1x1的,经过这个二维卷积的x再加上f(x)后经过ReLU就得到了最终的y

2、项目介绍

在这里插入图片描述

  1. dataset文件夹,将原始数据分割成训练、验证、测试三个数据集
  2. models构建模型的代码,包含resnet31、resnet50、resnet101、resnet152的构建代码,以及残差模块实现的代码
  3. original_dataset,原始数据,包含猫、狗、熊猫3个类别的数据,每个类别1000张图像
  4. save_model,训练模型保存的路径
  5. config.py 设置配置参数的代码
  6. evaluate.py 使用测试集对模型进行测试的代码
  7. prepare_data.py 数据预处理的辅助函数代码
  8. split_dataset.py 将原始数据集分割成训练集、验证集、测试集的代码
  9. train.py 训练验证的代码

3、训练脚本train.py解读------数据预处理

from __future__ import absolute_import, division, print_function
import tensorflow as tf
from models import resnet50, resnet101, resnet152, resnet34
import config
from prepare_data import generate_datasets
import mathif __name__ == '__main__':# GPU settingsgpus = tf.config.experimental.list_physical_devices('GPU')if gpus:for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)

导入项目工具包和辅助函数
配置 TensorFlow 中的 GPU 内存,:

  1. tf.config.experimental.list_physical_devices('GPU'):这个函数调用列出了 TensorFlow 在你的机器上可用的所有 GPU
  2. if gpus: 这个检查用来确认是否有可用的 GPU。如果有,它将继续对每一个 GPU 进行配置
  3. 在循环内部,对每一个 GPU 调用 tf.config.experimental.set_memory_growth(gpu, True),这使得 GPU 上的内存增长被启用
# get the original_datasettrain_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets()
def generate_datasets():train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)# read the original_dataset in the form of batchtrain_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
def get_dataset(dataset_root_dir):all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir)# print("image_path: {}".format(all_image_path[:]))# print("image_label: {}".format(all_image_label[:]))# load the dataset and preprocess imagesimage_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image)label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)dataset = tf.data.Dataset.zip((image_dataset, label_dataset))image_count = len(all_image_path)return dataset, image_count
def get_images_and_labels(data_root_dir):# 得到所有图像路径data_root = pathlib.Path(data_root_dir)all_image_path = [str(path) for path in list(data_root.glob('*/*'))]# 得到标签名字label_names = sorted(item.name for item in data_root.glob('*/'))# 例如:{'cats': 0, 'dogs': 1, 'panda': 2}label_to_index = dict((index, label) for label, index in enumerate(label_names))# 每一个图像对应的标签all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]return all_image_path, all_image_labeldef load_and_preprocess_image(img_path):# read picturesimg_raw = tf.io.read_file(img_path)# decode picturesimg_tensor = tf.image.decode_jpeg(img_raw, channels=channels)# resizeimg_tensor = tf.image.resize(img_tensor, [image_height, image_width])img_tensor = tf.cast(img_tensor, tf.float32)# normalizationimg = img_tensor / 255.0return img

load_and_preprocess_image()函数:

  1. 通过读取一个图像的路径
  2. 返回Tensor
  3. 进去进行归一化

get_images_and_labels()函数:

  1. 通过数据集的地址,获取当前目录下的所有图像的名称
  2. 在加上前缀路径和文件后缀,得到当前所有图像的对应的地址
  3. 返回地址和标签

get_dataset()函数:

  1. 通过调用get_images_and_labels()函数,得到当前目录下的图像的对应的地址和标签
  2. 使用from_tensor_slices方法和load_and_preprocess_image()函数读取地址和标签转换为Tensor
  3. 返回标签和数据组成的Tensor以及数据量

generate_datasets()函数:

  1. 训练、验证、测试数据路径分别通过调用get_dataset()函数得到训练、验证、测试数据Tensor和数据量
  2. 对训练、验证、测试数据加上batch_size和shuffle参数
  3. 返回训练、验证、测试数据Tensor和数据量

Resnet实战1
Resnet实战2
Resnet实战3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/444096.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React Router 完美教程(下)

我们书接上回,继续我们的React Router 路由之路: 我们到目前为止都没有用到 state、useEffect、redux等状态管理器。但也达到了我们的设计目的。 注意,action 返回的结果 可以在组件中使用 useActionData() 来获取。就像 useLoaderData() 的使…

天软因子数据库新成员——天软指数基金因子库

新闻资讯 News information UPDATE 天软指数基金因子库 天软始终致力于构建完善而丰富的因子库服务体系,陆续推出了股票因子、基金因子、多因子系列等众多因子数据及评价数据。 2024年天软将继续推进各品种因子的深入研究,开年之际天软推出指数基金因子库…

基于EdgeWorkers的边缘应用如何进行单元测试?

随着各行各业数字化转型的持续深入,越来越多企业开始选择将一些应用程序放在距离最终用户更近的边缘位置来运行,借此降低延迟,提高应用程序响应速度,打造更出色的用户体验。 相比传统集中部署和运行的方式,这种边缘应…

769933-15-5,Biotin aniline,可以合成多种有机化合物和聚合物

您好,欢迎来到新研之家 文章关键词:769933-15-5,Biotin aniline,生物素苯胺,生物素-苯胺 一、基本信息 产品简介:Biotin Aniline,一种具有重要生物学功能的化合物,不仅参与了维生…

Fiddler修改https请求与响应 bug修复变灰了选不了等 Fiddle对夜神模拟器抓包设置

不要修改别人的东西,不要修改别人的东西,不要修改别人的东西 只用于自己的网站,自己安全调试。 fiddler修改https请求 1、打到要改的请求 2、替换请求内容 3、开启捕获。操作产生请求。 4、fiddler里查看请求或响应数据 ,确认成…

python给word插入脚注

1.需求 最近因为工作需要,需要给大量文本的脚注插入内容,我就写了个小程序。 2.实现 下面程序是我已经给所有脚注插入了两次文本“幸福”,给脚注2到4再插入文本“幸福” from win32com import clientdef add_text_to_specific_footnotes(…

服务器系统安装时无法进入引导界面处理【笔记】

服务器系统安装时无法进入安装界面处理; 服务器系统安装时无法进入安装界面处理; 1、按“e”键进入编辑模式。 2、在出现的编辑项里,找到quiet aplash后空格加入nomodeset; 3、然后按CTRLX或F10引导系统,启动之后…

地毯填补问题

地毯填补问题 题目描述 相传在一个古老的阿拉伯国家里,有一座宫殿。宫殿里有个四四方方的格子迷宫,国王选择驸马的方法非常特殊,也非常简单:公主就站在其中一个方格子上,只要谁能用地毯将除公主站立的地方外的所有地…

结构体的学习

结构体与共用体,枚举 1.数据类型复习: 2结构体. eg;统计全校同学信息 需要记录的点--- 姓名,班级,性别,成绩,年龄 统计名字:char s[ ] [ 100 ] { "Tmo" } …

41、WEB攻防——通用漏洞XMLXXE无回显DTD实体伪协议代码审计

文章目录 XXE原理&探针&利用XXE读取文件XXE带外测试XXE实体引用XXE挖掘XXE修复 参考资料:CTF XXE XXE原理&探针&利用 XXE用到的重点知识是XML,XML被设计为传输和存储数据,XML文档结构包括XML声明、DTD文档类型定义&#xf…

【MySQL】学习并使用聚合函数和DQL进行分组查询

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-t8K8tl6eNwqdFmcD {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

网络层 IP协议(1)

前置知识 主机:配有IP地址,但是不进行路由控制的设备 路由器:既配置了IP地址,又能进行路由控制的设备 节点:主机和路由器的总称 IP协议主要完成的任务就是 地址管理和路由选择 地址管理:使用一套地址体系,将网络设备的地址描述出来 路由选择:一个数据报如何从源地址到目的地址 …