TensorFlow2实战-系列教程2:神经网络分类任务

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、Mnist数据集

下载mnist数据集:

%matplotlib inline
from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)

制作数据:

import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

简单展示数据:

from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
print(y_train[0])

打印结果:

(50000, 784)
5

在这里插入图片描述

2、模型构建

在这里插入图片描述
在这里插入图片描述
输入为784神经元,经过隐层提取特征后为10个神经元,10个神经元的输出值经过softmax得到10个概率值,取出10个概率值中最高的一个就是神经网络的最后预测值

构建模型代码:

import tensorflow as tf
from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

选择损失函数,损失函数是机器学习一个非常重要的部分,基本直接决定了这个算法的效果,这里是多分类任务,一般我们就直接选用多元交叉熵函数就好了:
TensorFlow损失函数API

编译模型:

model.compile(optimizer=tf.keras.optimizers.Adam(0.001),loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  1. adam优化器,学习率为0.001
  2. 多元交叉熵损失函数
  3. 评价指标

模型训练:

model.fit(x_train, y_train, epochs=5, batch_size=64, validation_data=(x_valid, y_valid))

训练数据,训练标签,训练轮次,batch_size,验证集

打印结果:

Train on 50000 samples, validate on 10000 samples
Epoch 1/5 50000/50000  1s 29us
sample-loss: 115566 - sparse_categorical_accuracy: 0.1122 - val_loss: 364928.5786 - val_sparse_categorical_accuracy: 0.1064
Epoch 2/5 50000/50000 1s 21us
sample - loss: 837104 - sparse_categorical_accuracy: 0.1136 - val_loss: 1323287.7028 - val_sparse_categorical_accuracy: 0.1064
Epoch 3/5 50000/50000 1s 20us
sample - loss: 1892431 - sparse_categorical_accuracy: 0.1136 - val_loss: 2448062.2680 - val_sparse_categorical_accuracy: 0.1064
Epoch 4/5 50000/50000 1s 20us
sample - loss: 3131130 - sparse_categorical_accuracy: 0.1136 - val_loss: 3773744.5348 - val_sparse_categorical_accuracy: 0.1064
Epoch 5/5 50000/50000 1s 20us
sample - loss: 4527781 - sparse_categorical_accuracy: 0.1136 - val_loss: 5207194.3728 - val_sparse_categorical_accuracy: 0.1064
<tensorflow.python.keras.callbacks.History at 0x1d3eb9015f8>

模型保存:

model.save('Mnist_model.h5')

3、TensorFlow常用模块

3.1 Tensor格式转换

创建一组数据

import numpy as np
input_data = np.arange(16)
input_data

打印结果:
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

转换成TensorFlow格式的数据:

dataset = tf.data.Dataset.from_tensor_slices(input_data)
for data in dataset:print (data)

将一个ndarray转换成
打印结果:
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

3.2repeat操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2)
for data in dataset:print (data)

tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)

tf.Tensor(14, shape=(), dtype=int32)
tf.Tensor(15, shape=(), dtype=int32)

会将当前的输出重复一遍

3.3 batch操作

dataset = tf.data.Dataset.from_tensor_slices(input_data)
dataset = dataset.repeat(2).batch(4)
for data in dataset:print (data)

tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)
tf.Tensor([0 1 2 3], shape=(4,), dtype=int32)
tf.Tensor([4 5 6 7], shape=(4,), dtype=int32)
tf.Tensor([ 8 9 10 11], shape=(4,), dtype=int32)
tf.Tensor([12 13 14 15], shape=(4,), dtype=int32)

将原来的数据按照4个为一个批次

3.4 shuffle操作

dataset = tf.data.Dataset.from_tensor_slices(input_data).shuffle(buffer_size=10).batch(4)
for data in dataset:print (data)

tf.Tensor([ 9 8 11 3], shape=(4,), dtype=int32)
tf.Tensor([ 5 6 1 13], shape=(4,), dtype=int32)
tf.Tensor([14 15 4 2], shape=(4,), dtype=int32)
tf.Tensor([12 7 0 10], shape=(4,), dtype=int32)

shuffle操作,直接翻译过来就是洗牌,把当前的数据进行打乱操作
buffer_size=10,就是缓存10来进行打乱取数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/433202.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯省赛无忧 编程13 肖恩的投球游戏

#include <iostream> #include <vector> using namespace std; int main() {int n, q;cin >> n >> q;vector<int> a(n 1);vector<int> diff(n 2, 0); // 初始化差分数组// 读取初始球数&#xff0c;构建差分数组for (int i 1; i < …

Redis 击穿、穿透、雪崩产生原因解决思路

大家都知道&#xff0c;计算机的瓶颈之一就是IO&#xff0c;为了解决内存与磁盘速度不匹配的问题&#xff0c;产生了缓存&#xff0c;将一些热点数据放在内存中&#xff0c;随用随取&#xff0c;降低连接到数据库的请求链接,避免数据库挂掉。需要注意的是&#xff0c;无论是击穿…

IDEA 安装阿里Java编码规范插件

1.File>Settings 2.安装之后重启 开发过程中如果有不符合规范的地方&#xff0c;会自动出现提示

硬件知识(1) 手机的长焦镜头

#灵感# 手机总是配备好几个镜头&#xff0c;研究一下 目录 手机常配备的摄像头&#xff0c;及效果举例 长焦的焦距 焦距的定义和示图&#xff1a; IPC的焦距和适用场景&#xff1a; 手机常配备的摄像头&#xff0c;及效果举例 以下是小米某个手机的摄像头介绍&#xff1a…

SpringBoot中集成XXL-JOB分布式任务调度平台,轻量级、低侵入实现定时任务

场景 XXL-JOB 分布式任务调度平台XXL-JOB XXL-JOB是一个分布式任务调度平台&#xff0c;其核心设计目标是开发迅速、学习简单、轻量级、易扩展。 特性&#xff1a; 1、简单&#xff1a;支持通过Web页面对任务进行CRUD操作&#xff0c;操作简单&#xff0c;一分钟上手&…

Redis 可以代替 MySQL 作为数据库吗?

文章目录 前言1.连接到Redis服务器&#xff1a;2.存储和获取数据&#xff1a;3.列表操作&#xff1a;4.有序集合操作&#xff1a;5.键过期和删除&#xff1a;a.发布和订阅消息&#xff1a;b.实现分布式锁&#xff1a;c.使用Redis实现缓存功能&#xff1a; 前言 当使用Redis作为…

GEM5 Garnet Standalone 命令行与stats.txt结果分析

简介 展示了不同的命令行与结果,作为初步的了解. 命令行 sim-cycles要大,不然没结果 ./build/NULL/gem5.debug configs/example/garnet_synth_traffic.py –num-cpus16 –num-dirs16 –networkgarnet –topologyMesh_XY –mesh-rows4 –sim-cycles1000000 --inj-vnet…

Kafka-服务端-PartitionStateMachine

PartitionStateMachine是Controller Leader用于维护分区状态的状态机。分区的状态是通过PartitionState接口定义的&#xff0c;它有四个子类分别代表了分区四种可能的状态&#xff0c;如表所示。 分区各个PartitionState之间的转换如图所示。 下面分析各个状态之间转换时&#…

使用代码取大量2*2像素图片各通道均值,存于Excel文件中。

任务是取下图RGB各个通道的均值及标签&#xff08;R, G&#xff0c;B&#xff0c;Label&#xff09;,其中标签由图片存放的文件夹标识。由于2*2像素图片较多&#xff0c;所以将结果放置于Excel表格中&#xff0c;之后使用SVM对他们进行分类。 from PIL import Image import os …

深入浅出hdfs-hadoop基本介绍

一、Hadoop基本介绍 hadoop最开始是起源于Apache Nutch项目&#xff0c;这个是由Doug Cutting开发的开源网络搜索引擎&#xff0c;这个项目刚开始的目标是为了更好的做搜索引擎&#xff0c;后来Google 发表了三篇未来持续影响大数据领域的三架马车论文&#xff1a; Google Fil…

Unity应用在车机上启动有概率黑屏的解决方案

问题描述 最近将游戏适配到车机上&#xff08;Android系统&#xff09;&#xff0c;碰到了一个严重bug&#xff0c;启动的时候有概率会遇到黑屏&#xff0c;表现就是全黑&#xff0c;无法进入Unity的场景。 经过查看LogCat日志&#xff0c;也没有任何报错&#xff0c;也没有任…

qt 坦克大战游戏 GUI绘制

关于本章节中使用的图形绘制类&#xff0c;如QGraphicsView、QGraphicsScene等的详细使用说明请参见我的另一篇文章&#xff1a; 《图形绘制QGraphicsView、QGraphicsScene、QGraphicsItem、Qt GUI-CSDN博客》 本文将模仿坦克大战游戏&#xff0c;目前只绘制出一辆坦克&#…