使用python处理MNIST数据集

文章目录

  • 一. MNIST数据集
    • 1.1 什么是MNIST数据集
    • 1.2MNIST数据集文件格式
    • 1.3使用python访问MNIST数据集文件内容
  • 附录
    • 程序源码

一. MNIST数据集

1.1 什么是MNIST数据集

MNIST数据集是入门机器学习/识别模式的最经典数据集之一。最早于1998年Yan Lecun在论文:[Gradient-based learning applied to document recognition]中提出。该数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。每张图片中的像素大小在0-255之间,其中0是黑色,255是白色。如下图所示:

在这里插入图片描述

MNIST共包含70000张手写数字图片,其中有60000张用作训练集,10000张用作测试集。元数据集可以在MNIST官网下载。下载之后得到4个压缩文件:

train-images-idx3-ubyte.gz #60000张训练集图片
train-labels-idx1-ubyte.gz #60000张训练集图片对应的标签
t10k-images-idx3-ubyte.gz  #10000张测试集图片
t10k-labels-idx1-ubyte.gz  #10000张测试集图片对应的标签

将上面的4个压缩文件分别解压,得到:

train-images-idx3-ubyte #60000张训练集图片的idx3-ubyte格式文件
train-labels-idx1-ubyte #60000张训练集图片对应的标签的idx3-ubyte格式文件
t10k-images-idx3-ubyte  #10000张测试集图片的idx3-ubyte格式文件
t10k-labels-idx1-ubyte  #10000张测试集图片对应的标签的idx3-ubyte格式文件

1.2MNIST数据集文件格式

解压得到的4个文件都是二进制格式的文件,为了获取其中的信息,需要先了解MNIST二进制文件的存储格式。格式描述如下:
在这里插入图片描述

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;
  • 第5-8个byte存的是number of images,即图像数量60000;
  • 第9-12个byte存的是每张图片行数/高度,即28;
  • 第13-16个byte存的是每张图片的列数/宽度,即28。
  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

1.3使用python访问MNIST数据集文件内容

知道了MNIST二进制文件的存储方式,下面介绍如何使用python访问文件内容。同样以训练集标签文件train-labels-idx1-ubyte和训练集图像文件train-images-idx3-ubyte为例:

import numpy as np
from PIL import ImageMNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte'  # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte'  # 下载的MNIST数据集文件地址with open(MNIST_labels_path, 'rb') as f:file_labels = f.read()  # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:file_images = f.read()  # 读入照片二进制文件magic_number_labels = int.from_bytes(file_labels[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)magic_number = int.from_bytes(file_images[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big')  # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big')  # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)

使用with open() as 函数读取文件,并使用int.from_bytes()方法将文件的magic number, number of items, number of images, number of rows, number of columns,等数据读入,将字节数据转换成整数数据,从而查看图像数量、图像高度和图像宽度信息。
运行结果:

在这里插入图片描述

通过以下程序,可以将MNIST数据集二进制文件中的照片提取出来并以.png格式保存在文件夹中:

# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]image_np = np.array(image, dtype=np.uint8).reshape(28, 28)im = Image.fromarray(image_np)im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")

输出的部分照片如下所示:
在这里插入图片描述

通过以下程序,将二进制标签文件中的部分标签信息打印出来,可以发现,标签中的数据正对应于图像中的手写数字信息。

# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')print('labels' + str(i) + '=' + str(labels))

在这里插入图片描述

附录

程序源码

import numpy as np
from PIL import ImageMNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte'  # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte'  # 下载的MNIST数据集文件地址with open(MNIST_labels_path, 'rb') as f:file_labels = f.read()  # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:file_images = f.read()  # 读入照片二进制文件magic_number_labels = int.from_bytes(file_labels[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)magic_number = int.from_bytes(file_images[0:4], 'big')  # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big')  # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big')  # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big')  # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]image_np = np.array(image, dtype=np.uint8).reshape(28, 28)im = Image.fromarray(image_np)im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')print('labels' + str(i) + '=' + str(labels))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/116279.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

yolo增加slide loss,改善样本不平衡问题

slide loss的主要作用是让模型更加关注难例,可以轻微的改善模型在难例检测上的效果 论文地址:https://arxiv.org/pdf/2208.02019.pdf 代码:GitHub - Krasjet-Yu/YOLO-FaceV2: YOLO-FaceV2: A Scale and Occlusion Aware Face Detector 样本不…

2023年“羊城杯”网络安全大赛 决赛 AWDP [Break+Fix] Web方向题解wp 全

终于迎来了我的第一百篇文章。 这次决赛赛制是AWDP。BreakFix,其实就是CTFFix,Fix规则有点难崩。Break和Fix题目是一样的。 总结一下:败北,还是太菜了得继续修炼一下。 一、Break ezSSTI 看到是SSTI,焚靖直接一把梭…

软件设计模式系列之十三——享元模式

1 模式的定义 享元模式(Flyweight Pattern)是一种结构型设计模式,它旨在减少内存占用或计算开销,通过共享大量细粒度对象来提高系统的性能。这种模式适用于存在大量相似对象实例,但它们的状态可以外部化(e…

2023华为杯数学建模竞赛E题

一、前言 颅内出血(ICH)是由多种原因引起的颅腔内出血性疾病,既包括自发性出血,又包括创伤导致的继发性出血,诊断与治疗涉及神经外科、神经内科、重症医学科、康复科等多个学科,是临床医师面临的重要挑战。…

免费获取独立ChatGPT账户!!

GPT对于每个科研人员已经成为不可或缺的辅助工具,不同的研究领域和项目具有不同的需求。如在科研编程、绘图领域:1、编程建议和示例代码: 无论你使用的编程语言是Python、R、MATLAB还是其他语言,都可以为你提供相关的代码示例。2、数据可视化…

2023 年 Android 毕业设计选题推荐,200 道 Android 毕业设计题目,避免踩坑

前言 选择一个Android毕业设计题目是一个重要的决策,它将影响你未来几个月的工作。以下是一些关于如何选择一个合适的Android毕业设计题目以及如何避免踩坑的建议: 兴趣和热情:首先,选择你真正感兴趣的领域。如果你对某个领域充…

Python:Django框架的Hello wrold示例

Django是Python的目前很常用的web框架,遵循MVC设计模式。 以下介绍如何安装Django框架,并生成最简单的项目,输出Hello world。(开发工具VScode) 一、安装Django 在VScode终端控制台执行以下指令安装Django python install django 如果要查…

相机有俯仰角时如何将像素坐标正确转换到其他坐标系

一般像素坐标系转相机坐标系都是默认相机是水平的,没有考虑相机有俯仰角的情况,大致的过程是:像素坐标系统-->图像坐标系-->相机坐标系 ->世界坐标系或雷达坐标系: 像素坐标系 像素坐标系(u,v)是…

AIX360-CEMExplainer: MNIST Example

CEMExplainer: MNIST Example 这一部分屁话有点多,导包没问题的话可以跳过加载MNIST数据集加载经过训练的MNIST模型加载经过训练的卷积自动编码器模型(可选)初始化CEM解释程序以解释模型预测解释输入实例获得相关否定(Pertinent N…

停车场系统源码

源码下载地址(小程序开源地址):停车场系统小程序,新能源电动车充电系统,智慧社区物业人脸门禁小程序: 【涵盖内容】:城市智慧停车系统,汽车新能源充电,两轮电动车充电,物…

基于Android+OpenCV+CNN+Keras的智能手语数字实时翻译——深度学习算法应用(含Python、ipynb工程源码)+数据集(三)

目录 前言总体设计系统整体结构图系统流程图 运行环境模块实现1. 数据预处理2. 数据增强3. 模型构建4. 模型训练及保存1)模型训练2)模型保存 5. 模型评估 相关其它博客工程源代码下载其它资料下载 前言 本项目依赖于Keras深度学习模型,旨在对…

雷达编程实战之静态杂波滤除与到达角估计

雷达中经过混频的中频信号常常混有直流分量等一系列硬件设计引入的固定频率杂波,我们称之位静态杂波,雷达信号处理需要把这些静态杂波滤除从而有效的提高信噪比,实现准确的目标检测功能。 目标的到达角估计作为常规车载雷达信号处理的末端&am…