paddle实现手写数字模型(一)

  1. 参考文档:paddle官网文档
  2. 环境:Python 3.12.2 ,pip 24.0 ,paddlepaddle 2.6.0
    python -m pip install paddlepaddle==2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 调试代码如下:
    LeNet.py
import paddle
import paddle.nn.functional as Fclass LeNet(paddle.nn.Layer):def __init__(self):super().__init__()self.conv1 = paddle.nn.Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2)self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.flatten(x, start_axis=1,stop_axis=-1)x = self.linear1(x)x = F.relu(x)x = self.linear2(x)x = F.relu(x)x = self.linear3(x)return x

train.py


import paddle
from paddle.vision.transforms import Compose,Normalize,ToTensor
import paddle.vision.transforms as T  import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracyfrom LeNet import LeNet
from PIL import Imageprint(paddle.__version__)
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
print('下载和加载训练数据...')
train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test',transform=transform)
print('load finished')train_data0,train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0,cmap=plt.cm.binary)
#plt.show()
print('train_data0 label is: '+str(train_label_0))model = paddle.Model(LeNet())   # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
print('配置模型...')
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())
# 训练模型
print('训练模型...')
model.fit(train_dataset,epochs=2,batch_size=64,verbose=1)
# 保存模型  
model.save('./model/mnist_model')  # 默认保存模型结构和参数 #预测模型
print('预测模型...')
model.evaluate(test_dataset, batch_size=64, verbose=1)

predicted.py


import paddleimport numpy as npfrom LeNet import LeNet
from PIL import Image# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):# 从img_path中读取图像,并转为灰度图im = Image.open(img_path).convert('L')#plt.imshow(im,cmap='gray')# print(np.array(im))im = im.resize((28, 28), Image.Resampling.LANCZOS)im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致im = 1 - im / 255 return im# 加载训练好的模型参数
model = LeNet()
model.load_dict(paddle.load('./model/mnist_model.pdparams'))# 设置模型为评估模式
model.eval()# 准备一个MNIST样例图像
example_image = load_image("d:/8.png")# 转换为Tensor并进行推理
with paddle.no_grad():example_tensor = paddle.to_tensor(example_image)prediction = model(example_tensor)print(prediction)# 获取预测类别
predicted_class = np.argmax(prediction.numpy(), axis=1)[0]
print(f"Predicted class: {predicted_class}")

说明:先通过执行train.py训练数据集,将模型保存在model文件夹中,
然后运行predicted.py加载训练出来的数据集,推理出d:/8.png图片的结果。
结果图片如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/600900.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.网络编程-网络协议

目录 网络编程是什么 网络编程三要素 OSI七层网络模型 TCP/IP五层模型 SSL/TLS 是哪层协议 网络编程是什么 网络编程是计算机科学中的一个重要领域,它涉及到编写能够在网络环境中进行通信的程序。网络编程的核心目标是使不同的设备能够通过网络交换信息&#…

重建大师进行扫码认证了,接下来怎样才能正常使用?(如下图)

重建大师软件授权已经有了后,新建工程后设置任务目录和监控目录一致就可以运行了。 重建大师是一款专为超大规模实景三维数据生产而设计的集群并行处理软件,输入倾斜照片,激光点云,POS信息及像控点,输出高精度彩色网格…

linux创建文件、linux创建文件的几种方式、touch、echo、cat、vi、vim

文章目录 一、创建文件1.1、touch1.2、echo1.3、cat1.4、vi或vim 一、创建文件 1.1、touch touch命令:用于创建一个新的空文件或者更新已存在文件的访问和修改时间。 (1)如果目标文件不存在,则新建一个文件 touch demo.txt&am…

基于R语言lavaan结构方程模型(SEM)实践技术应用

原文链接:基于R语言lavaan结构方程模型(SEM)实践技术应用https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247596681&idx4&sn08753dd4d3e7bc492d750c0f06bba1b2&chksmfa823b6ecdf5b278ca0b94213391b5a222d1776743609cd…

DFS-0与异或问题,有奖问答,飞机降落

代码和解析 #include<bits/stdc.h> using namespace std; int a[5][5]{{1,0,1,0,1}}; //记录图中圆圈内的值&#xff0c;并初始化第1行 int gate[11]; //记录10个逻辑门的一种排列 int ans; //答案 int logic(int x, int y, int op){…

数仓调优实战:GUC参数调优

1. 前言 适用版本&#xff1a;【8.1.1及以上】 GaussDB(DWS)性能调优系列专题文章&#xff0c;介绍了数据库性能调优的思路和总体策略。在系统级调优中数据库全局的GUC参数对整体性能的提升至关重要&#xff0c;而在语句级调优中GUC参数可以调整估算模型&#xff0c;选择查询…

【超重磅牛市信号】减半倒计时12天!首波抛售潮接近尾声,大暴涨将如期而至!

3月&#xff0c;美国CPI环比出现小幅反弹由3.1%升至3.2%&#xff0c;美国制造业指数PMI反弹至50.3%呈现进入扩张期的态势&#xff0c;日本结束长达8年的负利率时代首次加息。这导致美国4月降息概率大幅下降&#xff0c;5月降息概率也跌至50%以下。 尽管如此&#xff0c;全球金融…

移动医保支付

传统就医流程中&#xff0c;涉及“三长一短”的难题&#xff0c;因此根据国家政策及互联网的能力支持&#xff0c;用户在微信或者支付宝上激活医保电子凭证之后&#xff0c;无需在医院窗口排队&#xff0c;即可通过微信小程序或者公众号、支付宝小程序缴纳医保挂号或医保门诊费…

vulnhub----natraj靶机

文章目录 一.信息收集1.网段探测2.端口扫描3.版本服务探测4.漏扫5.目录扫描 二.漏洞利用1.分析信息2..fuzz工具 三.getshell四.提权六.nmap提权 一.信息收集 1.网段探测 因为使用的是VMware&#xff0c;靶机的IP地址是192.168.9.84 ┌──(root㉿kali)-[~/kali/vulnhub] └─…

MYSQL快速入门

理解SQL 语句的执行过程 掌握SQL 语句的基本语法 掌握SQL 语句的增删改查操作 1.SQL 分类 MySQL 是关系型数据库系统&#xff0c;其中存储了大量的数据&#xff0c;通过SQL 管理数据库配置和数据。结构化查询语言&#xff08;SQL&#xff09;&#xff0c;对数据库进行操作的语…

【智能算法】蛾群算法(MSA)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献 1.背景 2017年&#xff0c;AAA Mohamed等人受到飞蛾趋光行为启发&#xff0c;提出了蛾群算法&#xff08;Moth Swarm Algorithm, MSA&#xff09;。 2.算法原理 2.1算法思想 MSA设待优化问题的可行解和适…

做抖店什么东西好卖?什么商品赚钱?抖音小店的选品标准来了!

哈喽~我是电商月月 做抖店&#xff0c;选品决定了一切&#xff01; 而从没接触过抖店的新手朋友&#xff0c;根本不知道什么样的商品才能算的上是好商品 在这里&#xff0c;我不敢告诉大家这个商品好卖&#xff0c;你们快去卖&#xff01;店铺的情况不同&#xff0c;运营方式…