MLP实现fashion_mnist数据集分类(1)-模型构建、训练、保存与加载(tensorflow)

1、查看tensorflow版本

import tensorflow as tfprint('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、fashion_mnist数据集下载与展示

(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()
print(train_image.shape)
print(train_label.shape)
print(test_image.shape)
print(test_label.shape)

在这里插入图片描述

import matplotlib.pyplot as plt
# plt.imshow(train_image[0])  # 此处为啥是彩色的?def plot_images_lables(images,labels,start_idx,num=5):fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()
plot_images_lables(train_image,train_label,0,5)
# plot_images_lables(test_image,test_label,0,5)

在这里插入图片描述

3、数据预处理

X_train,X_test = tf.cast(train_image/255.0,tf.float32),tf.cast(test_image/255.0,tf.float32) # 归一化
y_train,y_test = train_label,test_label # 此处对y没有做onehot处理,需要使用稀疏交叉损失函数

4、模型构建

from keras import Sequential
from keras.layers import Flatten,Dense,Dropout
from keras import Inputmodel = Sequential()
model.add(Input(shape=(28,28)))
model.add(Flatten())
model.add(Dense(units=256,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=64,kernel_initializer='normal',activation='relu'))
model.add(Dropout(rate=0.1))
model.add(Dense(units=10,kernel_initializer='normal',activation='softmax'))
model.summary()

在这里插入图片描述

5、模型配置

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

6、模型训练

H = model.fit(x=X_train,y=y_train,validation_split=0.2,# validation_data=(X_test,y_test),epochs=10,batch_size=128,verbose=1)

在这里插入图片描述

plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()

在这里插入图片描述

plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()

在这里插入图片描述

7、模型评估

model.evaluate(X_test,y_test)

在这里插入图片描述

8、模型预测

import numpy as np
import matplotlib.pyplot as pltdef pred_plot_images_lables(images,labels,start_idx,num=5):# 预测res = model.predict(images[start_idx:start_idx+num])res = np.argmax(res,axis=1)# 画图fig = plt.gcf()fig.set_size_inches(12,14)for i in range(num):ax = plt.subplot(1,num,1+i)ax.imshow(images[start_idx+i],cmap='binary')title = 'label=' + str(labels[start_idx+i]) + ', pred=' + str(res[i])ax.set_title(title,fontsize=10)ax.set_xticks([])ax.set_yticks([])plt.show()
pred_plot_images_lables(X_test,y_test,0,5)

在这里插入图片描述

9、模型保存与加载

import numpy as nptf.keras.models.save_model(model,"model.keras")
loaded_model = tf.keras.models.load_model("model.keras")
# assert np.allclose(model.predict(X_test[:5]), loaded_model.predict(X_test[:5]))
print(np.argmax(model.predict(X_test[:5]),axis=1))
print(np.argmax(loaded_model.predict(X_test[:5]),axis=1))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/670954.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JS基础:变量的详解

你好,我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃,大专生,一枚程序媛,感谢关注。回复 “前端基础题”,可免费获得前端基础 100 题汇总,回复 “前端基础路线”,可获取…

数据库SQL语言实战(七)

前言 这次的有一点点难~~~~~我也写了好久 练习题 题目一 在学生表pub.student中统计名字(姓名的第一位是姓氏,其余为名字,不考虑复姓)的使用的频率,将统计结果放入表test5_01中 create table test5_01(First_name…

Docker容器:Docker-Consul 的容器服务更新与发现

目录 前言 一、什么是服务注册与发现 二、 Docker-Consul 概述 1、Consul 概念 2、Consul 提供的一些关键特性 3、Consul 的优缺点 4、传统模式与自动发现注册模式的区别 4.1 传统模式 4.2 自动发现注册模式 5、Consul 核心组件 5.1 Consul-Template组件 5.2 Consu…

esp32+mqtt协议+paltformio+vscode+微信小程序+温湿度检测

花费两天时间完成了这个项目(不完全是,属于是在resnet模型训练和温湿度检测两头跑......模型跑不出来,又是第一次从头到尾独立玩硬件,属于是焦头烂额了......,完成这个项目后,我的第一反应是写个csdn&#…

如何delphi7中添加自定义的控件或组件

1.问题 为了编写方便,会自己生成一些自定义的控件。 XXX.PAS:源码文件,后续可以再调整 XXX.DCU:编译后的文件 2.解决办法 1.自定义控件的文件如下 2.将这些文件保留在程式代码中,万一丢失会导致程式开启后报是否忽略 3.选择菜单栏中的com…

简述前后端分离架构案例

Hello , 这里是小恒不会java 。今晚1点写写关于RESTful接口的使用案例,本文会通过django原生js前后端分离的案例简单讲解。本文带你认识一下简化版的前后端分离架构 代码 本文案例代码在GitHub上 https://github.com/lmliheng/fontend前后端分离 先说说什么是前后…

ue引擎游戏开发笔记(32)——为游戏添加新武器装备

1.需求分析: 游戏中角色不会只有一种武器,不同武器需要不同模型,甚至可能需要角色持握武器的不同位置,因此需要添加专门的武器类,方便武器后续更新,建立一个武器类。 2.操作实现: 1.在ue5中新建…

cmake进阶:文件操作

一. 简介 前面几篇文章学习了 cmake的文件操作,写文件,读文件。文章如下: cmake进阶:文件操作之写文件-CSDN博客 cmake进阶:文件操作之读文件-CSDN博客 本文继续学习文件操作。主要学习 文件重命名,删…

深入了解C/C++的内存区域划分

🔥个人主页:北辰水墨 🔥专栏:C学习仓 本节我们来讲解C/C的内存区域划分,文末会附加一道题目来检验成果(有参考答案) 一、大体有哪些区域?分别存放什么变量开辟的空间? …

深度神经网络中的不确定性研究综述

A.单一确定性方法 对于确定性神经网络,参数是确定的,每次向前传递的重复都会产生相同的结果。对于不确定性量化的单一确定性网络方法,我们总结了在确定性网络中基于单一正向传递计算预测y *的不确定性的所有方法。在文献中,可以找…

Fireworks AI和MongoDB:依托您的数据,借助优质模型,助力您开发高速AI应用

我们欣然宣布 MongoDB与 Fireworks AI 正携手合作 让客户能够利用生成式人工智能 (AI) 更快速、更高效、更安全地开展创新活动 Fireworks AI由 Meta旗下 PyTorch团队的行业资深人士于 2022 年底创立,他们在团队中主要负责优化性能、提升开发者体验以及大规模运…

【数据可视化-02】Seaborn图形实战宝典

Seaborn介绍 Seaborn是一个基于Python的数据可视化库,它建立在matplotlib的基础之上,为统计数据的可视化提供了高级接口。Seaborn通过简洁美观的默认样式和绘图类型,使数据可视化变得更加简单和直观。它特别适用于那些想要创建具有吸引力且信…