深度学习-pytorch-nerual network价格预测-004

# 1.导入相关模块
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timefrom torchsummary import summary# 2.构建数据集
def create_dataset():# 使用pandas读取数据data = pd.read_csv('dataset/手机价格预测.csv')# 特征值和目标值x, y = data.iloc[:, :-1], data.iloc[:, -1]# 类型转换:特征值,目标值x = x.astype(np.float32)y = y.astype(np.int64)# 数据集划分x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88)# 构建数据集,转换为pytorch的形式train_dataset = TensorDataset(torch.from_numpy(x_train.values), torch.tensor(y_train.values))valid_dataset = TensorDataset(torch.from_numpy(x_valid.values), torch.tensor(y_valid.values))# 返回结果return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))# 3.构建网络模型
class PhonePriceModel(nn.Module):def __init__(self, input_dim, output_dim):super(PhonePriceModel, self).__init__()# 1. 第一层:输入维度:20,输出维度:128self.linear1 = nn.Linear(input_dim, 128)# 2. 第二层:输入维度:128,输出维度:256self.linear2 = nn.Linear(128, 256)# 3. 第三层:输入维度:256,输出维度:4self.linear3 = nn.Linear(256, output_dim)def forward(self, x):# 前向传播过程x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))output = self.linear3(x)# 获取数据结果return output# 4.模型训练
def train(train_dataset, input_dim, class_num, ):# 固定随机数种子torch.manual_seed(0)# 初始化模型model = PhonePriceModel(input_dim, class_num)# 损失函数criterion = nn.CrossEntropyLoss()# 优化方法optimizer = optim.SGD(model.parameters(), lr=1e-3)# 训练轮数num_epoch = 50# 遍历每个轮次的数据for epoch_idx in range(num_epoch):# 初始化数据加载器dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)# 训练时间start = time.time()# 计算损失total_loss = 0.0total_num = 1# 遍历每个batch数据进行处理for x, y in dataloader:# 将数据送入网络中进行预测output = model(x)# 计算损失loss = criterion(output, y)# 梯度归零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()# 损失计算total_num += 1total_loss += loss.item()# 打印损失变换结果print('epoch: %4s loss: %.2f, time: %.2fs' % (epoch_idx + 1, total_loss / total_num, time.time() - start))# 模型保存torch.save(model.state_dict(), 'model/phone.pth')def test(valid_dataset, input_dim, class_num):# 加载模型和训练好的网络参数model = PhonePriceModel(input_dim, class_num)model.load_state_dict(torch.load('model/phone.pth'))# 构建加载器dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)# 评估测试集correct = 0# 遍历测试集中的数据for x, y in dataloader:# 将其送入网络中output = model(x)# 获取类别结果y_pred = torch.argmax(output, dim=1)# 获取预测正确的个数correct += (y_pred == y).sum()# 求预测精度print('Acc: %.5f' % (correct.item() / len(valid_dataset)))if __name__ == '__main__':# 1.获取数据train_dataset, valid_dataset, input_dim, class_num = create_dataset()print("输入特征数:", input_dim)print("分类个数:", class_num)# 2.模型实例化model = PhonePriceModel(input_dim, class_num)summary(model, input_size=(input_dim,), batch_size=16)# 3.模型训练# train(train_dataset, input_dim, class_num)# 4.模型预测test(valid_dataset, input_dim, class_num)

优化点:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/789272.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

财务知识-会计术语

财务知识-会计术语

selenium爬虫2

无头浏览器简介 无头浏览器(Headless Browser)是一种没有图形用户界面的浏览器,它在后台运行,不会显示任何窗口或界面。无头浏览器通常用于自动化任务,如网页抓取、自动化测试和性能监控等。 爬取票房比如我要爬取上图的2008--2024年的热门电影票房排名 from selenium imp…

Zotero设置

实现Zotero数据在不同电脑间的迁移1. 说明Zotero 中文社区 | 百度网盘使用 zotero 仅同步题录信息,使用其他云同步程序同步文献的附件,此处以坚果云为例进行演示。 准备:zotero 和 坚果云 注册账号 zotero 的插件 zotfile (国内汉化版) 坚果云客户端常用插件:zotfile jasmi…

CentOS 7.9 内核从 3.10 升级到 5.4

1.背景介绍: 环境需求:在搭建 Kubernetes (K8S) 环境时,内核版本最好大于 4.4 以支持 K8S 的所有特性。 当前内核版本:CentOS 7.9 的默认内核版本为 3.10.0-1160.el7.x86_64,不满足 K8S 的推荐内核版本要求。 2.查看内核版本及相关包: 使用命令 uname -r 查看当前内核版本…

基于LangChain手工测试用例转Web自动化测试生成工具

在传统编写 Web 自动化测试用例的过程中,基本都是需要测试工程师,根据功能测试用例转换为自动化测试的用例。市面上自动生成 Web 或 App 自动化测试用例的产品无非也都是通过录制的方式,获取操作人的行为操作,从而记录测试用例。整个过程类似于但是通常录制出来的用例可用性…

PbootCMS网站常见错误提示总结

一些初涉相关领域的新朋友在进行 pbootcms 的安装过程中,往往会频繁遭遇一些错误状况。接下来,为您详细罗列 pbootcms 于后台抑或前台所呈现出的各类问题以及相应的解决办法。1、Parse error: syntax error, unexpected :, expecting { in www\core\function\handle.php on l…

PbootCMS未检测到您服务器环境的sqlite3数据库扩展

在进行相关操作时,未能检测到您服务器环境中的 sqlite3 数据库扩展。在 PbootCms 的安装流程当中,“未检测到您服务器环境的 sqlite3 数据库扩展”这一问题的解决办法扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、…

PbootCMS验证码不显示怎么办

扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网站改版、BUG修复、问题处理、二次开发、PSD转HTML、网站被黑、网站漏洞修复等。专业解决各种疑难杂症,您有任何网站问题都…

PbootCMS您访问路径含有非法字符,防注入系统提醒您请勿尝试非法操作!

您所访问的路径当中包含了非法字符,我们的防注入系统特此提醒您,千万不要尝试进行任何非法操作!扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网站改版、BUG修复、问题…

PbootCMS前台显示留言条数统计

前台所呈现的内容为留言条数的统计情况。 无需进行二次开发,通过运用 sql 标签即可达成。在 PbootCMS 当中,sql 标签的使用实例之一便是对网站留言的总数进行统计。扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、J…

PbootCMS后台登录提示:”登录失败:数据库目录写入权限不足!“

后台登录提示:“登录失败:数据库目录写入权限不足!”通常来说,一般出现权限不足的情况,其中大多数状况都是由于文件夹权限不足所导致的。尤其是在使用 sqlite 的时候,必须要给根目录下的 data 文件夹设定 755 权限。扫码添加技术【解决问题】专注中小企业网站建设、网站安…