神经网络案例实战

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、摄像头像素等。

在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

🔎思路:

  1. 数据预处理:对收集到的数据进行清洗和预处理,确保数据的质量和一致性。这包括处理缺失值、异常值和重复数据等。

  2. 特征工程:从原始数据中提取有用的特征,以便用于建模。这可以包括对连续型特征进行归一化或标准化,对分类特征进行编码等。

  3. 模型选择:选择一个适合的机器学习算法来建立模型,这里我们使用神经网络模型。

  4. 模型训练:将收集到的数据划分为训练集和测试集。使用训练集来训练模型,通过调整模型的参数来最小化预测误差。

  5. 模型评估:使用测试集来评估模型的性能。常用的评估指标包括均方误差(MSE)、均方根误差(RMSE)和平均绝对误差(MAE)等。

  6. 价格预测:使用训练好的模型来预测新手机的价格。输入新手机的功能特征,模型将输出预测的价格。

构建数据集 

💬数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便后期构造数据集加载对象。

def create_dataset():data = pd.read_csv('predict.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)# 构建数据集train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.tensor(x_valid.values), torch.tensor(y_valid.values))return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))train_dataset, valid_dataset, input_dim, class_num = create_dataset()

💬其中 train_test_split 方法中的stratify=y参数的作用是在划分训练集和验证集时,保持类别的比例相同。这样可以确保在训练集和验证集中各类别的比例与原始数据集中的比例相同,有助于提高模型的泛化能力,防止出现一份中某个类别只有几个。

构建分类网络模型

# 构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()self.linear1 = nn.Linear(input_dim, 128) # 输入层到隐藏层的线性变换self.linear2 = nn.Linear(128, 256)  # 隐藏层到输出层的线性变换self.linear3 = nn.Linear(256, output_dim)  # 输出层到最终输出的线性变换def _activation(self, x):return torch.sigmoid(x)def forward(self, x):x = self._activation(self.linear1(x))# 通过第一层线性变换和激活函数x = self._activation(self.linear2(x))# 通过第二层线性变换和激活函数output = self.linear3(x)# 通过第三层线性变换得到最终输出return output
  • self.linear1self.linear2之间的线性变换将输入维度从input_dim映射到128个神经元,然后再将128个神经元映射到256个神经元。
  • 我们通过self._activation方法对输入数据进行激活函数处理。具体来说,它使用Sigmoid激活函数对输入数据进行非线性变换,将线性变换后的输出映射到0和1之间。
  • 第三层: 输入为维度为 256, 输出维度为: 4

编写训练函数 

def train():# 固定随机数种子torch.manual_seed(66)model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)num_epoch = 50for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 准确率correct = 0for x, y in dataloader:output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()total_num += len(y)total_loss += loss.item() * len(y)print('epoch: %4s loss: %.2f, time: %.2fs' %(epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'price-model.bin')
  • 💬要在PyTorch中查看随机数种子,可以使用torch.random.initial_seed()函数。这个函数会返回当前设置的随机数种子。 

评估函数

def test():# 加载模型model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('price-model.bin'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0for x, y in dataloader:output = model(x)y_pred = torch.argmax(output, dim=1)correct += (y_pred == y).sum()print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

网络性能调优

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数

🔎我们可以自行将优化方法由 SGD 调整为 Adam , 学习率由 1e-3 调整为 1e-4 ,对数据数据进行标准化 ,增加网络深度 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/679220.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STC8增强型单片机开发——库函数

一、使用库函数点灯 导入库函数。 下载STC8H的库函数:📎STC8G-STC8H-LIB-DEMO-CODE_2023.07.17_优化版.zip 来到库函数的目录下,拷贝以下文件: Config.hType_def.hGPIO.hGPIO.c 新建项目,将拷贝的4个文件放到项目目录…

Ps中 饱和度 和 自然饱和度 的区别?

1.饱和度(Saturation):在Photoshop中,饱和度是一个全局性调整,它影响图像中所有颜色的鲜艳程度。当你增加饱和度时,所有的颜色都会变得更浓烈、更鲜艳;相反,减小饱和度会使图像整体变…

嵌入式Linux的QT项目CMake工程模板分享及使用指南

在嵌入式linux开发板上跑QT应用,不同于PC上的开发过程。最大的区别就是需要交叉编译,才能在板子上运行。 这里总结下嵌入式linux环境下使用CMake,嵌入式QT的CMake工程模板配置及如何使用,分享给有需要的小伙伴,有用到的…

关于ssrf

首先,先介绍一下ssrf。ssrf即服务器端请求伪造,是一种由攻击者构造形成由服务端发起请求的一个安全漏洞。一般情况下,SSRF攻击的目标是从外网无法访问的内部系统。而且因为请求是由服务端发起的,所以服务端能请求到与自身相连而与…

路由策略与路由控制

1.路由控制工具 匹配工具1:访问控制列表 (1)通配符 当进行IP地址匹配的时候,后面会跟着32位掩码位,这32位称为通配符。 通配符,也是点分十进制格式,换算成二进制后,“0”表示“匹配…

速卖通揭秘:aliexpress.item_get API商品详情返回值全解析

速卖通(AliExpress)是阿里巴巴旗下的一个面向全球市场的B2C电商平台,为卖家提供了一个向全球消费者销售商品的平台。对于开发者来说,速卖通提供了API接口来方便地进行数据交互和集成。其中,item_get API是用于获取商品…

AI换脸原理(6)——人脸分割介绍

一、介绍 人脸分割是计算机视觉和图像处理领域的一项重要任务,它主要涉及到将图像中的人脸区域从背景或其他非人脸区域中分离出来。这一技术具有广泛的应用场景,如人脸识别、图像编辑、虚拟背景替换等。 在计算机视觉(CV)领域,经典的分割技术可以主要划分为三类:语义分…

Python基础详解三

一,函数的多返回值 def methodReturn():return 1,2x,ymethodReturn() print(x,y) 1 2 二,函数的多种参数使用形式 缺省参数: def method7(name,age,address"淄博"):print("name:"name",age"str(age)&quo…

Git系列:Git Stash临时保存与恢复工作进度

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

超越传统游戏:生成式人工智能对游戏的变革性影响

人工智能(AI)在游戏中的应用 游戏产业是一个充满活力、不断发展的领域,人工智能(AI)的融入对其产生了重大影响。这一技术进步彻底改变了游戏的开发、玩法和体验方式。本文分析的重点是传统人工智能和生成式人工智能在游…

通讯录项目—顺序表实现

在上次我介绍顺序表后相信大家对顺序表有了一定的了解,现在就让我们来练练如何用它,这篇是在顺序表基础上新增的(建议看看线性表—顺序表实现-CSDN博客)。 目录 通讯录简介 创建用户信息 适配和理解通讯录 功能实现 初始化通讯录 销毁通讯录 增加…

java回调机制

目录 一、简介二、示例2.1 同步回调2.2 异步回调2.3 二者区别 三、应用场景 一、简介 在Java中,回调是一种常见的编程模式,它允许一个对象将某个方法作为参数传递给另一个对象,以便在适当的时候调用该方法。 以类A调用类B方法为例: 在类A中…