Tensorflow2.0笔记 - 使用compile,fit,evaluate,predict简化流程

        本笔记主要用compile, fit, evalutate和predict来简化整体代码,使用这些高层API可以减少很多重复代码。具体内容请自行百度,本笔记基于FashionMnist的训练笔记,原始笔记如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读347次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502        

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)test_db_iter = iter(test_db)
sample = next(test_db_iter)
print(sample[0].shape)
print(sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])
print("\n=====Building Model=========\n")
model.build(input_shape=(None, 28*28))
model.summary()total_epoches = 10
learn_rate = 0.01
#编译网络,使用compile(),传入优化器,损失函数和metrics度量
#https://blog.csdn.net/weixin_48169169/article/details/120793534
print("\n=====Compiling Model=========\n")
model.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['Accuracy'])
#使用fit()进行训练
print("\n=====Fitting Model=========\n")
model.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=2)
time.sleep(1)
#使用evaluate()进行模型验证,这里使用了test_db,实际中可以使用另外的数据集
print("\n=====Evaluating Model=========\n")
model.evaluate(test_db)
#使用predict()做数据预测
sample = next(iter(test_db))
#获得数据和标签
real_data = sample[0]
real_label = sample[1]
pred = model.predict(real_data)
pred = tf.argmax(pred, axis=1)
print("Predicted Labels:", pred.numpy())
print("Actual Labels:", real_label.numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/572616.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Tunes不能读取iPhone的内容,请前往iPhone偏好设置的摘要选项卡,然后单击恢复以将此iPhone恢复为出厂设置

重启itunes: 参考链接: https://baijiahao.baidu.com/s?id1642568736254330322&wfrspider&forpc 人工智能学习网站: https://chat.xutongbao.top

实例:NX二次开发使用链表进行拉伸功能(链表相关功能练习)

一、概述 在进行批量操作时经常会利用链表进行存放相应特征的TAG值,以便后续操作,最常见的就是拉伸功能。这里我们以拉伸功能为例子进行说明。 二、常用链表相关函数 UF_MODL_create_list 创建一个链表,并返回链表的头指针。…

uniapp 微信小程序 canvas 手写板文字重复倾斜水印

核心逻辑 先将坐标系中心点通过ctx.translate(canvasw / 2, canvash / 2) 平移到canvas 中心,再旋转设置水印 假如不 translate 直接旋转,则此时的旋转中心为左上角原点,此时旋转示意如图所示 当translate到中心点之后再旋转,此…

图书推荐|Django+Vue.js商城项目实战

一线资深架构师 凝聚近十年大型系统开发经验 倾力打造 双色印刷 适合:项目演练求职应聘技术提升 全新:Django 4.x与Vue.js 3.x全栈技术 易学:娓娓道来图示指引原理剖析步骤解说代码详注 真实企业级项目技术细节完整揭秘,照着做就…

Vue 二次封装组件的艺术与实践

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

逐步学习Go-协程goroutine

参考:逐步学习Go-协程goroutine – FOF编程网 什么是线程? 简单来说线程就是现代操作系统使用CPU的基本单元。线程基本包括了线程ID,程序计数器,寄存器和线程栈。线程共享进程的代码区,数据区和操作系统的资源。 线…

前端Web移动端学习day05

移动 Web 第五天 响应式布局方案 媒体查询Bootstrap框架 响应式网页指的是一套代码适配多端,一套代码适配各种大小的屏幕。 共有两种方案可以实现响应式网页,一种是媒体查询,另一种是使用bootstrap框架。 01-媒体查询 基本写法 max-wid…

day70 Mybatis使用mapper重构xml文件重新修改商品管理系统

day67 基于mysql数据库jdbcDruidjar包连接的商品管理用户购物系统-CSDN博客 1多表操作 2动态SQL 项目中使用的为商品管理系统的表 一 查询商品信息 编号,名称,单价,库存,类别 1表:商品表,类别表 n对1…

MySQL—存储引擎和索引

MySQL进阶 1. 存储引擎1.1 体系结构1.2 存储引擎1.3 存储引擎特点1.3.1 InnoDB1.3.2 MyISAM1.3.3 Memory1.3.4 区别及特点 1.4 存储引擎选择 2. 索引2.1 概述2.2 索引结构2.2.1 概述2.2.2 二叉树2.2.3 B-Tree2.2.4 BTree2.2.5 Hash 2.3 索引分类2.3.1 索引分类2.3.2 聚集索引&a…

Axure中后台系统原型模板,B端页面设计实例,高保真高交互54页

作品概况 页面数量:共 50 页(长期更新) 兼容版本:Axure RP 9/10,不支持低版本 应用领域:网页模板、网站后台、中台系统、B端系统 作品特色 本品为「web中后台系统页面设计实例模板」,默林原创…

libVLC 视频抓图

Windows操作系统提供了多种便捷的截图方式,常见的有以下几种: 全屏截图:通过按下PrtSc键(Print Screen),可以截取整个屏幕的内容。截取的图像会保存在剪贴板中,可以通过CtrlV粘贴到图片编辑工具…

律甲法务OA平台:信鸥科技引领法律行业新篇章

随着信息技术的飞速发展,法律行业也迎来了数字化转型的重要时刻。在这个信息化、智能化的时代,如何运用科技手段提升法律服务的质量和效率,成为法律行业亟待解决的问题。信鸥科技,作为业界的佼佼者,凭借其深厚的技术积…