TensorFlow案例学习:对服装图像进行分类

前言

官方为我们提供了一个 对服装图像进行分类 的案例,方便我们快速学习

建议按顺序观看,这是一个小系列,适合像我这样的初学者入门

配置环境:windows环境下tensorflow安装

图片分类案例学习:TensorFlow案例学习:对服装图像进行分类

使用官方模型,并进行微调:TensorFlow学习:使用官方模型进行图像分类、使用自己的数据对模型进行微调

将模型转换,在前端使用:TensorFlow学习:在web前端如何使用Keras 模型

学习

预处理数据

案例中有下面这段代码

# 预处理数据,检查训练集中的第一个图像可以看到像素值处于0~255之间
plt.figure() # 创建图像窗口
plt.imshow(train_images[0]) # 显示图片
plt.colorbar()  # 在图像旁边添加颜色条
plt.grid(False) # 取消网格线
plt.show() # 显示图形窗口# 将值缩小至0~1之间,然后将其反馈到神经网络模型。训练集和测试集都需要处理
train_images = train_images / 255.0
test_images = test_images / 255.0

在这里插入图片描述

百度查了一下,将值缩小至0~1之间是为了

将训练集和测试集数据的值缩小到0~1之间是为了进行数据归一化(Normalization)。这是一个常见的预处理步骤,对于图像分类任务特别重要。
将图像的像素值缩放到0~1之间有几个好处:

  • 数值范围一致性:将所有像素值限制在0~1范围内可以确保不同样本的特征具有一致的数值区间。这有助于避免某些特征对模型训练产生过大的影响。
  • 梯度下降稳定性:在深度学习中,常用的优化算法如梯度下降依赖于权重的更新和损失函数的梯度计算。将像素值缩小到较小的范围可以使这些计算更加稳定,有助于加速模型的收敛。
  • 避免数值溢出:在一些激活函数和优化算法中,如果输入值太大,可能会导致数值溢出或不稳定的情况。将像素值限制在0~1之间可以减少这种情况的发生。

以后再遇见处理255时就明白这样做的目的了

构建模型

构建神经网络需要先配置模型的层,然后再编译模型。

设置层
神经网络的基本组成部分是层。层会从向其馈送的数据中提取表示形式。希望这些表示形式有助于解决手头上的问题。

大多数深度学习都包括将简单的层链接在一起。大多数层(如 tf.keras.layers.Dense)都具有在训练期间才会学习的参数

# 1、设置层
# tf.keras是TensorFlow中的高级API,用于构建和训练神经网络模型。它是一个基于Keras库的接口,提供了更简单、更高级的方式来定义、配置和训练神经网络模型。
# tf.keras.Sequential 用于按顺序堆叠各个神经网络层来构建模型,是一种简单的模型类型
model = tf.keras.Sequential([# 将图像格式从二维数组(28*28像素),转化为一维数组(28*28 = 784像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。tf.keras.layers.Flatten(input_shape=(28,28)), # 第二层,是一个具有128个神经元的全连接神经层tf.keras.layers.Dense(128,activation='relu'),# 第三层会返回一个长度为10的数组,每个都包含一个得分来表示当前图像属于10个类中的哪一个tf.keras.layers.Dense(10)
])

这段代码我相信很多人跟我一样都有些疑问,还好现在有gpt,不然都不知道上哪里去找答案。下面是我的一些疑问及gpt的回答:

  • 为什么只有三层。答:在神经网络中,层数的选择是一个灵活的设计选择,取决于特定问题的复杂性和数据集的特征。选择三层可能是为了简化模型或者问题本身不需要更多层
  • 第二层为什么是tf.keras.layers.Dense(128)。答:选择128个神经元是基于对问题复杂性的估计和经验。如果问题比较复杂或数据集较大,增加神经元数量可以增加模型的容量,提高模型的表示能力。
  • 第三层为什么是tf.keras.layers.Dense(10)。答:因为这是一个分类问题,这个案例中有10个分类。每个神经元对应一个类别,并输出相应类别的预测概率。
  • tf.keras.layers.Dense(128)是计算的来的吗。答:通常需要根据实际问题和数据集来进行调整。增加神经元的数量可以增加模型的容量和学习能力,但也可能导致过拟合。过拟合是指模型在训练数据上表现良好,但在新数据上表现较差。建议先从较小的数量开始,然后逐渐增加,直到模型的性能不再提高或开始出现过拟合为止。
  • 模型的最后一层是输出层吗。答:模型的最后一层通常是输出层。输出层的神经元数量通常与你要解决的问题相关。对于分类任务,输出层的神经元数量应该等于类别的数量。对于二分类任务,可以使用一个神经元来表示两个类别的概率。对于多分类任务,可以使用多个神经元,每个神经元表示一个类别的概率。在使用tf.keras``构建模型时,你可以使用tf.keras.layers.Dense`来定义输出层,并使用适当的激活函数来产生输出。

编译模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数 - 测量模型在训练期间的准确程度。你希望最小化此函数,以便将模型“引导”到正确的方向上。
  • 优化器 - 决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标 - 用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
# 2、编译模型
model.compile(optimizer='adam', # 指定优化器,adam是常用的优化器,可以自适应的调整学习率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 指定损失函数,这里使用了稀疏分类交叉熵损失函数metrics=['accuracy'] # 指定评估模型性能的指标,这里使用准确率
)

训练模型

训练神经网络模型需要执行以下步骤:

  • 将训练数据馈送给模型。在本例中,训练数据位于 train_images 和 train_labels 数组中。
  • 模型学习将图像和标签关联起来。
  • 要求模型对测试集(在本例中为 test_images 数组)进行预测。
  • 验证预测是否与 test_labels 数组中的标签相匹配。
# 1、将训练数据反馈给模型
# model.fit用于将模型与训练数据进行拟合,这里是将所有样本迭代10次
model.fit(train_images,train_labels,epochs=10)

如下图:
在这里插入图片描述

# 2、在测试数据集上评估准确率,verbose=2参数表示以详细模式输出评估过程
test_loss,test_acc = model.evaluate(test_images,test_labels,verbose=2)
print("损失率:",test_loss,"准确率:",test_acc)

如下图:
在这里插入图片描述

进行预测

# 进行预测
# 模型经过训练后,您可以使用它对一些图像进行预测。附加一个 Softmax 层,将模型的线性输出 logits 转换成更容易理解的概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
# 预测图片
predictions = probability_model.predict(test_images)print("第一个预测结果:",predictions[0])

预测结果是一个包含 10 个数字的数组。它们代表模型对 10 种不同服装中每种服装的“置信度”。您可以看到哪个标签的置信度值最大:

np.argmax(predictions[0])

使用训练好的模型

现在模型已经训练好了,我们可以基于模型对单个图像进行预测

# 使用训练好的模型
# 加载图片
img = Image.open('pics/shirt.png') 
# 调整大小
img = img.resize((28,28))
# 将彩色图片转为灰度图片
img_gray = img.convert('L')
# 将图像转换为 NumPy 数组,并反转颜色
img_arr = np.array(img_gray)
img_arr = 255 - img_arr
# 将图像像素值归一化到0~1
img_arr = img_arr / 255.0
# 将图像形状调整为(128288)
img_arr = img_arr.reshape(1,28,28)
# 可以保存处理后的文件,也可以进行预测
# np.save('abc.npy',img_arr)
# tf.keras 模型经过了优化,可同时对一个批或一组样本进行预测。因此,即便您只使用一个图像,您也需要将其添加到列表中
#img_arr = tf.keras.preprocessing.image.img_to_array(img)res = probability_model.predict(img_arr)
print("预测结果是:",res,class_names[np.argmax(res[0])])# 可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(img_arr[0]) # 显示图片
plt.xlabel(class_names[np.argmax(res[0])],fontproperties=font)
plt.show() # 显示图形窗口

这块是最复杂的,搞了好久才成功。你加载的图片是彩色的,你必须将图片变成灰度的,并且是28*28像素的图片,也就是你的图片要处理成符合这个模型的图片才行。

但是最终结果其实也不是很准确,根本原因是你的图片处理后,能够获取的特征就很少了,这样会导致判断错误。

结果
在这里插入图片描述

遇到的问题

问题1
在执行(train_images, train_labels), (test_images,test_labels) = fashion_mnist.load_data()时提示

Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz: None – [WinError 10054] 远程主机强迫关闭了 一个现有的连接。

这是加载数据集时失败了,国内访问下载谷歌的数据总会出现这样的问题。

解决:
1、打开数据集官方网站 https://github.com/zalandoresearch/fashion-mnist,将下面这4个数据下载到本地放到项目里

在这里插入图片描述
2、加载本地数据

import gzip
import numpy as npdef load_data():# 加载训练集图像数据with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载训练集标签数据with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:train_labels = np.frombuffer(f.read(), np.uint8, offset=8)# 加载测试集图像数据with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载测试集标签数据with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:test_labels = np.frombuffer(f.read(), np.uint8, offset=8)return (train_images, train_labels), (test_images, test_labels)# 调用加载数据函数
(train_images, train_labels), (test_images, test_labels) = load_data()

问题2
验证前25个图像,设置中文乱码。教程中的使用的是英文,我这里尝试了一下中文,中文乱码
在这里插入图片描述
解决:设置中文字体

# 字体属性
from matplotlib.font_manager import FontProperties# 验证训练集中的前25个图像,并显示其名称
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1) # 按照 5*5进行显示plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i]],fontproperties=font)
plt.show()

在这里插入图片描述

完整代码

# 导入 TensorFlow 重命名
import tensorflow as tf# numpy是科学计算库,matplotlib是用于绘制图表和可视化数据的库
import numpy as np
import matplotlib.pylab as plt
# 字体属性
from matplotlib.font_manager import FontProperties# 用于加载文件
import gzip# 用于处理图片
from PIL import Image# 用于加载数据集的函数
def load_data():# 加载训练集图像数据with gzip.open('train-images-idx3-ubyte.gz', 'rb') as f:train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载训练集标签数据with gzip.open('train-labels-idx1-ubyte.gz', 'rb') as f:train_labels = np.frombuffer(f.read(), np.uint8, offset=8)# 加载测试集图像数据with gzip.open('t10k-images-idx3-ubyte.gz', 'rb') as f:test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)# 加载测试集标签数据with gzip.open('t10k-labels-idx1-ubyte.gz', 'rb') as f:test_labels = np.frombuffer(f.read(), np.uint8, offset=8)return (train_images, train_labels), (test_images, test_labels)print("tf版本:",tf.__version__)# 导入数据集,TensorFlow 内置的数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
# 将训练数据、测试数据取出,保存的元组里
(train_images, train_labels), (test_images,test_labels) = load_data()# 映射标签类,用于后面绘制图像使用
class_names = ['T恤/上衣', '裤子', '套头衫', '连衣裙', '外套', '凉鞋', '衬衫', '运动鞋', '包', '短靴']# 会打印出(60000, 28, 28),官方文档解释为训练集中有 60,000 个图像,每个图像由 28 x 28 的像素表示
print("训练数据集数据:",train_images.shape)# 预处理数据,检查训练集中的第一个图像可以看到像素值处于0~255之间
# plt.figure() # 创建图像窗口
# plt.imshow(train_images[0]) # 显示图片
# plt.colorbar()  # 在图像旁边添加颜色条
# plt.grid(False) # 取消网格线
# plt.show() # 显示图形窗口# 将值缩小至0~1之间,然后将其反馈到神经网络模型。训练集和测试集都需要处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 验证训练集中的前25个图像,并显示其名称
# font = FontProperties()
# font.set_family('Microsoft YaHei')
# plt.figure(figsize=(10,10))
# for i in range(25):
#     plt.subplot(5,5,i+1) # 按照 5*5进行显示
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(train_images[i], cmap=plt.cm.binary)
#     plt.xlabel(class_names[train_labels[i]],fontproperties=font)
# plt.show()# 构建模型# 1、设置层
# tf.keras是TensorFlow中的高级API,用于构建和训练神经网络模型。它是一个基于Keras库的接口,提供了更简单、更高级的方式来定义、配置和训练神经网络模型。
# tf.keras.Sequential 用于按顺序堆叠各个神经网络层来构建模型,是一种简单的模型类型
model = tf.keras.Sequential([# 将图像格式从二维数组(28*28像素),转化为一维数组(28*28 = 784像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。tf.keras.layers.Flatten(input_shape=(28,28)), # 第二层,是一个具有128个神经元的全连接神经层tf.keras.layers.Dense(128,activation='relu'),# 第三层会返回一个长度为10的数组,每个都包含一个得分来表示当前图像属于10个类中的哪一个tf.keras.layers.Dense(10)
])# 2、编译模型
model.compile(optimizer='adam', # 指定优化器,adam是常用的优化器,可以自适应的调整学习率loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), # 指定损失函数,这里使用了稀疏分类交叉熵损失函数metrics=['accuracy'] # 指定评估模型性能的指标,这里使用准确率
)# 训练模型# 1、将训练数据反馈给模型
# model.fit用于将模型与训练数据进行拟合,这里是将所有样本迭代10次
model.fit(train_images,train_labels,epochs=10)# 2、在测试数据集上评估准确率,verbose=2参数表示以详细模式输出评估过程
test_loss,test_acc = model.evaluate(test_images,test_labels,verbose=2)
print("损失率:",test_loss,"准确率:",test_acc)# 进行预测
# 模型经过训练后,您可以使用它对一些图像进行预测。附加一个 Softmax 层,将模型的线性输出 logits 转换成更容易理解的概率
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])
# 预测图片
predictions = probability_model.predict(test_images)print("第一个预测结果:",predictions[0],'类别是:',class_names[np.argmax(predictions[0])])# 使用训练好的模型
# 加载图片
img = Image.open('pics/shirt.png') 
# 调整大小
img = img.resize((28,28))
# 将彩色图片转为灰度图片
img_gray = img.convert('L')
# 将图像转换为 NumPy 数组,并反转颜色
img_arr = np.array(img_gray)
img_arr = 255 - img_arr
# 将图像像素值归一化到0~1
img_arr = img_arr / 255.0
# 将图像形状调整为(128288)
img_arr = img_arr.reshape(1,28,28)
# 可以保存处理后的文件,也可以进行预测
# np.save('abc.npy',img_arr)res = probability_model.predict(img_arr)
print("预测结果是:",res,class_names[np.argmax(res[0])])# 可视化显示
font = FontProperties()
font.set_family('Microsoft YaHei')
plt.figure() # 创建图像窗口
plt.xticks([])
plt.yticks([])
plt.grid(False) # 取消网格线
plt.imshow(img_arr[0]) # 显示图片
plt.xlabel(class_names[np.argmax(res[0])],fontproperties=font)
plt.show() # 显示图形窗口

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/132043.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

沪深300期权一个点多少钱?

经中国证监会批准,深圳证券交易所于2019年12月23日上市嘉实沪深300ETF期权合约品种。该产品是以沪深300为标的物的嘉实沪深300ETF交易型指数基金为标的衍生的标准化合约,下文介绍沪深300期权一个点多少钱?本文来自:期权酱 一、沪深300期权涨…

推荐几个技术学习的网站

USB中文网 点击打开 USB中文网 - USB技术开发交流USB中文网是国内领先的专业USB技术网站,提供USB开发入门教程,USB设备开发,USB驱动开发,USB摄像头,USB麦克风,USB存储设备,USB-HID设备&#x…

华为认证 | HCIP-Datacom,这门认证正式发布新版本!

华为认证数通高级工程师HCIP-Datacom-Campus Network Planning and Deployment V1.5(中文版)自2023年9月28日起,正式在中国区发布。 01 发布概述 基于“平台生态”战略,围绕“云-管-端”协同的新ICT技术架构,华为公司…

随着 ChatGPT 凭借 GPT-4V(ision) 获得关注,多模态 AI 不断发展

原创 | 文 BFT机器人 在不断努力让人工智能更像人类的过程中,OpenAI的GPT模型不断突破界限GPT-4现在能够接受文本和图像的提示。 生成式人工智能中的多模态表示模型根据输入生成文本、图像或音频等各种输出的能力。这些模型经过特定数据的训练,学习底层模…

Vega Prime入门教程14.04:CDB测试

本文首发于:Vega Prime入门教程14.04:CDB测试 打开失败 打开vpcdb_yemen_urban.acf 会报错 点击确定后会显示默认界面 这个白天蓝海应该是默认场景。 开启服务 打开LP,点击菜单栏或者工具栏 显示管理界面 切换至rtp界面 点击Start RTP按…

牛客 明明的随机数

HJ3 明明的随机数 原题思路代码运行截图收获 原题 HJ3 明明的随机数 思路 如果是C的话直接用set结构体就可以自动排序GO&#xff1a;用一个501的数组存储是否出现&#xff0c;最后从头开始输出出现过的数字 代码 #include <iostream> #include <set> using na…

html与css知识点

html 元素分类 块级元素 1.独占一行&#xff0c;宽度为父元素宽度的100% 2.可以设置宽高 常见块级元素 h1~h6 div ul ol li dl dt dd table form header footer section nav article aside 行内元素 1.一行显示多个 2.不能设置宽高&#xff0c;宽高由元素内容撑开 常见行内…

【Linux服务端搭建及使用】

连接服务器的软件&#xff1a;mobaxterm 设置root 账号 sudo apt-get install passwd #安装passwd 设置方法 sudo passwd #设置root密码 su root #切换到root账户设置共享文件夹 一、强制删除原有环境 1.删除python rpm -qa|grep pytho…

SLAM从入门到精通(ROS和底盘Stm32的关系)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 学过Ros的同学&#xff0c;一般对subscribe、publish、话题、服务这些内容都比较熟悉。如果再熟悉一点的话&#xff0c;还会知道slam、move_base、…

macOS Sonoma 14.1beta3(23B5067a)发布

黑果魏叔10 月 11 日消息&#xff0c;苹果今日向 Mac 电脑用户推送了 macOS 14.1 开发者预览版 Beta 3 更新&#xff08;内部版本号&#xff1a;23B5067a&#xff09;&#xff0c;本次更新距离上次发布隔了 7 天。 根据官方发布的macOS Sonoma 14.1beta3更新日志&#xff0c;此…

Ubuntu 18.04 OpenCV3.4.5 + OpenCV3.4.5 Contrib 编译

目录 1 依赖安装 2 下载opencv3.4.5及opencv3.4.5 contrib版本 3 编译opencv3.4.5 opencv3.4.5_contrib及遇到的问题 1 依赖安装 首先安装编译工具CMake&#xff0c;命令安装即可&#xff1a; sudo apt install cmake 安装Eigen&#xff1a; sudo apt-get install libeigen3-…

Flink---14、Flink SQL(SQL-Client准备、流处理中的表、时间属性、DDL)

星光下的赶路人star的个人主页 你生而真实&#xff0c;而非完美 文章目录 1、Flink SQL1.1 SQL-Client准备1.1.1 基于yarn-session模式1.1.2 常用配置 1.2 流处理中的表1.2.1 动态表和持续查询1.2.2 将流转换为动态表1.2.3 用SQL持续查询1.2.4 将动态表转换为流 1.3 时间属性1.…