学习笔记17:DenseNet实现多分类(卷积基特征提取)

news/2025/3/26 8:01:46/文章来源:https://www.cnblogs.com/gongzb/p/18230180

转自:https://www.cnblogs.com/miraclepbc/p/14378379.html

数据集描述

总共200200类图像,每一类图像都存放在一个以类别名称命名的文件夹下,每张图片的命名格式如下图:

数据预处理

首先分析一下我们在数据预处理阶段的目标和工作流程

  • 获取每张图像以及对应的标签

  • 划分测试集和训练集

  • 通过写数据集类的方式,获取数据集并进一步获得DataLoader

  • 打印图片,验证效果

获取图像及标签

all_imgs_path = glob.glob(r'E:\birds\birds\*\*.jpg') # 获取所有图像路径列表
all_labels_name = [i.split('\\')[3].split('.')[1] for i in all_imgs_path] # 获取每张图像的标签名
label_to_index = dict([(v, k) for k, v in enumerate(unique_labels)]) # 将标签名映射到数值
# 获取每张图片的数值标签
all_labels = []
for img in all_imgs_path:for k, v in label_to_index.items():if k in img:all_labels.append(v)

划分测试集和训练集

以下代码可以作为模板来用,不做额外解释

np.random.seed(2021)
index = np.random.permutation(len(all_imgs_path))
all_imgs_path = np.array(all_imgs_path)[index]
all_labels = np.array(all_labels)[index]
s = int(len(all_imgs_path) * 0.8)train_path = all_imgs_path[:s]
train_labels = all_labels[:s]
test_path = all_imgs_path[s:]
test_labels = all_labels[s:]

通过写数据集类的方式,获取数据集并进一步获得DataLoader

以下代码可以作为模板来用,不做额外解释

transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()
])class BirdsDataset(data.Dataset):def __init__(self, img_paths, labels, transform):self.imgs = img_pathsself.labels = labelsself.transforms = transformdef __getitem__(self, index):img = self.imgs[index]label = self.labels[index]pil_img = Image.open(img)pil_img = pil_img.convert('RGB') # 这一句是专门用来解决一种RuntimeError的np_img = np.array(pil_img, dtype = np.uint8)if np_img.shape == 2:img_data = np.repeat(np_img[:, :, np.newaxis], 3, axis = 2)pil_data = Image.fromarray(img_data)data = self.transforms(pil_img)return data, labeldef __len__(self):return len(self.imgs)train_ds = BirdsDataset(train_path, train_labels, transform)
test_ds = BirdsDataset(test_path, test_labels, transform)
train_dl = data.DataLoader(train_ds, batch_size = 32) # 这里只是提取卷积基,不做训练,因此不用shuffle
test_dl = data.DataLoader(test_ds, batch_size = 32)

结果查看

取出一个批次的数据,绘图

img_batch, label_batch = next(iter(train_dl))
plt.figure(figsize = (12, 8)) # 定义画布大小
index_to_label = dict([(k, v) for k, v in enumerate(unique_labels)])
for i, (img, label) in enumerate(zip(img_batch[:3], label_batch[:3])):img = img.permute(1, 2, 0).numpy() # 将channel放在最后一维plt.subplot(1, 3, i + 1)plt.title(index_to_label.get(label.item()))plt.imshow(img)

结果如下:

提取卷积基

这一阶段的工作流程如下:

  • 获取DenseNet预训练模型,使用feature部分

  • 使用卷积基提取图像特征,并存放在列表中

预训练模型获取

my_densenet = models.densenet121(pretrained = True).featuresif torch.cuda.is_available():my_densenet = my_densenet.cuda()for p in my_densenet.parameters():p.requires_grad = False

提取图像特征

train_features = []
train_features_labels = []
for im, la in train_dl:out = my_densenet(im.cuda())out = out.view(out.size(0), -1) # 这里需要进行扁平化操作,因为后面要进行线性模型预测train_features.extend(out.cpu().data) # 这里注意是extend,extend可以将一个列表加到另一个列表的后面train_features_labels.extend(la)test_features = []
test_features_labels = []
for im, la in test_dl:out = my_densenet(im.cuda())out = out.view(out.size(0), -1)test_features.extend(out.cpu().data)test_features_labels.extend(la)

重新定义数据集

因为后面要通过线性模型来预测,因此之前的图像数据集就不好用了

因此需要用刚刚提取到的特征,重新制作数据集

class FeatureDataset(data.Dataset):def __init__(self, feature_list, label_list):self.feature_list = feature_listself.label_list = label_listdef __getitem__(self, index):return self.feature_list[index], self.label_list[index]def __len__(self):return len(self.feature_list)train_feature_ds = FeatureDataset(train_features, train_features_labels)
test_feature_ds = FeatureDataset(test_features, test_features_labels)
train_feature_dl = data.DataLoader(train_feature_ds, batch_size = 32, shuffle = True)
test_feature_dl = data.DataLoader(test_feature_ds, batch_size = 32)

模型定义与预测

这里定义一个线性模型即可

模型定义

class FCModel(nn.Module):def __init__(self, in_size, out_size):super().__init__()self.linear = nn.Linear(in_size, out_size)def forward(self, input):return self.linear(input)in_feature_size = train_features[0].shape[0]
net = FCModel(in_feature_size, 200)device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)loss_func = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = 0.00001)
epochs = 30

模型训练

def fit(epoch, model, trainloader, testloader):correct = 0total = 0running_loss = 0model.train()for x, y in trainloader:y = torch.tensor(y, dtype = torch.long)x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_func(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():y_pred = torch.argmax(y_pred, dim = 1)correct += (y_pred == y).sum().item()total += y.size(0)running_loss += loss.item()epoch_acc = correct / totalepoch_loss = running_loss / len(trainloader.dataset)test_correct = 0test_total = 0test_running_loss = 0model.eval()with torch.no_grad():for x, y in testloader:y = torch.tensor(y, dtype = torch.long)x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_func(y_pred, y)y_pred = torch.argmax(y_pred, dim = 1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()epoch_test_acc = test_correct / test_totalepoch_test_loss = test_running_loss / len(testloader.dataset)print('epoch: ', epoch, 'loss: ', round(epoch_loss, 3),'accuracy: ', round(epoch_acc, 3),'test_loss: ', round(epoch_test_loss, 3),'test_accuracy: ', round(epoch_test_acc, 3))return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acctrain_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, net, train_feature_dl, test_feature_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_test_loss)test_acc.append(epoch_test_acc)

训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/719711.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习笔记19:图像定位

转自:https://www.cnblogs.com/miraclepbc/p/14385623.html 图像定位的直观理解 不仅需要我们知道图片中的对象是什么,还要在对象的附近画一个边框,确定该对象所处的位置。 也就是最终输出的是一个四元组,表示边框的位置 图像定位网络架构 可以将图像定位任务看作是一个回归…

学习笔记13:微调模型

转自:https://www.cnblogs.com/miraclepbc/p/14360807.html resnet预训练模型 resnet模型与之前笔记中的vgg模型不同,需要我们直接覆盖掉最后的全连接层先看一下resnet模型的结构: 我们需要先将所有的参数都设置成requires_grad = False然后再重新定义fc层,并覆盖掉原来的。…

成熟的双向同步方案,能够解决哪些同步问题?

在企业的数据流转管控过程中,经常会遇到频繁的数据备份、同步,人工重复这样的工作程序,既繁琐又容易出错。因此对于企业而言,选择一款高效且安全的同步软件成为了企业运营中的关键一环,不仅能够提高工作效率,还能确保数据的安全性。在选择双向同步方案时,首先要明确自己…

从数据库设计到性能调优,全面掌握openGemini应用开发最佳实践

据库设计和性能调优最重要的干货都在这里了!本文分享自华为云社区《DTSE Tech Talk openGemini :从数据库设计到性能调优,全面掌握openGemini应用开发最佳实践》,作者:华为云开源。 在本期《从数据库设计到性能调优,全面掌握openGemini应用开发最佳实践》的主题直播中,…

MBD闲谈 第03期:MBD的“禁区”——底层驱动

转载自:autoMBD, 版权归autoMBD所有,转载请注明作者和来源 原文链接:http://www.360doc.com/content/22/0820/17/15913066_1044626106.shtml全文约3562字,你将看到以下内容:底层驱动的那些事底层驱动为啥是MBD“禁区” 底层驱动与模型集成下期预告1 底层驱动的那些事 先…

allure的suites(测试套)中未显示返回值参数,显示No information about test execution is available.(转自大佬,亲测有用)

转自大佬:https://blog.csdn.net/sbdxmnz/article/details/137016423ExecutionNo information about test execution is available.解决方法: 添加代码,因为pytest输出文本形式测试报告时未存储响应内容 # 将接口响应的文本内容附加到Allure报告中 allure.attach(接口响应.…

学习笔记9:卷积神经网络实现MNIST分类(GPU加速)

转自:https://www.cnblogs.com/miraclepbc/p/14345342.html 相关包导入 import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt from torch import nn import torch.nn.functional as F from torch.utils.data import TensorDataset from tor…

笔记2:张量简介

张量生成方法 转自:https://www.cnblogs.com/miraclepbc/p/14329476.html张量的形状及类型张量的计算张量的梯度手写线性回归张量生成方法 张量的形状及类型 张量的计算 张量的梯度 手写线性回归

笔记3:逻辑回归(分批次训练)

转自:https://www.cnblogs.com/miraclepbc/p/14332084.html 相关库导入 import torch import pandas as pd import numpy as np import matplotlib.pyplot as plt from torch import nn %matplotlib inline数据读入及预处理 data = pd.read_csv(E:/datasets/dataset/credit-a.…

【深度好文】到底什么是质量意识?如何衡量,如何提升?

大家好,我是狂师! 在软件测试中,质量意识是一个核心且至关重要的概念。相信大家,经常会听到:"这个家伙质量意识很强,某某某要提升质量意识“之类的话语。 在企业中,“质量意识”不仅关乎产品和服务的优劣,更是企业竞争力和可持续发展的关键因素。那么,到底什么是…

域名

顶级域名、二级域名与三级域名互联网名称与数字地址分配机构(ICANN)负责管理和协调国际互联网络域名系统。根据ICANN的定义,一个完整的域名至少有两个部分,各部分之间用“.”来分隔,最后一个“.”的右边部分称为顶级域名,也称为一级域名;最后一个“.”的左边部分称为二级…

异构数据源同步之数据同步 → DataX 使用细节

开心一刻 中午我妈微信给我消息 妈:儿子啊,妈电话欠费了,能帮妈充个话费吗 我:妈,我知道了,我帮你充 当我帮我妈把话费充好,正准备回微信的时候,我妈微信给我发消息了 妈:等会儿子,不用充了,刚刚有个二臂帮妈充上了 我输入框中的(妈,充好了)是发还是不发?简单使…