深度学习入门之手写数字识别

news/2025/1/14 12:06:29/文章来源:https://www.cnblogs.com/Undefined443/p/18670506

模型定义

我们使用 CNN 和 MLP 来定义模型:

import torch.nn as nnclass Model(nn.Module):def __init__(self):"""定义模型结构输入维度为 1 * 28 * 28 (C, H, W)"""super(Model, self).__init__()# 卷积层 1self.conv1 = nn.Sequential(# 二维卷积层,输入通道数为 1,输出通道数为 16,卷积核大小为 5,填充为 2nn.Conv2d(1, 16, kernel_size=5, padding=2),# ReLU 激活函数nn.ReLU(),# 最大池化层,池化窗口大小为 2nn.MaxPool2d(kernel_size=2)# 输出维度为 16 * 14 * 14 (C, H/2, W/2))# 卷积层 2self.conv2 = nn.Sequential(# 二维卷积层,输入通道数为 16,输出通道数为 32,卷积核大小为 5,填充为 2nn.Conv2d(16, 32, kernel_size=5, padding=2),# ReLU 激活函数nn.ReLU(),# 最大池化层,池化窗口大小为 2nn.MaxPool2d(kernel_size=2)# 输出维度为 32 * 7 * 7 (C, H/4, W/4))# 全连接层,输入维度为 32 * 7 * 7,输出维度为 10self.fc = nn.Linear(32 * 7 * 7, 10)def forward(self, x):"""前向传播函数,由 torch 自动调用"""x = self.conv1(x)x = self.conv2(x)x = x.reshape(x.size(0), -1)x = self.fc(x)return x

训练和测试函数

import torchdef train(model, train_loader, criterion, optimizer, device):# 设置模型为训练模式model.train()total_loss = 0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 前向传播output = model(data)  # 输出维度为 (batch_size, 10)# 计算损失loss = criterion(output, target)# 计算预测结果_, predicted = output.max(1)# 反向传播和优化optimizer.zero_grad()  # 清空梯度loss.backward()        # 反向传播optimizer.step()       # 更新参数# 统计total_loss += loss.item()total += target.size(0)correct += predicted.eq(target).sum().item()# 打印进度if (batch_idx + 1) % 100 == 0:  # 每 100 个 batch 打印一次print(f'Batch: {batch_idx + 1}/{len(train_loader)}, 'f'Loss: {loss.item():.4f}, 'f'Accuracy: {100. * correct / total:.2f}%')# 记录训练数据writer.add_scalar('Training Loss/Step',loss.item(),epoch * len(train_loader) + batch_idx)# 计算平均损失和准确率avg_loss = total_loss / len(train_loader)accuracy = 100. * correct / total# 计算平均损失和准确率return avg_loss, accuracydef test(model, test_loader, criterion, device):# 设置模型为评估模式model.eval()total_loss = 0correct = 0total = 0# 不计算梯度with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)# 前向传播output = model(data)  # 输出维度为 (batch_size, 10)# 计算预测结果_, predicted = output.max(1)  # 从维度为 1 的维度上取最大值# 计算损失loss = criterion(output, target)# 统计total_loss += loss.item()total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均损失和准确率avg_loss = total_loss / len(test_loader)accuracy = 100. * correct / totalreturn avg_loss, accuracy

主程序

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from model import Model
from train import train, test# 定义超参数
BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 0.001def load_data():"""加载数据"""# 定义数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 加载训练集和测试集train_dataset = torchvision.datasets.MNIST(root='./data', train=True,transform=transform,download=True)test_dataset = torchvision.datasets.MNIST(root='./data',train=False,transform=transform,download=True)# 创建数据加载器train_loader = DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)test_loader = DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=False)return train_loader, test_loaderdef main():device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {device}")train_loader, test_loader = load_data()# 定义模型model = Model().to(device)# 定义损失函数criterion = nn.CrossEntropyLoss()# 定义优化器optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)best_accuracy = 0for epoch in range(EPOCHS):print(f'\nEpoch: {epoch + 1}/{EPOCHS}')# 训练阶段train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, epoch, writer)print(f'Training - Average Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')# 测试阶段test_loss, test_acc = test(model, test_loader, criterion, device, epoch, writer)print(f'Testing - Average Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%')# 保存最佳模型if test_acc > best_accuracy:best_accuracy = test_acctorch.save(model.state_dict(), 'mnist_model.pth')print(f'\nBest Test Accuracy: {best_accuracy:.2f}%')if __name__ == '__main__':main()

TensorBoard 可视化

安装依赖:

pip install tensorboard torch_tb_profiler

修改程序,写入训练日志:

timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(f'runs/mnist_{timestamp}')sample_images, _ = next(iter(train_loader))
writer.add_graph(model, sample_images.to(device))writer.add_scalar('Testing Loss/Epoch', avg_loss, epoch)
writer.add_scalar('Testing Accuracy/Epoch', accuracy, epoch)if epoch == 0:images, labels = next(iter(test_loader))img_grid = torchvision.utils.make_grid(images[:25])writer.add_image('mnist_images', img_grid)

在 VS Code 中,可以使用下面的命令启动 TensorBoard:

> Python: Launch TensorBoard

image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/869025.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

省选构造专题

省选构造专题 The same permutation 首先打个表,发现在 \(1\le n\le 5\) 之内的是否有合法方案的情况为 √√√ 大了打不出来了。 考虑一下 \(4,5\) 连续有解,注意到一个偶数有解,则这个偶数 \(+1\) 也必定有解。 考虑以下构造方法即对于某一个交换,可以在它前后各添加一个…

自动化进程如何优化敏捷开发中的工作流

一、敏捷开发管理工具的现状 1.1 敏捷开发管理工具的基本功能 目前,敏捷开发管理工具的主要功能包括任务管理、进度跟踪、团队协作、资源分配、需求变更管理等。这些工具通常采用看板、任务板、甘特图、Burndown图等形式,帮助团队成员可视化地管理任务、跟踪项目进度、协调跨…

不知道前端代码哪里报错了?我有七种方式去监控它!

大家好,我是桃子,前端小菜鸟一枚,在日常工作中,有时候是不是不知道前端代码哪里报错了今天我给大家分享七中方法去监控它 我们先来说说前端中的错误类型有哪一些 错误类型 1、SyntaxError SyntaxError 是解析时发生语法错误,这个错误是捕获不到的,因为它是发生在构建阶段,…

web.config站内301永久重定向代码示例

注:此代码只适用于IIS服务器,如需要将123.asp重定向到123.html,请使用以下代码。 修改说明: 在web.config文件中添加301重定向规则,将123.asp重定向到123.html。<?xml version="1.0" encoding="UTF-8"?> <configuration><system.web…

请问云服务器需要开放哪些常用端口?

云服务器需要开放的端口与具体使用环境是有关系的,开放的端口越多,存在的安全隐患也就越大,所以开放端口越少越好。服务类型 端口 说明Web服务 80(HTTP), 443(HTTPS) 提供网站访问服务。FTP 21(文件管理) 提供文件传输服务。注:21端口可以关闭或修改。远程连接服务 3…

Audacity 3.7 (Linux, macOS, Windows) - 开源音频编辑器和录音工具

Audacity 3.7 (Linux, macOS, Windows) - 开源音频编辑器和录音工具Audacity 3.7 (Linux, macOS, Windows) - 开源音频编辑器和录音工具 Audacity is the worlds most popular audio editing and recording app 请访问原文链接:https://sysin.org/blog/audacity/ 查看最新版。…

CAP:Serverless + AI 让应用开发更简单

AI 已被广泛视为推动行业进步的关键力量,其在各行业的落地步伐加快。企业在构建 AI 应用开发过程中经常会面临 AI 技术门槛过高、试错周期过长、GPU 资源昂贵且弹性能力不足、缺乏配套工具、业务与模型的开发运维过于割裂、缺乏定制化能力等挑战,成为企业构建 AI 应用的『绊脚…

【运维自动化-作业平台】如何使用全局变量之密文类型?

密文类型的全局变量使用场景相对较少,使用方式也是直接引用即可,目前仅支持shell。一起来看看如何使用实操演示 1、新建作业时创建一个密文类型的全局变量app_secret2、添加一个执行脚本的步骤,脚本里打印下这个全局变量3、调试执行更多应用场景 上面这个示例是用最简单的ec…

微信多开防撤回、防撤回PC版 | WeChat4.0.1.21

点击上方蓝字关注我 前言 很多使用微信电脑版的朋友可能都会遇到一个问题,那就是微信电脑版不能同时登录多个账号。这对于那些需要在电脑上同时管理多个微信账号的人来说,确实很不方便。还有时候,别人撤回了他们发的消息,而我们可能就错过了那些重要的内容。这个版本可以同…

【PCI】PCIe高级错误上报能力AER(十二)

AER AER(Advanced Error Reporting)是一种用于检测和报告PCIe设备中发生的错误的机制,它允许PCIe设备检测到并报告各种类型的错误。错误类型包含Correctable Errors 和Uncorrectable errors两种,其中Uncorrectable errors下面又分为ERR_FATAL和ERR_NONFATAL。Correctable Err…

鸿蒙开发 - 自定义组件 和 组件通信的方法

自定义组件的基本结构 @Entry @Component struct MyComponent {build(){// ...} }build()函数build()函数用于描述组件的UI界面,自定义组件必须定义build()函数 build() {Column() {Text(测试)Button(点击)} }struct 关键字strcut 用来声明数据结构 struct + 自定组件名 + { .…

RN/H5多设备自适应组件库来了,高效实现鸿蒙原生应用多设备精致体验

在原生鸿蒙应用开发中,华为针对ArkUI框架推出了一整套针对多设备适配的完善能力(如“一多”能力)以及高阶组件(如分栏、边看边评等),帮助开发者轻松实现“一次开发,多端部署”。然而,当前鸿蒙生态仍存在大量用跨平台框架开发的应用,部分页面采用React Native(RN)和H…