【学习笔记】深度学习实战 | PyTorch 入门(MLP为例)

在这里插入图片描述

简要声明


  1. 学习相关网址
    1. [双语字幕]吴恩达深度学习deeplearning.ai
    2. Papers With Code
    3. Datasets
  2. 深度学习网络基于PyTorch学习架构,代码测试可跑。
  3. 本学习笔记单纯是为了能对学到的内容有更深入的理解,如果有错误的地方,恳请包容和指正。

参考文献


  1. PyTorch Tutorials [https://pytorch.org/tutorials/]
  2. PyTorch Docs [https://pytorch.org/docs/stable/index.html]

简要介绍


MLP (Multilayer Perceptron)

在这里插入图片描述

DatasetMNIST
Input (feature maps)32×32 (28×28)
CONV Layers0
FC Layers3
ActivationReLU
Output10

代码分析


函数库调用

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

处理数据

数据下载

# 从开放数据集中下载训练数据
train_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 从开放数据集中下载测试数据
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)print(f'Number of training examples: {len(train_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 60000
Number of testing examples: 10000

数据加载器(可选)

batch_size = 64# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

创建模型

# 选择训练设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device

class MLP(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.input_layer = nn.Sequential(nn.Linear(input_dim, 250),nn.ReLU())self.hidden_layer = nn.Sequential(nn.Linear(250, 100),nn.ReLU())self.output_layer = nn.Sequential(nn.Linear(100, output_dim))def forward(self, x):x = x.view(x.size(0), -1)x = self.input_layer(x)x = self.hidden_layer(x)x = self.output_layer(x)return xmodel = MLP(28*28, 10).to(device)
print(model)

MLP(
(input_layer): Sequential(
(0): Linear(in_features=784, out_features=250, bias=True)
(1): ReLU()
)
(hidden_layer): Sequential(
(0): Linear(in_features=250, out_features=100, bias=True)
(1): ReLU()
)
(output_layer): Sequential(
(0): Linear(in_features=100, out_features=10, bias=True)
)
)

训练模型

选择损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

训练循环

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

测试循环

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练模型

epochs = 5
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1

loss: 2.304649 [ 64/60000]
loss: 0.350683 [ 6464/60000]
loss: 0.267444 [12864/60000]
loss: 0.305221 [19264/60000]
loss: 0.200744 [25664/60000]
loss: 0.316856 [32064/60000]
loss: 0.156469 [38464/60000]
loss: 0.280946 [44864/60000]
loss: 0.291244 [51264/60000]
loss: 0.199387 [57664/60000]
Test Error:
Accuracy: 94.7%, Avg loss: 0.169173

模型处理

保存模型

torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

Saved PyTorch Model State to model.pth

加载模型

model = MLP(28*28, 10).to(device)
model.load_state_dict(torch.load("model.pth"))

重要函数


torch.cuda.is_available()返回一个布尔值,指示 CUDA 当前是否可用
nn.Sequential用于存储 Module 的列表
nn.Linear线性变换
nn.ReLU修正线性单位函数
nn.CrossEntropyLoss交叉熵损失
torch.optim.AdamAdam 算法
torch.save保存模型
torch.load加载模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/496033.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC了解

1.springMVC概述 Spring MVC(Model-View-Controller)是基于 Java 的 Web 应用程序框架,用于开发 Web 应用程序。它通过将应用程序分为模型(Model)、视图(View)和控制器(Controller&a…

求两个向量之间的夹角

求两个向量之间的夹角 介绍Unity的API求向量夹角Vector3.AngleVector3.SignedAngle 自定义获取方法0-360度的夹角 总结 介绍 求两个向量之间的夹角方法有很多,比如说Unity中的Vector3.Angle,Vector3.SignedAngle等方法,具体在什么情况下使用…

Swagger3 使用详解

Swagger3 使用详解 一、简介1 引入依赖2 开启注解3 增加一个测试接口4 启动服务报错1.5 重新启动6 打开地址:http://localhost:8093/swagger-ui/index.html 二、Swagger的注解1.注解Api和ApiOperation2.注解ApiModel和ApiModelProperty3.注解ApiImplicitParams和Api…

数据库管理-第156期 Oracle Vector DB AI-07(20240227)

数据库管理156期 2024-02-27 数据库管理-第156期 Oracle Vector DB & AI-07(20240227)1 Vector相关DDL操作可以在现有的表上新增vector数据类型的字段:可以删除包含vector数据类型的列:可以使用CTAS的方式,从其他有…

Mendix 10.7 发布- Go Mac It!

在我们上个月发布了硕果累累的 Mendix 10.6 MTS 之后,您是否还没有抚平激动的情绪?好吧,不管您是否已经准备好,本月将带来另一个您想知道的大亮点——Mac版Studio Pro!但这还不是全部。本月,我们还将推出Re…

jenkins+kubernetes+git+dockerhub构建devops云平台

Devops简介 k8s助力Devops在企业落地实践 传统方式部署项目为什么发布慢,效率低? 上线一个功能,有多少时间被浪费了? 如何解决发布慢,效率低的问题呢? 什么是Devops? 敏捷开发 提高开发效率&…

【通讯录案例-tabbarController结构 Objective-C语言】

一、接下来,我们来说一下,tabbarController的View结构 1.实际上,这个tabbarController的结构呢,跟这个导航控制器的结构,差不多, 它里边儿呢,首先,有一个tabbarController的View, tabbarController,实际上,里边儿,有一个View,是专门儿来放子控制器View的, nav…

计算机网络——IPV4数字报

1. IPv4数据报的结构 本结构遵循的是RFC 791规范,介绍了一个IPv4数据包头部的不同字段。 1.1 IPv4头部 a. 版本(Version):指明了IP协议的版本,IPv4表示为4。 b. 头部长度(IHL, Internet Header Length&…

【六袆-Golang】Golang:安装与配置Delve进行Go语言Debug调试

安装与配置Delve进行Go语言Debug调试 一、Delve简介二、win-安装Delve三、使用Delve调试Go程序[命令行的方式]四、使用Golang调试程序 Golang开发工具系列:安装与配置Delve进行Go语言Debug调试 摘要: 开发环境中安装和配置Delve,一个强大的G…

算法打卡day5|哈希表篇01|Leetcode 242.有效的字母异位词 、19.删除链表的倒数第N个节点、202. 快乐数、1. 两数之和

哈希表基础知识 哈希表 哈希表关键码就是数组的索引下标,然后通过下标直接访问数组中的元素;数组就是哈希表的一种 一般哈希表都是用来快速判断一个元素是否出现集合里。例如要查询一个名字是否在班级里: 要枚举的话时间复杂度是O(n)&…

Leetcoder Day25| 回溯part05:子集+排列

491.递增子序列 给定一个整型数组, 你的任务是找到所有该数组的递增子序列,递增子序列的长度至少是2。 示例: 输入:[4, 7, 6, 7]输出: [[4, 6], [4, 7], [4, 6, 7], [6, 7], [7,7], [4,7,7]] 说明: 给定数组的长度不会超过15。数组中的整数范围是 [-100,100]。给定数…

【Python】Code2flow学习笔记

1 Code2flow介绍 Code2flow是一个代码可视化工具库,旨在帮助开发人员更好地理解和分析代码: 可以将Python代码转换为流程图,以直观的方式展示代码的执行流程和逻辑结构。具有简单易用、高度可定制化和美观的特点,适用于各种代码…