深度学习笔记_7经典网络模型LSTM解决FashionMNIST分类问题

1、 调用模型库,定义参数,做数据预处理

import numpy as np
import torch
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from torch import nn
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt# 检查 GPU 可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)# 设置超参数
sequence_length = 28
input_size = 28  
hidden_size = 128
num_layers = 2 
num_classes = 10
batch_size = 64
learning_rate = 0.001
num_epochs = 50# 定义数据转换操作
transform = transforms.Compose([transforms.RandomRotation(degrees=[-30, 30]),   # 随机旋转transforms.RandomHorizontalFlip(),   # 随机水平翻转transforms.RandomCrop(size=28, padding=4),   # 随机裁剪transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),   # 颜色抖动transforms.ToTensor(),  # 将图像转换为张量transforms.Normalize((0.5,), (0.5,))
])

2、下载FashionMNIST训练集

# 下载FashionMNIST训练集
trainset = FashionMNIST(root='data', train=True,download=True, transform=transform)# 下载FashionMNIST测试集
testset = FashionMNIST(root='data', train=False,download=True, transform=transform)# 创建 DataLoader 对象
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

3、定义LSTM模型

# 定义LSTM模型
class LSTM(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes):super(LSTM, self).__init__()self.hidden_size = hidden_size  # LSTM隐含层神经元数self.num_layers = num_layers  # LSTM层数self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  # LSTM层self.fc = nn.Linear(hidden_size, num_classes)  # 全连接层def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)  # 初始化状态c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)out, _ = self.lstm(x, (h0, c0))  # LSTM前向传播out = self.fc(out[:, -1, :])  # 只取序列最后一个时间步的输出return F.log_softmax(out, dim=1)  # 使用log_softmax作为输出# 初始化模型、优化器和损失函数
model = LSTM(input_size, hidden_size, num_layers, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()# 记录训练和测试过程中的损失和准确率
train_losses = []
test_losses = []
train_accuracies = []
test_accuracies = []conf_matrix_list = []
accuracy_list = []
error_rate_list = []
precision_list = []
recall_list = []
f1_score_list = []
roc_auc_list = []

4、 训练循环

for epoch in range(num_epochs):model.train()train_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上data = data.view(-1, sequence_length, input_size)output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()# 计算训练准确率_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()# 计算平均训练损失和训练准确率train_loss /= len(train_loader)train_accuracy = 100. * correct / totaltrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 测试模型model.eval()test_loss = 0.0correct = 0all_labels = []all_preds = []with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)  # 将数据移到 GPU 上data = data.view(-1, sequence_length, input_size)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()all_labels.extend(target.cpu().numpy())  # 将结果移到 CPU 上all_preds.extend(pred.cpu().numpy())  # 将结果移到 CPU 上# 计算平均测试损失和测试准确率test_loss /= len(test_loader)test_accuracy = 100. * correct / len(test_loader.dataset)test_losses.append(test_loss)test_accuracies.append(test_accuracy)# 计算额外的指标conf_matrix = confusion_matrix(all_labels, all_preds)conf_matrix_list.append(conf_matrix)accuracy = accuracy_score(all_labels, all_preds)accuracy_list.append(accuracy)error_rate = 1 - accuracyerror_rate_list.append(error_rate)precision = precision_score(all_labels, all_preds, average='weighted')recall = recall_score(all_labels, all_preds, average='weighted')f1 = f1_score(all_labels, all_preds, average='weighted')precision_list.append(precision)recall_list.append(recall)f1_score_list.append(f1)fpr, tpr, thresholds = roc_curve(all_labels, all_preds, pos_label=1)roc_auc = auc(fpr, tpr)roc_auc_list.append(roc_auc)# 打印每个 epoch 的指标print(f'Epoch [{epoch + 1}/{num_epochs}] -> Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%')
# 打印或绘制训练后的最终指标
print(f'Final Confusion Matrix:\n{conf_matrix_list[-1]}')
print(f'Final Accuracy: {accuracy_list[-1]:.2%}')
print(f'Final Error Rate: {error_rate_list[-1]:.2%}')
print(f'Final Precision: {precision_list[-1]:.2%}')
print(f'Final Recall: {recall_list[-1]:.2%}')
print(f'Final F1 Score: {f1_score_list[-1]:.2%}')
print(f'Final ROC AUC: {roc_auc_list[-1]:.2%}')

5、绘制Loss、Accuracy曲线图, 计算混淆矩阵

import seaborn as sns
# 绘制Loss曲线图
plt.figure()
plt.plot(train_losses, label='Train Loss', color='blue')
plt.plot(test_losses, label='Test Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()# 绘制Accuracy曲线图
plt.figure()
plt.plot(train_accuracies, label='Train Accuracy', color='red')  # 绘制训练准确率曲线
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curve')
plt.grid(True)
plt.savefig('accuracy_curve.png')
plt.show()# 计算混淆矩阵
class_labels = [str(i) for i in range(10)]
confusion_mat = confusion_matrix(all_labels, all_preds)
plt.figure()
sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.savefig('confusion_matrix.png')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/285237.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python 爬虫之简单的爬虫(三)

爬取动态网页(上) 文章目录 爬取动态网页(上)前言一、大致内容二、基本思路三、代码编写1.引入库2.加载网页数据3.获取指定数据 总结 前言 之前的两篇写的是爬取静态网页的内容,比较简单。接下来呢给大家讲一下如何去…

Linux之FTP 服务器

一、FTP服务器匿名账户服务器配置 1、测试是否已安装vsftp服务器: 2、启动vsftp服务器: 3、修改vsftp主配置文件,允许匿名登录 4、重新启动vsftpd服务,禁用防火墙 5、打开FTP服务的数据文件存放目录/var/ftp,复制若干文件到该目…

智慧养老:创新科技让老年生活更美好

智慧养老:创新科技让老年生活更美好 随着人口老龄化的加剧,智慧养老成为了关注焦点。智慧养老以创新科技为核心,旨在改善老年人的生活品质、促进健康、增强安全感和社会融入感。本文将详细介绍智慧养老的关键技术和应用场景,带您了…

回归预测 | MATLAB实现SABO-LSTM基于减法平均优化器优化长短期记忆神经网络的多输入单输出数据回归预测模型 (多指标,多图)

回归预测 | MATLAB实现SABO-LSTM基于减法平均优化器优化长短期记忆神经网络的多输入单输出数据回归预测模型 (多指标,多图) 目录 回归预测 | MATLAB实现SABO-LSTM基于减法平均优化器优化长短期记忆神经网络的多输入单输出数据回归预测模型 &a…

Xxl-job-admin 数据库使用DM8/达梦改造

Xxl-job 简介 XXL-JOB是一个分布式任务调度平台,其核心设计目标是开发迅速、学习简单、轻量级、易扩展。 XXL-JOB-ADMIN 是针对分布式定时任务管理的Web管理平台,默认使用的数据库是MySQL 8版本。 业务背景 在项目中使用分布式定时任务调度框架:xxl-…

Linux---Ubuntu软件卸载

1. 软件卸载的介绍 Ubuntu软件卸载有两种方式: 离线安装包的卸载(deb 文件格式卸载)在线安装包的卸载(apt-get 方式卸载) 2. deb 文件格式卸载 命令格式: sudo dpkg –r 安装包名 -r 选项表示安装的卸载 dpkg 卸载效果图: 3. apt-get 方式卸载 命令格式: …

mybatis.interceptor.exception.SqLValidateException:Ilegal SQL::......

现象:⬇️ 描述:执行 SQL 没问题,应用代码报错 ⬇️ .mybatis.interceptor.exception.SqLValidateException:Ilegal SQL::SELECT voucherNo FROM voucher ORDER BY CAST(SUBSTRING(voucherNo FROM LOCATE(_, voucherNo) 1) AS U…

门控网络简介

门控网络是一种循环神经网络 (RNN),它使用门来控制信息在时间步之间的流动。门是一种神经网络层,它可以选择性地允许或阻止信息通过。 门控网络的主要优点是它们可以解决传统 RNN 中存在的梯度消失问题。梯度消失是指随着时间步的增加,梯度会…

大模型(LLM)+词槽(slot)构建动态场景多轮对话系统

构建动态场景多轮对话系统 引言 在人工智能和自然语言处理领域,聊天机器人的开发一直是一个热点话题。近年来,随着大型语言模型(LLM)的进步,构建能够理解和响应各种用户需求的聊天机器人变得更加可行和强大。本文将介…

什么店生意好?C++采集美团商家信息做数据分析

最近遇到几个朋友,想要一起合伙投资一个实体店,不问类型,就看哪类产品相对比较受欢迎。抛除地址位置,租金的影响,我们之谈产品。因此,我熬了几个通宵,写了这么一段爬取美团商家商品信息的数据并…

Idea远程debugger调试

当我们服务部署在服务器上,我们想要像在本地一样debug,就可以使用idea自带的Remote JVM Debug 创建Remote JVM Debug服务器启动jar打断点进入断点 当我们服务部署在服务器上,我们想要像在本地一样debug,就可以使用idea自带的 Remote JVM Debug) 创建Rem…

AX7A200教程(9): ov5640摄像头输出显示720p视频

一,功能框图 ov5640摄像头视频通过ddr3缓存后,最后使用hdmi接口进行输出显示 二,摄像头硬件说明 2.1,像头硬件管脚 如下图所示,一共18个管脚 2.2,摄像头电源初始化时序 因这个ov5640摄像头是买的老摄像…