训练 CNN 对 CIFAR-10 数据中的图像进行分类-keras实现

1. 加载 CIFAR-10 数据库

import keras
from keras.datasets import cifar10# 加载预先处理的训练数据和测试数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 可视化前 24 个训练图像

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inlinefig = plt.figure(figsize=(20,5))
for i in range(36):ax = fig.add_subplot(3, 12, i + 1, xticks=[], yticks=[])ax.imshow(np.squeeze(x_train[i]))

3. 通过将每幅图像中的每个像素除以 255 来调整图像比例

事实上,代价函数的形状是一个碗,但如果特征的比例非常不同,它也可能是一个拉长的碗。下图显示了梯度下降法在特征 1 和特征 2 比例相同的训练集上的应用(左图),以及在特征 1 的值远小于特征 2 的训练集上的应用(右图)。

Tips : 使用梯度下降法时,应确保所有特征的比例相似,以加快训练速度,否则收敛时间会更长。

# rescale [0,255] --> [0,1]
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

 4. 将数据集分为训练集、测试集和验证集

from keras.utils import to_categorical# 对标签进行一次热编码
num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)# 将训练集分为训练集和验证集
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]# 打印训练集的形状
print('x_train shape:', x_train.shape)# 打印训练、验证和测试图像的数量
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

5. 定义模型架构 

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropoutmodel = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))model.summary()

6. 编译模型 

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

7. 训练模型

from keras.callbacks import ModelCheckpoint   # 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)hist = model.fit(x_train, y_train, batch_size=32, epochs=100,validation_data=(x_valid, y_valid), callbacks=[checkpointer], verbose=2, shuffle=True)

8. 加载验证精度最高的模型

# 加载验证精度最高的权重
model.load_weights('model.weights.best.hdf5')

 9. 计算测试集的分类精度

# 评估和打印测试精度
score = model.evaluate(x_test, y_test, verbose=0)
print('\n', 'Test accuracy:', score[1])

10. 可视化一些预测

这可能会让你对网络错误分类某些对象的原因有所了解。

# 在测试集上得到预测
y_hat = model.predict(x_test)# 定义文本标签 (source: https://www.cs.toronto.edu/~kriz/cifar.html)
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 绘制测试图像的随机样本、预测标签和基本真实图像
fig = plt.figure(figsize=(20, 8))
for i, idx in enumerate(np.random.choice(x_test.shape[0], size=32, replace=False)):ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])ax.imshow(np.squeeze(x_test[idx]))pred_idx = np.argmax(y_hat[idx])true_idx = np.argmax(y_test[idx])ax.set_title("{} ({})".format(cifar10_labels[pred_idx], cifar10_labels[true_idx]),color=("green" if pred_idx == true_idx else "red"))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/236595.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录第二十一天(一刷C语言)|回溯算法组合

创作目的:为了方便自己后续复习重点,以及养成写博客的习惯。 一、回溯算法 1、种类 排列、组合、分割、子集、棋盘问题 2、回溯步骤 (0)回溯抽象 回溯法解决的问题均可以抽象为树形结构(N叉树) &…

plist文件在线生成网页配置苹果ios系统ipa文件下载

您可以进入首页—工具箱—plist文件在线制作 您可以进入控制台—plist文件 ●也可以直接访问:咕噜分发内测平台-苹果ios系统应用安卓apk安全漏洞扫描提供商 ●应用名称_包体的bid-下载地址-图标地址 ●如果不知道怎么查看苹果包名 可以通过咕噜分发【工具箱】-【IOS…

基于AT89C51单片机四位加法计算器的设计

1.设计任务 利用AT89C51单片机为核心控制元件,设计一个四位加法计算器,设计的系统实用性强、操作简单,实现了智能化、数字化。 1)、通过4*4矩阵键盘输入数字及运算符; 2)、可以进行4位十进制数以内的加法…

MyEclipse控制台console不停的自动跳动控制台界面,解决方案

有时候Eclipse启动,控制台console不会自动跳出来,需要手工点击该选项卡才行,按下面的设置,可以让它自动跳出来(或不跳出来):方法: 一、windows -> preferences -> run/debug -> consol…

【shell】正则表达式和AWK

一.正则表达式 通配符匹配文件(而且是已存在的文件) 基本正则表达式扩展正则表达式 可以使用 man 手册帮助 正则表达式:匹配的是文章中的字符 通配符:匹配的是文件名 任意单个字符 1.元字符(字符匹配&…

API管理:smart-doc 与 新版 torna 集成

使用 docker-compose 搭建 torna 环境 torna 介绍 项目地址:https://gitee.com/durcframework/torna torna 是一个企业接口文档解决方案,目标是让文档管理变得更加方便、快捷。Torna采用团队协作的方式管理和维护项目API文档,将不同形式的文…

发现一个好用的搜索引擎【非凡搜索】无广告

# 非凡搜索,小众、无广告、简洁。号称“不收集、不传播任何个人信息”。 非凡搜索 作为一个程序员,经常需要搜索一下技术资料,在国内大部分搜索引擎中搜索时,往往夹带各种广告或目的不明确的结果,这个搜索引擎无疑是一…

chrome vue devTools安装

安装好后如下图所示: 一:下载vue devTools 下载链接https://download.csdn.net/download/weixin_44659458/13192207?spm1001.2101.3001.6661.1&utm_mediumdistribute.pc_relevant_t0.none-task-download-2%7Edefault%7ECTRLIST%7EPaid-1-13192207…

Clickhouse Join

ClickHouse中的Hash Join, Parallel Hash Join, Grace Hash Join https://www.cnblogs.com/abclife/p/17579883.html https://clickhouse.com/blog/clickhouse-fully-supports-joins-full-sort-partial-merge-part3 总结 本文描述并比较了ClickHouse中基于内存哈希表的3种连接…

申请免费的域名SSL证书

1.,选择SSL证书提供商 首先,您需要选择一个提供免费SSL证书的可信赖服务提供商。比如JoySSL就提供了永久免费版本的SSL证书。 2, 注册账户并登录 在选择了JoySSL,您需要在网站上注册一个账户。完成注册后,使用您的账…

uniapp微信小程序实现地图展示控件

最终实现效果: 地图上展示控件,并可以点击。 目录 一、前言 二、在地图上展示控件信息 点击后可进行绘制面图形 1.使用cover-view将控件在地图上展示 2.设置控件样式,使用好看的图标 3.控件绑定点击事件 一、前言 原本使用的是control…

Linux:服务器管理工具宝塔(bt)安装教程

一、简介 bt宝塔Linux面板是提升运维效率的服务器管理软件,支持一键LAMP/LNMP/集群/监控/网站/FTP/数据库/JAVA等多项服务的管理功能 二、安装 使用 SSH 连接工具,如堡塔SSH终端连接到您的 Linux 服务器后,挂载磁盘,根据系统执…