神经网络模型数字推理预测

MNIST数据集

MNIST是机器学习领域 最有名的数据集之一,被应用于从简单的实验到发表的论文研究等各种场合。 实际上,在阅读图像识别或机器学习的论文时,MNIST数据集经常作为实验用的数据出现。

MNIST数据集是由0到9的数字图像构成的。训练图像有6万张, 测试图像有1万张,这些图像可以用于学习和推理。MNIST数据集的一般使用方法是,先用训练图像进行学习,再用学习到的模型度量能在多大程度上对测试图像进行正确的分类。

1.png

MNIST的图像数据是28像素 × 28像素的灰度图像(1通道),各个像素的取值在0到255之间。每个图像数据都相应地标有"7" "2" "1"等标签。

使用如下脚本可以下载数据集

# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as npurl_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784def _download(file_name):file_path = dataset_dir + "/" + file_nameif os.path.exists(file_path):returnprint("Downloading " + file_name + " ... ")urllib.request.urlretrieve(url_base + file_name, file_path)print("Done")def download_mnist():for v in key_file.values():_download(v)def _load_label(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labelsdef _load_img(file_name):file_path = dataset_dir + "/" + file_nameprint("Converting " + file_name + " to NumPy Array ...")    with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return datadef _convert_numpy():dataset = {}dataset['train_img'] =  _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])    dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return datasetdef init_mnist():download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return Tdef load_mnist(normalize=True, flatten=True, one_hot_label=False):"""读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)"""if not os.path.exists(save_file):init_mnist()with open(save_file, 'rb') as f:dataset = pickle.load(f)if normalize:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].astype(np.float32)dataset[key] /= 255.0if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if __name__ == '__main__':init_mnist()

load_mnist函数以"(训练图像 ,训练标签 ),(测试图像,测试标签 )"的多元组形式返回读入的MNIST数据。

load_mnist(normalize=True, flatten=True, one_hot_label=False) 这 样,设 置 3 个 参 数。
第 1 个参数normalize设置是否将输入图像正规化为0.0~1.0的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。
第2个参数flatten设置是否展开输入图像(变成一维数组)。如果将该参数设置为False,则输入图像为1 × 28 × 28的三维数组;若设置为True,则输入图像会保存为由784个元素构成的一维数组。
第3个参数one_hot_label设置是否将标签保存为one-hot表示(one-hot representation)。one-hot表示是仅正确解标签为1,其余皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,只是像7、2这样简单保存正确解标签;当one_hot_label为True时,标签则 保存为one-hot表示。

可以通过如下代码读出下载的图片

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
from DeepLearn_Base.dataset.mnist import load_mnist
from PIL import Imagedef img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)img = x_train[1]
label = t_train[1]
print(label)  # 5print(img.shape)  # (784,)
img = img.reshape(28, 28)  # 把图像的形状变为原来的尺寸
print(img.shape)  # (28, 28)img_show(img)

读出来的数据如下所示:

2.png

神经网络的推理

现在使用python的numpy结合神经网络的算法来推理图片的内容。整个流程其实就是两个部分:数据集准备、权重与偏置超参数准备。

数据集准备

使用如下代码块下载准备测试数据集:

def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)return x_test, t_test# 下载mnist数据集
# 分别下载测试图像包、测试标签包、训练图像包、训练标签包
x, t = get_data()

打印输出x, t参数shape

3.png

读取实现准备好的权重参数文件pkl,同时打印出来看看其参数shape

def init_network():with open("E:\\workcode\\code\\DeepLearn_Base\\ch03\\sample_weight.pkl", 'rb') as f:network = pickle.load(f)return network# 获取预训练好的权重与偏置参数
network = init_network()

4.png

可以看到,超参数分别是3个权重参数与3个偏置参数,为了方便,稍后再打印出其shape .

超参数文件 sample_weight.pkl 是预训练好的,本文主要是从神经网络的推理角度考虑,预训练文件的准备,暂不涉及。

推理

开始执行神经网络的推理,同时打印出其各个参数的shape

def predict(network, x):W1, W2, W3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']# 第一层计算a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)# 第二层计算a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, W3) + b3# 输出层y = softmax(a3)return yaccuracy_cnt = 0
for i in range(len(x)):y = predict(network, x[i])p= np.argmax(y) # 获取概率最高的元素的索引if p == t[i]:accuracy_cnt += 1print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

predict方法中,执行了推理过程,主要是各个数学公式的计算(sigmoid,softmax,线性计算),这些公式都是在numpy的基础上根据公式用程序语言表述出来的,具体的计算逻辑可以查阅functions.py文件。

看看各个参数的shape:

5.png

可以看看计算过程中的各个数据维度是否满足匹配:

6.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/236758.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

<Linux>(极简关键、省时省力)《Linux操作系统原理分析之linux存储管理(2)》(18)

《Linux操作系统原理分析之linux存储管理(1)》(17) 6 Linux存储管理6.2 选段符与段描述符6.2.1 选段符6.2.2 段描述符6.2.3 分段机制的存储保护 6.3 80x86 的分页机制6.3.180x86 的分页机制6.3.2 分页机制的地址转换6.3.3 页表目录…

⭐ Unity + ARKIT 介绍 以及 平面检测的实现

在AR插件中,ARKIT是比较特殊的一个,首先他在很多追踪上的效果要比其他的AR插件要好,但是只能在IOS系统设备上运行。 1.首先ARKIT在最新版Unity已经集成在AR Foundation中,那我们就需要ARSession 和ARSessionOrigin这两个重要组件…

年终好价节买什么好?高效实用、高性价比的的数码好物推荐

前段时间,“淘宝双12不再举办”的话题上了热搜,改成了“淘宝年终好价节”。从“双12”到“好价节”,背后意味着大众跳出了一味追求低价的“买买买”的怪圈,转变为更追寻价好质优的商品。错过双11的消费者可以趁这个时间抓紧入手收…

深度学习手势识别 - yolo python opencv cnn 机器视觉 计算机竞赛

文章目录 0 前言1 课题背景2 卷积神经网络2.1卷积层2.2 池化层2.3 激活函数2.4 全连接层2.5 使用tensorflow中keras模块实现卷积神经网络 3 YOLOV53.1 网络架构图3.2 输入端3.3 基准网络3.4 Neck网络3.5 Head输出层 4 数据集准备4.1 数据标注简介4.2 数据保存 5 模型训练5.1 修…

PHP开源问答网站平台源码系统 源码全部开源可二次开发 附带完整的搭建教程

目前,问答网站已经成为人们获取知识、交流思想的重要平台。然而,对于许多开发者来说,从头开始构建一个问答网站可能会面临各种挑战。今天,小编给大家介绍一款基于PHP的开源问答网站平台源码系统,它不仅源码全部开源&am…

NOIP2007提高组第二轮T3:矩阵取数游戏

题目链接 [NOIP2007 提高组] 矩阵取数游戏 题目描述 帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的 n m n \times m nm 的矩阵,矩阵中的每个元素 a i , j a_{i,j} ai,j​ 均为非负整数。游戏规则如下: 每次取数时须从每行各取走一…

学习笔记7——数据库基础知识以及mysql的查询语句

学习笔记系列开头惯例发布一些寻亲消息 链接:https://baobeihuijia.com/bbhj/contents/3/199913.html 数据库 三个概念区分 DB:数据库,存储数据的仓库,有组织的数据容器DBMS:数据库管理系统SQL:几乎所有的DBMS都支持…

人工智能_机器学习056_拉格朗日乘子法原理推导_公式由来详解_原理详解---人工智能工作笔记0096

https://blog.csdn.net/Soft_Po/article/details/118332454 这里有老师的一篇文章介绍拉格朗日乘子法的原理推导 结合老师的这篇文章我们来看一下详细的推导过程 可以看到上一节我们说,一个有条件的,函数,可以转换为一个,无条件的函数, 根据拉格朗日乘子法,可以创建出一个等…

python使用记录

1、VSCode添加多个python解释器 只需要将对应的python.exe的目录,添加到系统环境变量中即可,VSCode会自动识别及添加 2、pip 使用 pip常用命令和一些坑 查看已安装库的版本号 pip show 库名称 通过git 仓库安装第三方库 pip install git仓库地址

Linux系统常用指令

1.使用xshell登录到云服务器的Linux系统: ssh 用户名公网IP,例如: ssh root111.11.111. 2.添加用户 adduser 用户名,例如: adduser user 3.为用户设置密码 passwd 用户名,例如: passwd …

Cloudflare Email Routing 免费邮件发送服务

Cloudflare Email Routing 免费邮件发送(作为 Service 服务)用于 Workers/Pages 项目中。 原文链接: https://willin.wang/blog/cloudflare-send-email-service 准备工作 准备一个域名,例如 example.com。现在,在 cloudflare-dashboard 中添加一个网站并构建您的域名。这…

【JavaScript】3.4 JavaScript在现代前端开发中的应用

文章目录 1. 用户交互2. 动态内容3. 前端路由4. API 请求总结 JavaScript 是现代前端开发的核心。无论是交互效果,还是复杂的前端应用,JavaScript 都发挥着关键作用。在本章节中,我们将探讨 JavaScript 在现代前端开发中的应用,包…