TensorFlow2实战-系列教程6:迁移学习实战

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  include_top = False, weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))     validation_generator =  test_datagen.flow_from_directory( validation_dir,batch_size  = 20,class_mode  = 'binary', target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/439105.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Time Series】LSTM代码实战

一、简介 还是那句话,"时间序列金融"是一个很有"钱"景的话题,还是想尝试采用Stock时间序列预测任务DeepLearning。本文提供了LSTM预测股票的源代码。 二、代码 运行代码时的注意事项:按照配置项创建好对应的文件夹&#…

Android SystemUI 介绍

目录 一、什么是SystemUI 二、SystemUI应用源码 三、学习 SystemUI 的核心组件 四、修改状态与导航栏测试 本篇文章,主要科普的是Android SystemUI , 下一篇文章我们将介绍如何把Android SystemUI 应用转成Android Studio 工程项目。 一、什么是Syst…

Hadoop3.x基础(1)

来源:B站尚硅谷 这里写目录标题 大数据概论大数据概念大数据特点(4V)大数据应用场景 Hadoop概述Hadoop是什么Hadoop发展历史(了解)Hadoop三大发行版本(了解)Hadoop优势(4高)Hadoop组成&#xf…

349. 两个数组的交集(力扣LeetCode)

文章目录 349. 两个数组的交集题目描述数组解题set容器解题该思路数组版解题 349. 两个数组的交集 题目描述 给定两个数组 nums1 和 nums2 ,返回 它们的交集 。输出结果中的每个元素一定是 唯一 的。我们可以 不考虑输出结果的顺序 。 示例 1: 输入&a…

C#,贝尔数(Bell Number)的计算方法与源程序

1 埃里克坦普尔贝尔 贝尔数是组合数学中的一组整数数列,以埃里克坦普尔贝尔(Eric Temple Bell)命名, 埃里克坦普尔贝尔(生于1883年2月7日,苏格兰阿伯丁郡阿伯丁,于1960年12月21日在美国加利福尼…

RK3568平台开发系列讲解(Linux系统篇)device 资源的获取

🚀返回专栏总目录 文章目录 一、platform_device 结构体二、platform_get_resource() 获取沉淀、分享、成长,让自己和他人都能有所收获!😄 一、platform_device 结构体 struct platform_driver 结构体继承了 struct device_driver 结构体, 因此可以直接访问 struct devi…

SeaTunnel集群安装

环境准备 服务器节点 节点名称 IP bigdata1 192.168.1.250 bigdata4 192.168.1.251 bigdata5 192.168.1.252 Java环境(三个节点都需要) java1.8 注意:在安装SeaTunnel集群时,最好是现在一个节点上将所有配置都修改完&a…

jenkins pipeline配置maven可选参数

1、在Manage Jenkins下的Global Tool Configuration下对应的maven项添加我们要用得到的不同版本的maven安装项 2、pipeline文件内容具体如下 我们maven是单一的,所以我们都是配置单选参数 pipeline {agent anyparameters {gitParameter(name: BRANCH_TAG, type: …

(五)MySQL的备份及恢复

1、MySQL日志管理 在数据库保存数据时,有时候不可避免会出现数据丢失或者被破坏,这样情况下,我们必须保证数据的安全性和完整性,就需要使用日志来查看或者恢复数据了 数据库中数据丢失或被破坏可能原因: 误删除数据…

jQuery 下载 使用

1. jQuery 网址&#xff1a;jQuery CDN 2. 下载, 右击在新的tab打开 3.ctrls另存为jquery-1.12.4.min.js&#xff0c;并放到自己项目的js目录下 4.请参考目录 5.编写第一个使用jQuery的代码 <script src"js/jquery-1.12.4.min.js"></script> <!DOCT…

数学知识第五期 扩展欧几里得算法

前言 扩展欧几里得算法也是重要的数论&#xff0c;数论固然就是数学知识的理论&#xff0c;希望大家能够熟练掌握&#xff01;&#xff01;&#xff01; 一、扩展欧几里得算法的基本内容 扩展欧几里得算法&#xff08;英语&#xff1a;Extended Euclidean algorithm&#xf…

非阿里云注册域名如何在云解析DNS设置解析?

概述 非阿里云注册域名使用云解析DNS&#xff0c;按照如下步骤&#xff1a; 添加域名。 添加解析记录。 修改DNS服务器。 DNS服务器变更全球同步&#xff0c;等待48小时。 添加解析记录 登录云解析DNS产品控制台。 在 域名解析 页面中&#xff0c;单击 添加域名 。 在 …