基于Pytorch实现图像分类——基于jupyter

分类任务

  • 网络基本构建与训练方法,常用函数解
  • torch.nn.functional模块
  • nn.Module模块

MNIST数据集下载

from pathlib import Path
import requestsDATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"PATH.mkdir(parents=True, exist_ok=True)URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"if not (PATH / FILENAME).exists():content = requests.get(URL + FILENAME).content(PATH / FILENAME).open("wb").write(content)

解压数据集

import pickle
import gzipwith gzip.open((PATH / FILENAME).as_posix(), "rb") as f:((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

查阅数据

from matplotlib import pyplot
import numpy as nppyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)
from matplotlib import pyplot
import numpy as np
pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
print(x_train.shape)

在这里插入图片描述

网络模型搭建

在这里插入图片描述

import torchx_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

在这里插入图片描述

常用函数介绍

import torch.nn.functional as Floss_func = F.cross_entropydef model(xb):return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)print(loss_func(model(xb), yb))

模型搭建

from torch import nnclass Mnist_NN(nn.Module):def __init__(self):super().__init__()self.hidden1 = nn.Linear(784, 128)self.hidden2 = nn.Linear(128, 256)self.out  = nn.Linear(256, 10)def forward(self, x):x = F.relu(self.hidden1(x))x = F.relu(self.hidden2(x))x = self.out(x)return x
net = Mnist_NN()
print(net)

Mnist_NN(
(hidden1): Linear(in_features=784, out_features=128, bias=True)
(hidden2): Linear(in_features=128, out_features=256, bias=True)
(out): Linear(in_features=256, out_features=10, bias=True)
)

for name, parameter in net.named_parameters():print(name, parameter,parameter.size())

dataset数据接口

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoadertrain_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)def get_data(train_ds, valid_ds, bs):return (DataLoader(train_ds, batch_size=bs, shuffle=True),DataLoader(valid_ds, batch_size=bs * 2),)
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
from torch import optim
def fit(steps, model, loss_func, opt, train_dl, valid_dl):for step in range(steps):model.train()for xb, yb in train_dl:loss_batch(model, loss_func, xb, yb, opt)model.eval()with torch.no_grad():losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)print('当前step:'+str(step), '验证集损失:'+str(val_loss))def get_model():model = Mnist_NN()return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):loss = loss_func(model(xb), yb)if opt is not None:loss.backward()opt.step()opt.zero_grad()return loss.item(), len(xb)
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/616438.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CSS核心样式-03-浮动+背景属性

目录 五、标准文档流 1. 微观现象 ①空白折叠现象 ②文字类的元素如果排在一行会出现一种高低不齐、底边对齐效果 ③自动换行 2. 元素等级 ①块级元素 ②行内元素 ③ 行内块元素 六、显示模式 display display四个属性值 脱离标准流 七、浮动属性(脱标…

超越传统Lambda函数:深入解析Out-of-line Lambdas的奇妙之处

超越传统函数:深入解析线外 Lambda函数 的奇妙之处 一、背景二、lambda 的捕获三、可能出现的警告四、lambda的广义捕获五、为每种情况进行重载六、总结 一、背景 Out-of-line Lambdas翻译过来就是“线外Lambda函数”或“离线Lambda函数”。Lambda 是使代码更具表现…

【opencv】示例-peopledetect.cpp HOG(方向梯度直方图)描述子和SVM(支持向量机)进行行人检测...

// 包含OpenCV项目所需的objdetect模块头文件 #include <opencv2/objdetect.hpp> // 包含OpenCV项目所需的highgui模块头文件&#xff0c;用于图像的显示和简单操作 #include <opencv2/highgui.hpp> // 包含OpenCV项目所需的imgproc模块头文件&#xff0c;用于图像…

在vue中配置样式 max-width:100px时,发现和width:100px一样没有对应的递增到最大宽度的效果?怎么回事?怎么解决?

原因&#xff1a; 可能时vue的样式大部分和display相关&#xff0c;有很多的联系&#xff0c;导致不生效 解决&#xff1a; 对设置max-width样式的元素设置display:inline-block;属性&#xff0c;即可生效&#xff0c;实现随着子元素的扩展而扩展并增加固定到最大的宽度

视频批量高效剪辑,支持将视频文件转换为音频文件,轻松掌握视频格式

在数字化时代&#xff0c;视频内容日益丰富&#xff0c;管理和编辑这些视频变得愈发重要。然而&#xff0c;传统的视频剪辑软件往往操作复杂&#xff0c;难以满足高效批量处理的需求。现在&#xff0c;一款全新的视频批量剪辑神器应运而生&#xff0c;它支持将视频文件一键转换…

【opencv】示例-phase_corr.cpp 捕获视频流并通过计算相位相关性来检测画面中的移动...

// 包含OpenCV库的头文件 #include "opencv2/core.hpp" // 包含OpenCV核心功能 #include "opencv2/videoio.hpp" // 包含视频IO功能 #include "opencv2/highgui.hpp" // 包含高级GUI功能&#xff0c;显示图像 #include "opencv2/imgproc.hp…

数据结构之单链表相关刷题

找往期文章包括但不限于本期文章中不懂的知识点&#xff1a; 个人主页&#xff1a;我要学编程(ಥ_ಥ)-CSDN博客 所属专栏&#xff1a;数据结构 数据结构之单链表的相关知识点及应用-CSDN博客 下面题目基于上面这篇文章&#xff1a; 下面有任何不懂的地方欢迎在评论区留言或…

Open CASCADE学习|BRepOffsetAPI_DraftAngle

BRepOffsetAPI_DraftAngle 是 Open CASCADE Technology (OCCT) 中用于创建带有草图斜面的几何体的类。草图斜面是一种在零件设计中常见的特征&#xff0c;它可以在零件的表面上创建一个倾斜的面&#xff0c;通常用于便于零件的脱模或是增加零件的强度。 本例演示了如何创建一个…

mp3怎样才能转换成wav格式?音频互相转换的方法

一&#xff0c;什么是WAV WAV&#xff0c;全称为波形音频文件&#xff08;Waveform Audio File Format&#xff09;&#xff0c;是一种由微软公司和IBM公司联合开发的音频文件格式。自1991年问世以来&#xff0c;WAV格式因其无损的音频质量和广泛的兼容性&#xff0c;成为了多…

在Linux驱动中,如何确保中断上下文的正确保存和恢复?

大家好&#xff0c;今天给大家介绍在Linux驱动中&#xff0c;如何确保中断上下文的正确保存和恢复&#xff1f;&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 在Linux驱动中&am…

配置交换机端口安全

1、实验目的 通过本实验可以掌握&#xff1a; 交换机管理地址配置及接口配置。查看交换机的MAC地址表。配置静态端口安全、动态端口安全和粘滞端口安全的方法。 2、实验拓扑 配置交换机端口安全的实验拓扑如图所示。 配置交换机端口安全的实验拓扑 3、实验步骤 &#xff…

RetinalNet论文笔记

RetinalNet 概述1. 引言2. 相关工作3. 焦点损失4. RetinaNet Detector 检测器5. 实验6. 结论 3. Focal loss3.1. 平衡交叉熵3.2. 焦点损失定义3.3. 类别不平衡和模型初始化3.4. 类别不平衡和两阶段检测器 4. RetinaNet Detector特征金字塔网络骨干&#xff08;Feature Pyramid …