Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/576211.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTTP,Servlet

HTTP 概念:HyperTextTransferProtocol,超文本传输协议,规定了浏览器和服务器之间数据传输的规则 HTTP协议特点: 1.基于TCP协议:面向连接,安全 2.基于请求-响应模型的:一次请求对应一次响应 …

java数据结构与算法刷题-----LeetCode278. 第一个错误的版本

java数据结构与算法刷题目录(剑指Offer、LeetCode、ACM)-----主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/123063846 文章目录 二分查找 二分查找 解题思路:时间复杂度O( l o g 2 …

Linux文件IO(2):使用标准IO进行文件的打开、关闭、读写、流定位等相关操作

目录 前言 文件的打开和关闭的概念 文件的打开 文件的打开函数 文件打开的模式 文件的关闭 文件的关闭函数 注意事项 字符的输入(读单个字符) 字符输入的函数 注意事项 字符的输出(写单个字符) 字符输出的函数 注意…

【Web应用技术基础】CSS(6)——使用 HTML/CSS 实现 Educoder 顶部导航栏

第一题&#xff1a;使用flex布局实现Educoder顶部导航栏容器布局 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Educoder</title><script src"https://cdn.staticfile.org/jquery/1.1…

论文笔记:分层问题-图像共注意力问答

整理了2017 Hierarchical Question-Image Co-Attention for Visual Question Answering&#xff09;论文的阅读笔记 背景模型问题定义模型结构平行共注意力交替共注意力 实验可视化 背景 视觉问答(VQA)的注意力模型在此之前已经有了很多工作&#xff0c;这种模型生成了突出显示…

机器学习优化算法(深度学习)

目录 预备知识 梯度 Hessian 矩阵&#xff08;海森矩阵&#xff0c;或者黑塞矩阵&#xff09; 拉格朗日中值定理 柯西中值定理 泰勒公式 黑塞矩阵&#xff08;Hessian矩阵&#xff09; Jacobi 矩阵 优化方法 梯度下降法&#xff08;Gradient Descent&#xff09; 随机…

Python列表、元组、字典及集合

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、列表定义方式&#xff1a; 二、元组1、定义方式&#xff1a;2、元组中的物理存储地址不可修改,如果修改则会报错&#xff0c;但是元组中的列表、字典项等却可以…

2024环境,资源与绿色能源国际会议(ICERGE2024)

2024环境&#xff0c;资源与绿色能源国际会议(ICERGE2024) 会议简介 2024环境、资源与绿色能源国际会议(ICERGE2024)将于2024年在三亚举行。该会议是一个围绕环境、资源与绿色能源研究领域的国际学术交流活动。 会议主题包括但不限于环境科学、环境工程、资源利用、绿色能源开…

tomcat和Servlet开发小案例

在上篇文章中,我已经正确安装了tomcat和利用servlet写了个基础的hello world程序,明白了这两个东西到底是啥玩意. 接下来,需要写个登录的小案例来进一步熟悉下基于servlet开发的流程. 一,新建项目 我们新建的maven项目其实里面是空的。所以我们需要给他变成一个servlet项目。 …

Netty对Channel事件的处理以及空轮询Bug的解决

继续上一篇Netty文章&#xff0c;这篇文章主要分析Netty对Channel事件的处理以及空轮询Bug的解决 当Netty中采用循环处理事件和提交的任务时 由于此时我在客户端建立连接&#xff0c;此时服务端没有提交任何任务 此时select方法让Selector进入无休止的阻塞等待 此时selectCnt进…

天途又获证书:甲级互联网地图服务企业、一级软件开发服务企业、一级信息技术服务运行维护服务企业。

10月24日&#xff0c;天途收到了三项重量级证书&#xff0c;甲级互联网地图服务企业资质证书、一级软件开发服务企业资质证书、一级信息技术服务运行维护服务企业证书。 互联网地图&#xff0c;是指登载在互联网上或者通过互联网发送的基于服务器地理信息数据库形成的具有实时…

【一种基于改进A*算法和CSA-APF算法的混合路径规划方法】—— 论文阅读

论文题目&#xff1a;A Hybrid Path Planning Method Based on Improved A∗ and CSA-APF Algorithms 1 摘要 大问题&#xff1a;复杂动态环境下全局路径规划难以避开动态障碍物&#xff0c;且局部路径容易陷入局部最优的问题 问题1&#xff1a;针对A*算法产生冗余路径节点和…