TensorFlow2实战-系列教程6:猫狗识别3------迁移学习

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  include_top = False, weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))     validation_generator =  test_datagen.flow_from_directory( validation_dir,batch_size  = 20,class_mode  = 'binary', target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述
猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/448155.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【linux】校招中的“熟悉linux操作系统”一般是指达到什么程度?

这样,你先在网上找一套完整openssh升级方案(不是yum或apt的,要源码安装的),然后在虚拟机上反复安装测试,直到把他理解了、背下来。 面试的时候让你简单说说linux命令什么的,你就直接把这个方案…

uniapp中echart实例

1&#xff0c;自定义仪表盘 797_1706772047 index.vue import { useGaugeStore } from "/stores/utils"; const { currentValueEndAngle, currentSplitNumber } storeToRefs(useGaugeStore() ); const gaugeStore useGaugeStore();const wenduGauge ref<chart…

抖音老阳叭叭叭分享的temu蓝海项目优势有哪些?

在市场竞争日益激烈的今天&#xff0c;寻找一个具有发展潜力的蓝海项目已成为众多创业者的共同目标。近日&#xff0c;老阳分享了一个名为Temu的蓝海项目&#xff0c;引发了广泛关注。那么&#xff0c;这个项目究竟有何优势&#xff0c;能让人们对它青睐有加呢?以下几个方面将…

2024美国大学生数学建模C题网球运动中的势头详解思路+具体代码

2024美国大学生数学建模C题网球运动中的势头详解思路具体代码 E题数据已更新&#xff0c;做E题的小伙伴推荐看看博主的E题解析文章。那么废话不多说我们继续来做C题。 赛题分析 我们先阅题&#xff1a; 在2023年温布尔登男单决赛中&#xff0c;20岁的西班牙新星卡洛斯阿尔卡…

Linux 常用命令行

Linux (Ubuntu) 常用操作命令行 1. 打开终端&#xff1a;ctrl alt t; 2. 清屏&#xff1a;clear; 3. 进入目录&#xff1a;cd path;[/ 根目录&#xff1b;./ 当前目录&#xff1b;../ 上一级] 4. 返回上一级目录: cd ..; 5. 显示工作路径: pwd; 6. 列表显示文件、文件夹&…

【FPGA VerilogModelsim】 8bitBCD码60计数器

可私信获取整个项目文件 8bit 即有8位二进制 BCD码 ,全称Binary-Coded Decimal,简称BCD码或者二-十进制代码 利用四位二进制(0000-1111)16个中选择10个作为十进制0-9; 常见的BCD码是8421码 本项目使用两组BCD码(每组4bit,共8bit,故称为8bitBCD)(高位0-5,低位0-9…

定义HarmonyOS IDL接口

HarmonyOS IDL简介 HarmonyOS Interface Definition Language&#xff08;简称HarmonyOS IDL&#xff09;是HarmonyOS的接口描述语言。HarmonyOS IDL与其他接口语言类似&#xff0c;通过HarmonyOS IDL定义客户端与服务端均认可的编程接口&#xff0c;可以实现在二者间的跨进程…

深度学习入门笔记(四)函数与优化方法

深度学习有三大部分 模型表征(包括模型设计、网络表示等)模型评估(上一篇文章提到的准确召回和损失函数等)优化算法(模型如何学习或更新)本节我们就来介绍模型是如何学习或更新的。 4.1 损失函数 模型的学习,实际上就是对参数的学习。参数学习的过程需要一系列的约束,…

JavaScript基础五对象 内置对象 Math.random()

内置对象-生成任意范围随机数 Math.random() 随机数函数&#xff0c; 返回一个0 - 1之间&#xff0c;并且包括0不包括1的随机小数 [0, 1&#xff09; 如何生成0-10的随机数呢&#xff1f; Math.floor(Math.random() * (10 1)) 放大11倍再向下取整 如何生成5-10的随机数&…

100 C++内存高级话题 new 细节探秘,重载类内 operator new ,delete

一 new 内存分配细节探秘 我们以分配10个char为例&#xff0c;说明&#xff0c;观察内存发现&#xff0c;当delete 的时候&#xff0c;实际上很多内存都改变了。 实际上 new 内存不是一个简单的事情。为了记录和管理分配出去的内存&#xff0c;额外分配了不少内存&#xff0c;…

C++ 音视频流媒体浅谈

C流媒体开发 今天就浅浅聊一下C流媒体开发 流媒体开发中最常见的是FFmpeg&#xff08;编解码器&#xff09; 业务逻辑主要是播放器了&#xff08;如腾旭视频 爱奇艺等等&#xff09; FFmpeg是一个开源的音视频处理工具集&#xff0c;可以用于处理、转换和流媒体传输音视频…

微服务介绍

1. 什么是微服务 在介绍微服务时&#xff0c;首先得先理解什么是微服务&#xff0c;顾名思义&#xff0c;微服务得从两个方面去理解&#xff0c;什么是"微"、什么是"服务"&#xff0c; 微 狭义来讲就是体积小、著名的"2 pizza 团队"很好的诠释了…