TenorFlow多层感知机识别手写体

文章目录

  • 数据准备
  • 建立模型
          • 建立输入层 x
          • 建立隐藏层h1
          • 建立隐藏层h2
          • 建立输出层
  • 定义训练方式
          • 建立训练数据label真实值 placeholder
          • 定义loss function
          • 选择optimizer
  • 定义评估模型的准确率
          • 计算每一项数据是否正确预测
          • 将计算预测正确结果,加总平均
  • 开始训练
          • 画出误差执行结果
          • 画出准确率执行结果
  • 评估模型的准确率
  • 进行预测
  • 找出预测错误

GITHUB地址https://github.com/fz861062923/TensorFlow
注意下载数据连接的是外网,有一股神秘力量让你403

数据准备

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_datamnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.from ._conv import register_converters as _register_convertersWARNING:tensorflow:From <ipython-input-1-2ee827ab903d>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\base.py:252: _internal_retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\admin\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
print('train images     :', mnist.train.images.shape,'labels:'           , mnist.train.labels.shape)
print('validation images:', mnist.validation.images.shape,' labels:'          , mnist.validation.labels.shape)
print('test images      :', mnist.test.images.shape,'labels:'           , mnist.test.labels.shape)
train images     : (55000, 784) labels: (55000, 10)
validation images: (5000, 784)  labels: (5000, 10)
test images      : (10000, 784) labels: (10000, 10)

建立模型

def layer(output_dim,input_dim,inputs, activation=None):#激活函数默认为NoneW = tf.Variable(tf.random_normal([input_dim, output_dim]))#以正态分布的随机数建立并且初始化权重Wb = tf.Variable(tf.random_normal([1, output_dim]))XWb = tf.matmul(inputs, W) + bif activation is None:outputs = XWbelse:outputs = activation(XWb)return outputs
建立输入层 x
x = tf.placeholder("float", [None, 784])
建立隐藏层h1
h1=layer(output_dim=1000,input_dim=784,inputs=x ,activation=tf.nn.relu)  
建立隐藏层h2
h2=layer(output_dim=1000,input_dim=1000,inputs=h1 ,activation=tf.nn.relu)  
建立输出层
y_predict=layer(output_dim=10,input_dim=1000,inputs=h2,activation=None)

定义训练方式

建立训练数据label真实值 placeholder
y_label = tf.placeholder("float", [None, 10])#训练数据的个数很多所以设置为None
定义loss function
# 深度学习模型的训练中使用交叉熵训练的效果比较好
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict , labels=y_label))
选择optimizer
optimizer = tf.train.AdamOptimizer(learning_rate=0.001) \.minimize(loss_function)
#使用Loss_function来计算误差,并且按照误差更新模型权重与偏差,使误差最小化

定义评估模型的准确率

计算每一项数据是否正确预测
correct_prediction = tf.equal(tf.argmax(y_label  , 1),tf.argmax(y_predict, 1))#将one-hot encoding转化为1所在的位数,方便比较
将计算预测正确结果,加总平均
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

开始训练

trainEpochs = 15#执行15个训练周期
batchSize = 100#每一批的数量为100
totalBatchs = int(mnist.train.num_examples/batchSize)#计算每一个训练周期应该执行的次数
epoch_list=[];accuracy_list=[];loss_list=[];
from time import time
startTime=time()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for epoch in range(trainEpochs):#执行15个训练周期#每个训练周期执行550批次训练for i in range(totalBatchs):batch_x, batch_y = mnist.train.next_batch(batchSize)#用该函数批次读取数据sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})#使用验证数据计算准确率loss,acc = sess.run([loss_function,accuracy],feed_dict={x: mnist.validation.images, #验证数据的featuresy_label: mnist.validation.labels})#验证数据的labelepoch_list.append(epoch)loss_list.append(loss);accuracy_list.append(acc)    print("Train Epoch:", '%02d' % (epoch+1), \"Loss=","{:.9f}".format(loss)," Accuracy=",acc)duration =time()-startTime
print("Train Finished takes:",duration)        
Train Epoch: 01 Loss= 133.117172241  Accuracy= 0.9194
Train Epoch: 02 Loss= 88.949943542  Accuracy= 0.9392
Train Epoch: 03 Loss= 80.701606750  Accuracy= 0.9446
Train Epoch: 04 Loss= 72.045913696  Accuracy= 0.9506
Train Epoch: 05 Loss= 71.911483765  Accuracy= 0.9502
Train Epoch: 06 Loss= 63.642936707  Accuracy= 0.9558
Train Epoch: 07 Loss= 67.192626953  Accuracy= 0.9494
Train Epoch: 08 Loss= 55.959281921  Accuracy= 0.9618
Train Epoch: 09 Loss= 58.867351532  Accuracy= 0.9592
Train Epoch: 10 Loss= 61.904548645  Accuracy= 0.9612
Train Epoch: 11 Loss= 58.283069611  Accuracy= 0.9608
Train Epoch: 12 Loss= 54.332244873  Accuracy= 0.9646
Train Epoch: 13 Loss= 58.152175903  Accuracy= 0.9624
Train Epoch: 14 Loss= 51.552104950  Accuracy= 0.9688
Train Epoch: 15 Loss= 52.803482056  Accuracy= 0.9678
Train Finished takes: 545.0556836128235
画出误差执行结果
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()#获取当前的figure图
fig.set_size_inches(4,2)#设置图的大小
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1edb8d4c240>

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

画出准确率执行结果
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

评估模型的准确率

print("Accuracy:", sess.run(accuracy,feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))
Accuracy: 0.9643

进行预测

prediction_result=sess.run(tf.argmax(y_predict,1),feed_dict={x: mnist.test.images })
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9], dtype=int64)
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,idx,num=10):fig = plt.gcf()fig.set_size_inches(12, 14)if num>25: num=25 for i in range(0, num):ax=plt.subplot(5,5, 1+i)ax.imshow(np.reshape(images[idx],(28, 28)), cmap='binary')title= "label=" +str(np.argmax(labels[idx]))if len(prediction)>0:title+=",predict="+str(prediction[idx]) ax.set_title(title,fontsize=10) ax.set_xticks([]);ax.set_yticks([])        idx+=1 plt.show()
plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,0)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

y_predict_Onehot=sess.run(y_predict,feed_dict={x: mnist.test.images })
y_predict_Onehot[8]
array([-6185.544  , -5329.589  ,  1897.1707 , -3942.7764 ,   347.9809 ,5513.258  ,  6735.7153 , -5088.5273 ,   649.2062 ,    69.50408],dtype=float32)

找出预测错误

for i in range(400):if prediction_result[i]!=np.argmax(mnist.test.labels[i]):print("i="+str(i)+"   label=",np.argmax(mnist.test.labels[i]),"predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=18   label= 3 predict= 8
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=233   label= 8 predict= 7
i=241   label= 9 predict= 8
i=245   label= 3 predict= 5
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=320   label= 9 predict= 1
i=340   label= 5 predict= 3
i=381   label= 3 predict= 7
i=386   label= 6 predict= 5
sess.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/473778.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity设备分级策略

Unity设备分级策略 前言 之前自己做的设备分级策略&#xff0c;在此做一个简单的记录和思路分享。希望能给大家带来帮助。 分级策略 根据拟定的评分标准&#xff0c;预生成部分已知机型的分级信息&#xff0c;且保存在包内&#xff1b;如果设备没有被评级过&#xff0c;则优…

多模态(三)--- BLIP原理与源码解读

1 BLIP简介 BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation 传统的Vision-Language Pre-training &#xff08;VLP&#xff09;任务大多是基于理解的任务或基于生成的任务&#xff0c;同时预训练数据多是从web获…

国产制造,欧美品质:爱可声助听器产品质量获国际认可

随着科技的发展和全球化的推进&#xff0c;越来越多的中国制造产品开始走向世界舞台。其中&#xff0c;爱可声助听器凭借其卓越的产品质量&#xff0c;成为了国产制造的骄傲。 国产制造指的是在中国境内生产的产品&#xff0c;欧美品质则是指产品在设计、生产、质量控制等方面…

Invalid DataSize: cannot convert ‘30Mb‘ to Long

Invalid DataSize: cannot convert 30Mb to Long servlet:multipart:max-file-size: 30MBmax-request-size: 30MB

得物面试:Redis用哈希槽,而不是一致性哈希,为什么?

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格&#xff0c;遇到很多很重要的面试题&#xff1a; Redis为何用哈希槽而不用一致性哈希&#xff1f; 最近…

python学习24

前言&#xff1a;相信看到这篇文章的小伙伴都或多或少有一些编程基础&#xff0c;懂得一些linux的基本命令了吧&#xff0c;本篇文章将带领大家服务器如何部署一个使用django框架开发的一个网站进行云服务器端的部署。 文章使用到的的工具 Python&#xff1a;一种编程语言&…

009集——磁盘详解——电脑数据如何存储在磁盘

很多人也知道数据能够保存是由于设备中有一个叫做「硬盘」的组件存在&#xff0c;但也有很多人不知道硬盘是怎样储存这些数据的。这里给大家讲讲其中的原理。 首先我们要明白的是&#xff0c;计算机中只有0和1&#xff0c;那么我们存入硬盘的数据&#xff0c;实际上也就是一堆0…

PLC_博图系列☞LAD

PLC_博图系列☞LAD 文章目录 PLC_博图系列☞LAD背景介绍LAD优势局限 LAD元素 关键字&#xff1a; PLC、 西门子、 博图、 Siemens 、 LAD 背景介绍 这是一篇关于PLC编程的文章&#xff0c;特别是关于西门子的博图软件。我并不是专业的PLC编程人员&#xff0c;也不懂电路&a…

英文论文(sci)解读复现【NO.21】一种基于空间坐标的轻量级目标检测器无人机航空图像的自注意

此前出了目标检测算法改进专栏&#xff0c;但是对于应用于什么场景&#xff0c;需要什么改进方法对应与自己的应用场景有效果&#xff0c;并且多少改进点能发什么水平的文章&#xff0c;为解决大家的困惑&#xff0c;此系列文章旨在给大家解读发表高水平学术期刊中的 SCI论文&a…

简单DP算法(动态规划)

简单DP算法 算法思想例题1、01背包问题题目信息思路题解 2、摘花生题目信息思路题解 3、最长上升子序列题目信息思路题解 题目练习1、地宫取宝题目信息思路题解 2、波动数列题目信息思路题解 算法思想 从集合角度来分析DP问题 例如求最值、求个数 例题 1、01背包问题 题目…

Android EditText关于imeOptions的设置和响应

日常开发中&#xff0c;最绕不开的一个控件就是EditText&#xff0c;随之避免不了的则是对其软键盘事件的监听&#xff0c;随着需求的不同对用户输入的软键盘要求也不同&#xff0c;有的场景需要用户输入完毕后&#xff0c;有一个确认按钮&#xff0c;有的场景需要的是回车&…

深入解析域名短链接生成原理及其在Python/Flask中的实现策略:一篇全面的指南与代码示例

为了构建一个高效且用户友好的域名短链服务&#xff0c;我们可以将项目精简为以下核心功能板块&#xff1a; 1. 用户管理 注册与登录&#xff1a;允许用户创建账户并登录系统。 这部分内容可以参考另一片文章实现&#xff1a; 快速实现用户认证&#xff1a;使用Python和Flask…