深度学习技巧应用22-构建万能数据生成类的技巧,适用于CNN,RNN,GNN模型的调试与训练贯通

大家好,我是微学AI,今天给大家介绍一下深度学习技巧应用22-构建万能数据生成类的技巧,适用于CNN,RNN,GNN模型的调试与训练贯通。本文将实现了一个万能数据生成类的编写,并使用PyTorch框架训练CNN、RNN和GNN模型。

目录:
1.背景介绍
2.依赖库介绍
3.万能的数据生成器介绍
4.CNN,RNN,GNN模型搭建
5.数据生成与模型训练
6.训练结果与总结
在这里插入图片描述

1.背景介绍

在人工智能模型训练过程中,我们需要进行一些实验、测试或调试,我们可能需要一个具有特定形状和数量的数据集来验证我们的算法或模型。通过构建一个万能的数据生成器,我们可以灵活地生成各种形状和大小的数据集,无需手动制作和准备真实数据集。

其次,数据生成器还可以用于探索数据集的性质和特征。通过生成具有特定分布、特征或规律的数据集,我们可以更深入地了解数据之间的关系、特征之间的相互影响以及数据的结构等。这对于数据预处理、特征工程和模型选择都非常有帮助。

此外,数据生成器还可以用于实现数据增强技术。数据增强是指通过对原始数据进行一系列变换或扰动来生成新的训练样本,以增加训练数据的多样性和泛化能力。通过构建一个万能的数据生成器,我们可以定义各种数据增强方法,并在训练过程中动态地生成增强后的样本,从而提高模型的稳健性和可靠性。

2.依赖库介绍

首先,我们需要引入以下依赖库:

  • torch:PyTorch框架
  • torch.optim:torch.optim是PyTorch框架中的一个模块,用于优化模型的参数。它提供了各种优化算法,如随机梯度下降(SGD)、Adam、Adagrad等。通过选择适当的优化算法和调整参数,可以使模型在训练过程中更好地收敛并获得更好的性能。
  • torch.utils.data:torch.utils.data是PyTorch框架中的一个模块,用于处理数据集的工具类。它提供了一些常用的数据处理操作,如数据加载、批量处理、数据迭代和数据转换等。通过使用torch.utils.data,可以方便地将数据集加载到模型中进行训练,并且能够灵活地处理不同格式的数据。
  • numpy:numpy是一个Python库,主要用于进行数值计算和科学计算。它提供了多维数组对象(ndarray)和一系列用于操作数组的函数。numpy可以高效地进行数值运算,并且支持广播(broadcasting)和向量化操作,因此在科学计算、数据分析和机器学习等领域都得到广泛应用。在PyTorch中,numpy可以与torch.Tensor进行无缝的转换,方便进行数据的处理和转换。

3.万能的数据生成器介绍

首先我们需要定义了一个名为UniversalDataset的数据集类,用于生成具有特定形状和数量的数据和标签。

在类的初始化方法__init__中,我们传入了三个参数:data_shape表示数据的形状(一个元组),target_shape表示标签的形状(一个元组),num_samples表示数据集中样本的数量。通过这三个参数生成数据。

接着,我们实现了__len__方法,该方法返回数据集中样本的数量,即num_samples。

再定义__getitem__方法,该方法根据索引idx返回数据集中索引对应的数据和标签。在这个方法中,我们首先创建了一个与data_shape相同形状的全零张量data,以及一个与target_shape相同形状的全零张量target。

然后,我们分别计算了数据和标签的维度,即data_dims和target_dims。

本文使用torch.linspace函数在0和1之间生成长度为data_dim_size的等间隔数据范围data_range,并通过reshape方法将其重新塑形为data_shape_expanded形状的张量。然后,我们将这个塑形后的数据范围加到数据张量data上。

我们对标签也进行了类似的操作,生成了一个有规律的标签张量target。

最后,我们返回了数据张量data和标签张量target作为这个索引对应的样本。

通过这个类,我们可以根据需要生成具有指定形状和数量的数据集,并且数据和标签都是有规律的,方便进行后续的训练和评估。

import torch
from torch import nn
from torch.utils.data import DataLoader, Datasetclass UniversalDataset(Dataset):def __init__(self, data_shape, target_shape, num_samples):self.data_shape = data_shapeself.target_shape = target_shapeself.num_samples = num_samplesdef __len__(self):return self.num_samplesdef __getitem__(self, idx):# 生成数据和标签data = torch.zeros(self.data_shape)target = torch.zeros(self.target_shape)# 计算数据和标签的维度data_dims = len(self.data_shape)target_dims = len(self.target_shape)# 生成有规律的数据和标签for dim in range(data_dims):data_dim_size = self.data_shape[dim]data_range = torch.linspace(0, 1, data_dim_size)data_shape_expanded = [1] * data_dimsdata_shape_expanded[dim] = data_dim_sizedata += data_range.reshape(data_shape_expanded)for dim in range(target_dims):target_dim_size = self.target_shape[dim]target_range = torch.linspace(0, 1, target_dim_size)target_shape_expanded = [1] * target_dimstarget_shape_expanded[dim] = target_dim_sizetarget += target_range.reshape(target_shape_expanded)return data, target

4.CNN,RNN,GNN模型搭建

class CNNModel(nn.Module):def __init__(self, input_shape):super(CNNModel, self).__init__()self.conv1 = nn.Conv2d(input_shape[0], 16, kernel_size=3, stride=1, padding=1)self.fc = nn.Linear(16 * (input_shape[1] // 2) * (input_shape[2] // 2), 10)def forward(self, x):x = self.conv1(x)x = nn.functional.relu(x)x = nn.functional.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = self.fc(x)return xclass RNNModel(nn.Module):def __init__(self, input_shape):super(RNNModel, self).__init__()self.rnn = nn.RNN(input_shape[1], 64, batch_first=True)self.fc = nn.Linear(64, 10)def forward(self, x):_, h_n = self.rnn(x)x = self.fc(h_n.squeeze(0))return xclass GNNModel(nn.Module):def __init__(self, input_shape):super(GNNModel, self).__init__()self.fc1 = nn.Linear(input_shape[1], 32)self.fc2 = nn.Linear(32, 10)def forward(self, x):x = torch.mean(x, dim=1)x = self.fc1(x)x = nn.functional.relu(x)x = self.fc2(x)return x

5.数据生成与模型训练

# 定义训练函数
def train(model, dataloader, criterion, optimizer):running_loss = 0.0correct = 0total = 0for inputs, labels in dataloader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, dim=1)total += labels.size(0)correct += (predicted == labels.argmax(dim=1)).sum().item()loss.backward()optimizer.step()running_loss += loss.item()epoch_loss = running_loss / len(dataloader)epoch_acc = correct / totalreturn epoch_loss, epoch_acc# 设置参数
data_shape_cnn = (3, 32, 32)  # (channels, height, width)
target_shape = (10,)
num_samples = 1000
batch_size = 32
learning_rate = 0.001
num_epochs = 10# 创建数据集和数据加载器
dataset = UniversalDataset(data_shape_cnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建CNN模型、优化器和损失函数
cnn_model = CNNModel(data_shape_cnn)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=learning_rate)
cnn_criterion = nn.CrossEntropyLoss()print('CNN模型训练:')
# 训练CNN模型
for epoch in range(num_epochs):cnn_loss, cnn_acc = train(cnn_model, dataloader, cnn_criterion, cnn_optimizer)print(f'CNN - Epoch {epoch+1}/{num_epochs}, Loss: {cnn_loss:.4f}, Accuracy: {cnn_acc:.4f}')# 重新创建数据集和数据加载器
data_shape_rnn = (20, 32)  # (sequence_length, input_size, hidden_size)
dataset = UniversalDataset(data_shape_rnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建RNN模型、优化器和损失函数
rnn_model = RNNModel(data_shape_rnn)
rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate)
rnn_criterion = nn.CrossEntropyLoss()print('RNN模型训练:')
# 训练RNN模型
for epoch in range(num_epochs):rnn_loss, rnn_acc = train(rnn_model, dataloader, rnn_criterion, rnn_optimizer)print(f'RNN - Epoch {epoch+1}/{num_epochs}, Loss: {rnn_loss:.4f}, Accuracy: {rnn_acc:.4f}')# 重新创建数据集和数据加载器
data_shape_gnn = (10, 100)  # (num_nodes, node_features)
dataset = UniversalDataset(data_shape_gnn, target_shape, num_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建GNN模型、优化器和损失函数
gnn_model = GNNModel(data_shape_gnn)
gnn_optimizer = torch.optim.Adam(gnn_model.parameters(), lr=learning_rate)
gnn_criterion = nn.CrossEntropyLoss()print('GNN模型训练:')
# 训练GNN模型
for epoch in range(num_epochs):gnn_loss, gnn_acc = train(gnn_model, dataloader, gnn_criterion, gnn_optimizer)print(f'GNN - Epoch {epoch+1}/{num_epochs}, Loss: {gnn_loss:.4f}, Accuracy: {gnn_acc:.4f}')

6.训练结果与总结

运行结果:

CNN模型训练:
CNN - Epoch 1/10, Loss: 10.4031, Accuracy: 0.5840
CNN - Epoch 2/10, Loss: 10.2561, Accuracy: 1.0000
CNN - Epoch 3/10, Loss: 10.2503, Accuracy: 1.0000
CNN - Epoch 4/10, Loss: 10.2496, Accuracy: 1.0000
CNN - Epoch 5/10, Loss: 10.2495, Accuracy: 1.0000
CNN - Epoch 6/10, Loss: 10.2494, Accuracy: 1.0000
CNN - Epoch 7/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 8/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 9/10, Loss: 10.2493, Accuracy: 1.0000
CNN - Epoch 10/10, Loss: 10.2493, Accuracy: 1.0000
RNN模型训练:
RNN - Epoch 1/10, Loss: 10.3851, Accuracy: 0.9680
RNN - Epoch 2/10, Loss: 10.2606, Accuracy: 1.0000
RNN - Epoch 3/10, Loss: 10.2551, Accuracy: 1.0000
RNN - Epoch 4/10, Loss: 10.2531, Accuracy: 1.0000
RNN - Epoch 5/10, Loss: 10.2520, Accuracy: 1.0000
RNN - Epoch 6/10, Loss: 10.2513, Accuracy: 1.0000
RNN - Epoch 7/10, Loss: 10.2509, Accuracy: 1.0000
RNN - Epoch 8/10, Loss: 10.2506, Accuracy: 1.0000
RNN - Epoch 9/10, Loss: 10.2504, Accuracy: 1.0000
RNN - Epoch 10/10, Loss: 10.2502, Accuracy: 1.0000
GNN模型训练:
GNN - Epoch 1/10, Loss: 10.9591, Accuracy: 0.0400
GNN - Epoch 2/10, Loss: 10.3914, Accuracy: 1.0000
GNN - Epoch 3/10, Loss: 10.2818, Accuracy: 1.0000
GNN - Epoch 4/10, Loss: 10.2635, Accuracy: 1.0000
GNN - Epoch 5/10, Loss: 10.2569, Accuracy: 1.0000
GNN - Epoch 6/10, Loss: 10.2539, Accuracy: 1.0000
GNN - Epoch 7/10, Loss: 10.2524, Accuracy: 1.0000
GNN - Epoch 8/10, Loss: 10.2515, Accuracy: 1.0000
GNN - Epoch 9/10, Loss: 10.2509, Accuracy: 1.0000
GNN - Epoch 10/10, Loss: 10.2505, Accuracy: 1.0000

本文主要介绍了如何创建一个万能的数据生成类,可以根据输入的形状参数生成不同形状的数据。然后,将生成的数据和标签输入到CNN、RNN和GNN模型中进行训练,并打印出损失值和准确率。后续我们可以根据实际应用中可能需要根据具体任务做更多修改和扩展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/13614.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Jupyter Notebook的内核添加新的虚拟环境

最近,在搭建环境的时候发现 Jupyter Notebook 的内核只有基础的python和pytorch,现在我想要在 Jupyter Notebook 中使用新的虚拟环境。 下面是解决的方法: (1)首先在Anaconda Prompt中激活虚拟环境,比如我…

LIN总线与RS485总线

LIN(Local Interconnect Network,局部互连网络)总线和RS485都是用于设备间通信的串行通信协议。下面我将分别列出它们的优势和劣势。 LIN总线的优势: 简单性:LIN总线的硬件和协议简单,易于实现和维护。成…

设计模式之装饰者模式-TS中装饰器介绍

装饰器的基本介绍 装饰器是一种特殊类型的声明,它能够被附加到类声明,方法,访问符,属性或参数上。 装饰器使用expression这种形式,expression求值后必须为一个函数,它会在运行时被调用,被装饰的…

SQL专家云回溯某时间段内的阻塞

背景 SQL专家云像“摄像头”一样,对环境、参数配置、服务器性能指标、活动会话、慢语句、磁盘空间、数据库文件、索引、作业、日志等几十个运行指标进行不同频率的实时采集,保存到SQL专家云自己的数据库中。因此可以随时对任何一个时间段进行回溯。 趋势…

vue项目打包并配置到iOS工程中

一、修改vue项目的配置文件 将config文件夹里面的index.js中的 assetsPublicPath的值修改为“./” Webpack.prod.conf.js 中output添加参数publicPath:./ 在webpack.base.conf.js里 publicPath: process.env.NODE_ENV 生产 ?./ config.build.assetsPublicPath :…

flutter聊天界面-Text富文本表情emoji、url、号码展示

flutter聊天界面-Text富文本表情emoji、url、号码展示 Text富文本表情emoji展示,主要通过实现Text.rich展示文本、emoji、自定义表情、URL等 一、Text及TextSpan Text用于显示简单样式文本 TextSpan它代表文本的一个“片段”,不同“片段”可按照不同的…

web-html的基本用法

web前端代码基本用法 <html> <head><meta charset"utf-8"><!-- charset 属性规定 HTML 文档的字符编码。要是没有规定字符编码的话是有可能乱码的 -->待到秋来九月八&#xff08;head&#xff09;<!-- 头部就是直接写在最上面的文字&…

尚无忧餐桌预订订桌包厢预订小程序源码

1.支持中餐、晚餐不同时间段桌位预定 2.支持包厢&#xff0c;大厅等不同区域预定 本系统后台tpvue 前端原生小程序 <!-- 导航栏 --> <!-- <van-nav-bar title"{{canteen}}" title-class"navbar" /> --> <van-nav-bar title"…

Spring Boot 中的服务发现

Spring Boot 中的服务发现 Spring Boot 是一个非常流行的 Java Web 开发框架&#xff0c;它提供了很多工具和组件来简化 Web 应用程序的开发。其中&#xff0c;服务发现是 Spring Boot 中的一个非常重要的组件&#xff0c;它可以帮助我们自动地发现和管理应用程序中的服务。 什…

树莓派(香橙派)交叉编译

目录 1、交叉编译是什么 2、为什么要交叉编译&#xff1f; 3、交叉编译需要用到什么工具&#xff1f; 4、&#xff08;香橙派&#xff09;交叉编译工具链的安装 5、 交叉编译服务端客户端 6、 带wiringPi库的交叉编译如何进行 1、交叉编译是什么 交叉编译是在一个平台上生…

盛最多水的容器(力扣)双指针 JAVA

给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说明&#xff1a;你不能倾斜容器。 输入&…

JAVA开发( 腾讯云消息队列 RocketMQ使用总结 )

一、问题背景 之所以需要不停的总结是因为在java开发过程中使用到中间件实在太多了&#xff0c;久久不用就会慢慢变得生疏&#xff0c;有时候一个中间很久没使用&#xff0c;可能经过了很多版本的迭代&#xff0c;使用起来又有区别。所以还是得不断总结更新。最近博主就是在使用…