学习Pytorch深度学习运行AlexNet代码时关于在Pycharm中解决 “t >= 0 t < n_classes” 的断言错误方法

在学习深度学习的过程中,遇到了一个报错:

 这跑的代码是AlexNet的代码实现。

运行时出现报错:

C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\cuda\Loss.cu:257: block: [0,0,0], thread: [4,0,0] Assertion `t >= 0 && t < n_classes` failed.

解决方案:

当你遇到 CUDA error: device-side assert triggered 和具体的断言错误 t >= 0 && t < n_classes,这通常指示在 CUDA 上运行的某些操作遇到了问题,大多数情况下是由于标签值 t 超出了期望的范围。这里的问题发生在执行损失函数计算时,具体来说是在 PyTorch 的底层 CUDA 代码中。

问题的主要源头:

由于错误提示 t >= 0 && t < n_classes,你需要确保所有标签值都在正确的范围内。对于分类任务,标签值 t 应该是一个非负整数,并且小于类别总数 n_classes。如果你的数据集标签不是从 0 开始的,你需要将它们转换为从 0 开始。

在合适的位置定义:确保 n_classes 在你尝试使用它进行断言检查之前已经被定义。这通常意味着你需要在加载数据集、初始化数据加载器之前,或在定义模型之处确定 n_classes 的值。

这里我的代码中已经指定:

由于错误报告显示断言失败发生在损失计算时,一个可能的原因是某些样本的标签不在 [0, n_classes-1] 的范围内。你可以添加一些代码在损失计算之前检查标签值:

增添代码段:

        if not (labels.min() >= 0 and labels.max() < n_classes):print("不满足条件的标签值:", labels[labels < 0], labels[labels >= n_classes])

 以及:

n_classes = 102  # 根据你的具体任务设置这个值, 这里对应num_classes=102

继续运动代码进行测试,出现如下报错:

 从提供的输出信息来看,断言错误是因为存在标签值等于 102,这超出了预期的类别范围 [0, n_classes-1]。假设 n_classes 应该是 102(意味着有效的标签范围是从 0101),标签值 102 显然是无效的,因为它等于类别总数,超出了最大有效索引。

解决方案

  1. 校正类别总数:首先确认 n_classes 的值是否正确。如果你的任务确实有 102 个类别(例如,Flowers102 数据集),那么 n_classes 应该设置为 102,并且你需要确保所有标签都在 [0, 101] 的范围内。

  2. 修正数据标签:由于出现了 102 作为标签值,这可能是由于数据标签在某个步骤中被错误地分配或转换。你需要回溯到数据处理的步骤,找出为什么会有 102 这样的标签值出现,并进行修正。如果是因为数据集自带的标签从 1 开始计数,那么你需要将所有标签减 1 以转换为从 0 开始计数:

增添代码:

# 假设 `labels` 是你的标签张量
labels = labels - 1

测试和验证

在进行了上述修正之后,再次运行你的代码,并使用之前添加的打印语句来验证所有标签值是否都在正确的范围内。如果没有进一步的断言错误,那么这意味着问题已经被解决。如果问题依然存在,可能需要进一步调查数据处理流程中的每一个步骤,确保在任何地方都没有引入标签错误。

再一次运行代码,发现没有报错,代码运行正常:

模型开始了训练,问题得到了解决!!!

最后给出优化后的完整python代码:

文件:main_AlexNet.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import *
import numpy as np
import matplotlib.pyplot as plt
import sys
from AlexNet import AlexNet
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'# 设备检测,若未检测到cuda设备则在CPU上运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 设置随机种子
torch.manual_seed(0)# 定义模型、优化器、损失函数
model = AlexNet(num_classes=102).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.002, momentum=0.9)
criterion = nn.CrossEntropyLoss()# 设置训练集的数据变换,进行数据增强
transform_train = transforms.Compose([transforms.RandomRotation(30),  # 随机旋转 -30度到30度之间transforms.RandomResizedCrop((224, 224)),  # 随机比例裁剪并进行resizetransforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转transforms.ToTensor(),  # 将数据转换为张量# 对三通道数据进行归一化(均值,标准差), 数值是从ImageNet数据集上的百万张图片中随机抽样计算得到transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 对数据进行归一化
])# 设置测试集的数据变换,进行数据增强
transform_test = transforms.Compose([transforms.Resize((224, 224)),  # resizetransforms.ToTensor(),  # 将数据转化为张量# 对三通道数据进行归一化(均值,标准差),数值是从ImageNet数据集上的百万张图片中随机抽样计算得到transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载训练数据,需要特别注意的是Flowers102数据集,test簇的数据量较多些,所以这里使用"test"作为训练集
train_dataset = datasets.Flowers102(root='./data/flowers102', split="test",download=False, transform=transform_train)
# 实例化训练数据加载器
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=6, drop_last=False)
# 加载测试数据,使用“train”作为测试集
test_dataset = datasets.Flowers102(root='./data/flowers102', split="train",download=False, transform=transform_test)
# 实例化测试数据加载器
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=6, drop_last=False)# 设置epoch数并开始训练
num_epochs = 500  # 设置epoch数
n_classes = 102  # 根据你的具体任务设置这个值, 这里对应num_classes=102
loss_history = []  # 创建损失历史记录列表
acc_history = []  # 创建准确率历史记录列表# tqdm用于显示进度条并评估任务时间开销
for epoch in tqdm(range(num_epochs), file=sys.stdout):# 记录损失和预测正确数total_loss = 0total_correct = 0# 批量训练model.train()for inputs, labels in train_loader:labels = labels - 1if not (labels.min() >= 0 and labels.max() < n_classes):print("不满足条件的标签值:", labels[labels < 0], labels[labels >= n_classes])# 将数据转换到指定计算资源设备上inputs = inputs.to(device)labels = labels.to(device)# 预测、损失函数、反向传播optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 记录训练集losstotal_loss += loss.item()# 测试模型,不计算梯度model.eval()with torch.no_grad():for inputs, labels in test_loader:# 将数据转换到指定计算资源设备上inputs = inputs.to(device)labels = labels.to(device)# 预测outputs = model(inputs)# 记录测试集预测正确数total_correct += (outputs.argmax(1) == labels).sum().item()# 记录训练集损失和测试集准确率loss_history.append(np.log10(total_loss))  # 将损失加入损失历史记录列表,由于数值有时较大,这里取对数acc_history.append(total_correct / len(test_dataset))  # 将准确率加入准确率历史记录列表# 打印中间值# 每50个epoch打印一次中间值if epoch % 50 == 0:tqdm.write("Epoch: {0} Loss: {1} Acc: {2}".format(epoch, loss_history[-1], acc_history[-1]))# 使用Matplotlib绘制损失和准确率的曲线图
plt.plot(loss_history, label='loss')
plt.plot(acc_history, label=' ')
plt.legend()
plt.show()# 输出准确率
print("Accuracy:", acc_history[-1])

文件:AlexNet.py

import torch
import torch.nn as nn
from torchinfo import summary# 定义AlexNet的网络结构
class AlexNet(nn.Module):def __init__(self, num_classes=1000, dropout=0.5):super().__init__()# 定义卷积层self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)# 定义全连接层self.classifier = nn.Sequential(nn.Dropout(p=dropout),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return x# 查看模型结构以及参数量,input_size表示示例输入数据的维度信息
# summary(AlexNet(), input_size=(1,3,224,224))

将epoch改为100,得到如下训练结果:

 可见模型还未收敛,大家可以自行调节参数来尝试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/462305.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

寒假作业2024.2.6

1.现有无序序列数组为23,24,12,5,33,5347&#xff0c;请使用以下排序实现编程 函数1:请使用冒泡排序实现升序排序 函数2:请使用简单选择排序实现升序排序 函数3:请使用直接插入排序实现升序排序 函数4:请使用插入排序实现升序排序 #include <stdio.h> #include <stdl…

nvm安装node后,npm无效

类似报这种问题&#xff0c;是因为去github下载npm时下载失败&#xff0c; Please visit https://github.com/npm/cli/releases/tag/v6.14.17 to download npm. 第一种方法&#xff1a;需要复制这里面的地址爬梯子去下载&#xff08;github有时不用梯子能直接下载&#xff0c;有…

联合体知识点解析

联合体&#xff1a; 联合体也是一种自定义类型&#xff0c; 特点是成员变量公用一块空间。所以也叫共用体。 联合体的性质 先定义一个联合体&#xff1a; 然后我创建一个联合体变量&#xff1a; 现在探究当修改一个成员变量的值时&#xff0c; 其他成员变量的值能否被修改&am…

[day0] 借着“ai春晚”开个场

1 文思ai笔记-新的开始 今天是2024年2月29日&#xff0c;也是传统农历的除夕夜。早起在ai圈看到一个比较新奇的消息&#xff0c;ai春晚今日举办&#xff0c;竟然有一点小小的激动。这些年确实好久没看过春晚了&#xff0c;自己对于春晚的映像还停留在“白云黑土”、“今天&…

【MySQL】MySQL表的增删改查(基础)

MySQL表的增删改查&#xff08;基础&#xff09; 1. CRUD2. 新增&#xff08;Create&#xff09;2.1 单行数据全列插入2.2 多行数据 指定列插入 3. 查询&#xff08;Retrieve&#xff09;3.1 全列查询3.2 指定列查询3.3 查询字段为表达式3.4 别名3.5 去重&#xff1a;DISTINCT…

【多模态MLLMs+图像编辑】MGIE:苹果开源基于指令和大语言模型的图片编辑神器(24.02.03开源)

项目主页&#xff1a;https://mllm-ie.github.io/ 论文 :基于指令和多模态大语言模型图片编辑 2309.Guiding Instruction-based Image Editing via Multimodal Large Language Models &#xff08;加州大学圣巴拉分校苹果&#xff09; 代码&#xff1a;https://github.com/appl…

Flask基础学习

1.debug、host、port 模式修改 1) debug模式 默认debug模式是off&#xff0c;在修改代码调试过程中需要暂停重启使用&#xff0c;这时可修改on模式解决。 同时在debug模式开启下可看到出错信息。 下面有关于Pycharm社区版和专业版修改debug模式的区别 专业版 社区版&#…

redis-sentinel(哨兵模式)

目录 1、哨兵简介:Redis Sentinel 2、作用 3、工作模式 4、主观下线和客观下线 5、配置哨兵模式 希望能够帮助到大家&#xff01;&#xff01;&#xff01; 1、哨兵简介:Redis Sentinel Sentinel(哨兵)是用于监控redis集群中Master状态的工具&#xff0c;其已经被集成在re…

SQL--图形化界面工具

1.图形化界面工具 上述&#xff0c;我们已经讲解了通过DDL语句&#xff0c;如何操作数据库、操作表、操作表中的字段&#xff0c;而通过DDL语句执 行在命令进行操作&#xff0c;主要存在以下两点问题&#xff1a; 1).会影响开发效率 ; 2). 使用起来&#xff0c;并不直观&…

docker安装Yapi

docker安装Yapi 我试了很多次按照网上安装&#xff0c;但是看时间都是2022年之前的&#xff0c;所以我下载的mogodb都是last版本不是报错就是在报错的路上&#xff0c;后来一想那就换成2022年那些版本&#xff0c;也可能是last版本不兼容或者是比较低的版本。 我将mogodb换成…

Red Panda Dev C++ Maker 使用说明

https://download.csdn.net/download/HappyStarLap/88804678https://download.csdn.net/download/HappyStarLap/88804678 下载https://download.csdn.net/download/HappyStarLap/88804678&#xff1a; ​ 这个&#xff0c;就是我们将运行的文件。 ​ 里面加了许多我…

(每日持续更新)信息系统项目管理(第四版)(高级项目管理)考试重点整理第10章 项目进度管理(四)

博主2023年11月通过了信息系统项目管理的考试&#xff0c;考试过程中发现考试的内容全部是教材中的内容&#xff0c;非常符合我学习的思路&#xff0c;因此博主想通过该平台把自己学习过程中的经验和教材博主认为重要的知识点分享给大家&#xff0c;希望更多的人能够通过考试&a…