LeNet对MNIST 数据集中的图像进行分类--keras实现

我们将训练一个卷积神经网络来对 MNIST 数据库中的图像进行分类,可以与前面所提到的CNN实现对比CNN对 MNIST 数据库中的图像进行分类-CSDN博客

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图像
from keras.datasets import mnist# 使用 Keras 导入预洗牌 MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

将前六个训练图像可视化 

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])ax.imshow(X_train[i], cmap='gray')ax.set_title(str(y_train[i]))

 查看图像的更多细节

def visualize_input(img, ax):ax.imshow(img, cmap='gray')width, height = img.shapethresh = img.max()/2.5for x in range(width):for y in range(height):ax.annotate(str(round(img[x][y],2)), xy=(y,x),horizontalalignment='center',verticalalignment='center',color='white' if img[x][y]<thresh else 'black')fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# normalize the data to accelerate learning
mean = np.mean(X_train)
std = np.std(X_train)
X_train = (X_train-mean)/(std+1e-7)
X_test = (X_test-mean)/(std+1e-7)print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import np_utilsnum_classes = 10 
# print first ten (integer-valued) training labels
print('Integer-valued labels:')
print(y_train[:10])# one-hot encode the labels
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)# print first ten (one-hot) training labels
print('One-hot labels:')
print(y_train[:10])

重塑数据以适应我们的 CNN(和 input_shape)

# input image dimensions 28x28 pixel images. 
img_rows, img_cols = 28, 28X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)print('image input shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

论文地址:lecun-01a.pdf

要在 Keras 中实现 LeNet-5,请阅读原始论文并从第 6、7 和 8 页中提取架构信息。以下是构建 LeNet-5 网络的主要启示:

  • 每个卷积层的滤波器数量:从图中(以及论文中的定义)可以看出,每个卷积层的深度(滤波器数量)如下:C1 = 6、C3 = 16、C5 = 120 层。
  • 每个 CONV 层的内核大小:根据论文,内核大小 = 5 x 5
  • 每个卷积层之后都会添加一个子采样层(POOL)。每个单元的感受野是一个 2 x 2 的区域(即 pool_size = 2)。请注意,LeNet-5 创建者使用的是平均池化,它计算的是输入的平均值,而不是我们在早期项目中使用的最大池化层,后者传递的是输入的最大值。如果您有兴趣了解两者的区别,可以同时尝试。在本实验中,我们将采用论文架构。
  • 激活函数:LeNet-5 的创建者为隐藏层使用了 tanh 激活函数,因为对称函数被认为比 sigmoid 函数收敛更快。一般来说,我们强烈建议您为网络中的每个卷积层添加 ReLU 激活函数。

需要记住的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除了网络中的最后一层,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最后一层应该是具有软最大激活函数的密集(FC)层。最终层的节点数应等于数据集中的类别总数。
from keras.models import Sequential
from keras.layers import Conv2D, AveragePooling2D, Flatten, Dense
#Instantiate an empty model
model = Sequential()# C1 Convolutional Layer
model.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding='same'))# S2 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))# C3 Convolutional Layer
model.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))# S4 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))# C5 Fully Connected Convolutional Layer
model.add(Conv2D(120, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))#Flatten the CNN output so that we can connect it with fully connected layers
model.add(Flatten())# FC6 Fully Connected Layer
model.add(Dense(84, activation='tanh'))# Output Layer with softmax activation
model.add(Dense(10, activation='softmax'))# print the model summary
model.summary()

编译模型

我们将使用亚当优化器

# the loss function is categorical cross entropy since we have multiple classes (10) # compile the model by defining the loss function, optimizer, and performance metric
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

训练模型

LeCun 和他的团队采用了计划衰减学习法,学习率的值按照以下时间表递减:前两个历元为 0.0005,接下来的三个历元为 0.0002,接下来的四个历元为 0.00005,之后为 0.00001。在论文中,作者对其网络进行了 20 个历元的训练。

from keras.callbacks import ModelCheckpoint, LearningRateScheduler# set the learning rate schedule as created in the original paper
def lr_schedule(epoch):if epoch <= 2:     lr = 5e-4elif epoch > 2 and epoch <= 5:lr = 2e-4elif epoch > 5 and epoch <= 9:lr = 5e-5else: lr = 1e-5return lrlr_scheduler = LearningRateScheduler(lr_schedule)# set the checkpointer
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)# train the model
hist = model.fit(X_train, y_train, batch_size=32, epochs=20,validation_data=(X_test, y_test), callbacks=[checkpointer, lr_scheduler], verbose=2, shuffle=True)

在验证集上加载分类准确率最高的模型

# load the weights that yielded the best validation accuracy
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率

# evaluate test accuracy
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]# print test accuracy
print('Test accuracy: %.4f%%' % accuracy)

评估模型

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['accuracy'], 'o-')
ax.plot([None] + hist.history['val_accuracy'], 'x-')
# 绘制图例并自动使用最佳位置: loc = 0。
ax.legend(['Train acc', 'Validation acc'], loc = 0)
ax.set_title('Training/Validation acc per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('acc')
plt.show()

import matplotlib.pyplot as pltf, ax = plt.subplots()
ax.plot([None] + hist.history['loss'], 'o-')
ax.plot([None] + hist.history['val_loss'], 'x-')# Plot legend and use the best location automatically: loc = 0.
ax.legend(['Train loss', "Val loss"], loc = 0)
ax.set_title('Training/Validation Loss per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/259150.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

将项目代码上传到github

文章目录 1. 上传步骤1.1. 设置保存项目代码的文件夹1.2. 打开git1.3. 连接到github仓库1.4. 将本地文件上传到github 附录. git 常用命令 摘要&#xff1a;该文章主要从上传代码步骤讲起&#xff0c;关于git下载和其环境配置没有涉及到。 1. 上传步骤 1.1. 设置保存项目代码…

Vue脚手架 生命周期 组件化开发

Vue脚手架 & 生命周期 & 组件化开发 一、今日目标 1.生命周期 生命周期介绍生命周期的四个阶段生命周期钩子声明周期案例 2.综合案例-小黑记账清单 列表渲染添加/删除饼图渲染 3.工程化开发入门 工程化开发和脚手架项目运行流程组件化组件注册 4.综合案例-小兔…

【数据结构】链表OJ题(顺序表)(C语言实现)

✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅✅ ✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨✨ &#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1f33f;&#x1…

Vue 一个简单的mixin的运用,对mixin的初步了解

刚学vue的时候&#xff0c;从一个大神口中老是说什么混入混入&#xff0c;觉得很神秘&#xff0c;后来一了解&#xff0c;原来如此&#xff1a; 其实从字面意思来理解&#xff0c;就是将代码里面的内容混在一起了&#xff0c;上一段代码可能比较好理解一点。 先定义一个简单混…

视频剪辑:视频转码实用技巧,批量将MP4转为MP3音频

随着数字媒体设备的普及&#xff0c;视频和音频文件已成为日常生活中的重要组成部分。有时&#xff0c;可能要将MP4视频文件转换为MP3音频文件&#xff0c;以提取其中的音频内容或者进行其他处理。这是耗费时间的任务&#xff0c;那要如何操作呢&#xff1f;本文详解云炫AI智剪…

2023年7月31日 Go生态洞察:探索项目模板实验

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

Flink入门之核心概念(三)

任务槽 TaskSlots: 任务槽&#xff0c;是TaskManager提供的用于执行Task的资源&#xff08;CPU 内存&#xff09; TaskManager提供的TaskSlots的个数&#xff1a;主要由Taskmanager所在机器的CPU核心数来决定&#xff0c;不能超过CPU的最大核心数 1.可以在flink/conf/flink-c…

Python configparser 模块:优雅处理配置文件的得力工具

更多资料获取 &#x1f4da; 个人网站&#xff1a;ipengtao.com 配置文件在软件开发中扮演着重要的角色&#xff0c;而Python中的 configparser 模块提供了一种优雅而灵活的方式来处理各种配置需求。本文将深入介绍 configparser 模块的各个方面&#xff0c;通过丰富的示例代码…

Power BI - 5分钟学习透视列

每天5分钟&#xff0c;今天介绍Power BI透视列功能 什么是透视列&#xff1f; 透视列就是把行数据转换成列数据&#xff0c;也就是大家在工作中常说的行转列。 如何进行逆透视操作&#xff1a; 1&#xff0c;导入的【Sales】表&#xff0c;样例内容如下&#xff1a; 2, 【Ho…

20 套监控平台统一成 1 套 Flashcat,国泰君安监控选型提效之路

author:宋庆羽-国泰君安期货 运维工作最重要的就是维护系统的稳定性&#xff0c;其中监控是保证系统稳定性很重要的一环。通过监控可以了解系统的运行状态&#xff0c;及时发现问题和系统隐患&#xff0c;有助于一线人员快速解决问题&#xff0c;提高业务系统的可用时长。 作为…

Javaweb之 IDEA集成Maven的详细解析

03. IDEA集成Maven 我们要想在IDEA中使用Maven进行项目构建&#xff0c;就需要在IDEA中集成Maven 3.1 配置Maven环境 3.1.1 当前工程设置 1、选择 IDEA中 File > Settings > Build,Execution,Deployment > Build Tools > Maven 2、设置IDEA使用本地安装的Maven…

多线程(初阶九:线程池)

目录 一、线程池的由来 二、线程池的简单介绍 1、ThreadPoolExecutor类 &#xff08;1&#xff09;核心线程数和最大线程数&#xff1a; &#xff08;2&#xff09;保持存活时间和存活时间的单位 &#xff08;3&#xff09;放任务的队列 &#xff08;4&#xff09;线程工…