机器学习-10-基于paddle实现神经网络

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensortrain_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as nptrain_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingfrom collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingif isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop workingreturn list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linearclass MyModel(paddle.nn.Layer):def __init__(self):super(MyModel, self).__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)self.linear1 = Linear(in_features=16*5*5, out_features=120)self.linear2 = Linear(in_features=120, out_features=84)self.linear3 = Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = F.relu(x)x = self.conv2(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1, stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       Linear-1          [[1, 400]]            [1, 120]           48,120     Linear-2          [[1, 120]]            [1, 84]            10,164     Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpecinputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)results = model.predict(test_dataset, batch_size=64)test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])
Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/640736.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

直播美颜工具与视频美颜SDK:技术深入探索

直播美颜工具和视频美颜SDK的出现,为直播平台和应用开发者提供了丰富的选择。本文将深入探讨这些技术的原理、应用和发展趋势。 一、美颜算法 直播美颜工具的核心在于其先进的美颜算法。这些算法通过对图像进行分析和处理,实时地修饰主播的面部特征&am…

机器学习-10-神经网络python实现-从零开始

文章目录 总结参考本门课程的目标机器学习定义从零构建神经网络手写数据集MNIST介绍代码读取数据集MNIST神经网络实现测试手写的图片 带有反向查询的神经网络实现 总结 本系列是机器学习课程的系列课程,主要介绍基于python实现神经网络。 参考 BP神经网络及pytho…

stm32——GPIO学习

对于许多刚入门stm32的同学们来说,GPIO是我们的第一课,初出茅庐的我们会对GPIO的配置感到疑惑不解,也是劝退我们的第一课,今天我们就来一起学习一下stm32的GPIO,提振一下信心。好的,发车了小卷卷们&#xf…

力扣146. LRU 缓存

Problem: 146. LRU 缓存 文章目录 题目描述思路复杂度Code 题目描述 思路 主要说明大致思路,具体实现看代码。 1.为了实现题目中的O(1)时间复杂度的get与put方法,我们利用哈希表和双链表的结合,将key作为键,对应的链表的节点作为…

【前后端】django前后端交互

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、django是什么二、django前后端交互指引三、总结 前言 随着开发语言及人工智能工具的普及,使得越来越多的人会主动学习使用一些开发语言&#x…

零售数据分析方案:深度剖析人、货、场

人,即会员分析、用户分析,通过分析获得直观的用户画像,了解目标用户群体的消费水平、喜好、频率,为销售营销决策提供必要的数据支持;货,即商品分析,包括但不限于分析商品结构、分析销售top10商品…

构建云原生湖仓:Apache Iceberg与Amoro的结合实践

随着大数据技术的快速发展,企业对数据的处理和分析需求日益增长。传统的数据仓库已逐渐无法满足现代业务对数据多样性和实时性的要求,这促使了数据湖和数据仓库的融合,即湖仓一体架构的诞生。在云原生技术的推动下,构建云原生湖仓…

Opensbi初始化分析:设备初始化

Opensbi初始化分析:设备初始化 设备初始化sbi_init函数coldinit,冷启动初始化sbi_scratch_init函数sbi_domain_init函数sbi_hsm_initsbi_platform_early_initsbi_hart_initsbi_console_initsbi_platform_irqchip_init中断控制器的初始化sbi_ipi_init函数…

C++奇迹之旅:从0开始实现日期时间计算器

文章目录 📝前言🌠 头文件Date.h🌉日期计算函数🌠前后置🌉前后置-- 🌠两对象日期相减🌉自定义流输入和输出 🌉 代码🌉 头文件Date.h🌠Date.cpp🌉 …

数据挖掘实验(Apriori,fpgrowth)

Apriori:这里做了个小优化,比如abcde和adcef自连接出的新项集abcdef,可以用abcde的位置和f的位置取交集,这样第n项集的计算可以用n-1项集的信息和数字本身的位置信息计算出来,只需要保存第n-1项集的位置信息就可以提速…

在Qt creator中使用多光标

2024年4月22日,周一下午 Qt Creator 支持多光标模式。 多光标模式允许你在同一时间在多个光标位置进行编辑,从而可以更快地进行一些重复性的编辑操作。 要启用多光标模式,请按住 Alt 键,并用鼠标左键在文本编辑器中选择多个光标…

标题Selenium IDE 常见错误笔记

Selenium IDE 常见错误笔记 错误1:Failed:Exceeded waiting time for new window to appear 2000ms 这个错误通常出现在第一次运行时,有两个原因: Firefox阻止了弹出式窗口,在浏览器设置里允许这个操作即可。 有些网站设置了反…