基于cifar数据集合成含开集、闭集噪声的数据集

news/2024/10/5 8:38:04/文章来源:https://www.cnblogs.com/zh-jp/p/18274899

前言

噪声标签学习下的一个任务是:训练集上存在开集噪声和闭集噪声;然后在测试集上对闭集样本进行分类。

训练集中被加入的开集样本,会被均匀得打上闭集样本的标签充当开集噪声;而闭集噪声的设置与一般的噪声标签学习一致,分为对称噪声:随机将闭集样本的标签替换为其他类别;和非对称噪声:将闭集样本的标签替换为特定的类别。

论文实验中,常用cifar数据集模拟这类任务。目前已知有两类方法:

第一类基于cifar100,将100个类的一部分,通常是20个类作为开集样本,将它们标签替换了前80个类作为开集噪声;然后对于后续80个类,选择部分样本设置为对称/非对称闭集噪声。CVPR2022的PNP: Robust Learning From Noisy Labels by Probabilistic Noise Prediction提供的代码中,使用了这种方法。

但是,如果要考虑非对称噪声,在cifar10上就很难实现,cifar10的类的顺序不像cifar100那样有规律,不好设置闭集噪声。

第二类方法适用cifar10和cifar100,保持原始数据集的样本数不变,使用额外的数据集(通常是imagenet32、places365)代替部分样本作为开集噪声,对于剩下的非开集噪声样本再设置闭集噪声。ECCV2022的Embedding contrastive unsupervised features to cluster in-and out-of-distribution noise in corrupted image datasets提供的代码使用了这种方式。

places365可以使用torchvision.datasets.Places365下载,由于训练集较大,通常是用它的验证集作为辅助数据集。imagenet32是imagnet的32x32版本,同样是1k类,但是类的具体含义的顺序与imagenet不同,imagenet32类的具体含义可见这里。image32下载地址在对应论文A downsampled variant of imagenet as an alternative to the cifar datasets提供的链接。

接下来是用第二种方法,辅助数据集使用imagenet32,基于cifar构造含开集闭集噪声的训练集。

实验

设计imagenet32数据集

import os
import pickle
import numpy as np
from PIL import Image
from torch.utils.data import Dataset_train_list = ['train_data_batch_1','train_data_batch_2','train_data_batch_3','train_data_batch_4','train_data_batch_5','train_data_batch_6','train_data_batch_7','train_data_batch_8','train_data_batch_9','train_data_batch_10']
_val_list = ['val_data']def get_dataset(transform_train, transform_test):# prepare datasets# Train settrain = Imagenet32(train=True, transform=transform_train)  # Load all 1000 classes in memory# Test settest = Imagenet32(train=False, transform=transform_test)  # Load all 1000 test classes in memoryreturn train, testclass Imagenet32(Dataset):def __init__(self, root='~/data/imagenet32', train=True, transform=None):if root[0] == '~':root = os.path.expanduser(root)self.transform = transformsize = 32# Now load the picked numpy arraysif train:data, labels = [], []for f in _train_list:file = os.path.join(root, f)with open(file, 'rb') as fo:entry = pickle.load(fo, encoding='latin1')data.append(entry['data'])labels += entry['labels']data = np.concatenate(data)else:f = _val_list[0]file = os.path.join(root, f)with open(file, 'rb') as fo:entry = pickle.load(fo, encoding='latin1')data = entry['data']labels = entry['labels']data = data.reshape((-1, 3, size, size))self.data = data.transpose((0, 2, 3, 1))  # Convert to HWClabels = np.array(labels) - 1self.labels = labels.tolist()def __getitem__(self, index):img, target = self.data[index], self.labels[index]img = Image.fromarray(img)if self.transform is not None:img = self.transform(img)return img, target, indexdef __len__(self):return len(self.data)

目录结构:

imagenet32
├─ train_data_batch_1
├─ train_data_batch_10
├─ train_data_batch_2
├─ train_data_batch_3
├─ train_data_batch_4
├─ train_data_batch_5
├─ train_data_batch_6
├─ train_data_batch_7
├─ train_data_batch_8
├─ train_data_batch_9
└─ val_data

设计cifar数据集

import torchvision
import numpy as np
from dataset.imagenet32 import Imagenet32class CIFAR10(torchvision.datasets.CIFAR10):def __init__(self, root='~/data', train=True, transform=None,r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):nb_classes = 10self.nb_classes = nb_classessuper().__init__(root, train=train, transform=transform)if train is False:returnnp.random.seed(seed)if r_ood > 0.:ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]if corruption == 'imagenet':imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]else:raise ValueError(f'Unknown corruption: {corruption}')self.ids_ood = ids_oodself.data[ids_ood] = img_oodif r_id > 0.:ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]for i, t in enumerate(self.targets):if i in ids_id:self.targets[i] = int(np.random.random() * nb_classes)self.ids_id = ids_idclass CIFAR100(torchvision.datasets.CIFAR100):def __init__(self, root='~/data', train=True, transform=None,r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):nb_classes = 100self.nb_classes = nb_classessuper().__init__(root, train=train, transform=transform)if train is False:returnnp.random.seed(seed)if r_ood > 0.:ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]if corruption == 'imagenet':imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]else:raise ValueError(f'Unknown corruption: {corruption}')self.ids_ood = ids_oodself.data[ids_ood] = img_oodif r_id > 0.:ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]for i, t in enumerate(self.targets):if i in ids_id:self.targets[i] = int(np.random.random() * nb_classes)self.ids_id = ids_id

查看统计结果

import pandas as pd
import altair as alt
from dataset.cifar import CIFAR10, CIFAR100# Initialize CIFAR10 dataset
cifar10 = CIFAR10(r_imb=0.)
cifar100 = CIFAR100(r_imb=0.)def statistics_samples(dataset):ids_ood = dataset.ids_oodids_id = dataset.ids_id# Collect statisticsstatistics = []for i in range(dataset.nb_classes):statistics.append({'class': i,'id': 0,'ood': 0,'clear': 0})for i, t in enumerate(dataset.targets):if i in ids_ood:statistics[t]['ood'] += 1elif i in ids_id:statistics[t]['id'] += 1else:statistics[t]['clear'] += 1df = pd.DataFrame(statistics)# Melt the DataFrame for Altairdf_melt = df.melt(id_vars='class', var_name='type', value_name='count')# Create the bar chartchart = alt.Chart(df_melt).mark_bar().encode(x=alt.X('class:O', title='Classes'),y=alt.Y('count:Q', title='Sample Count'),color='type:N')return chartchart1 = statistics_samples(cifar10)
chart2 = statistics_samples(cifar100)
chart1 = chart1.properties(title='cifar10',width=100,  # Adjust width to fit both charts side by sideheight=400
)
chart2 = chart2.properties(title='cifar100',width=800,height=400
)
combined_chart = alt.hconcat(chart1, chart2).configure_axis(labelFontSize=12,titleFontSize=14
).configure_legend(titleFontSize=14,labelFontSize=12
)
combined_chart

运行环境

# Name                    Version                   Build  Channel
altair                    5.3.0                    pypi_0    pypi
pytorch                   2.3.1           py3.12_cuda12.1_cudnn8_0    pytorch
pandas                    2.2.2                    pypi_0    pypi

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/733503.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AbpVnext系列四 用户表设计

一、一个系统中,最开始要设计的是用户表,先简单的设计如下图。 二、类的实体/// <summary>/// /// </summary>/// <param name="id"></param>public User(long id){Id = id;Status = NormalLockedStatus.Normal;}/// <summary>/// …

oop第7~8次作业总结(第三次Blog)

oop第7~8次作业总结(第三次Blog) 一、前言 二、设计与分析 三、踩坑心得 四、改进建议 五、总结 一、前言 1、第七次作业 第六次作业只有一道题,与上次相比,本次迭代考虑多个并联电路串联在一起的情况,考虑一条串联电路中包含其他串联电路的情况。增加了互斥开关和受控窗帘…

Profinet IO从站数据 转EtherCAT项目案例

目录 1 案例说明 1 2 VFBOX网关工作原理 1 3 准备工作 2 4 使用PRONETA软件获取PROFINET IO从站的配置信息 2 5 设置网关采集PROFINETIO从站设备数据 5 6 启动ETHERCAT从站转发采集的数据 8 7 选择槽号和数据地址 9 8 选择子槽号 11 9 案例总结 12 1 案例说明设置网关采集Profi…

题目集7~8总结性blog

前言在我看来,题目集7~8知识点在于类的使用,方法的调用,通过不同类的ArrayList的使用将所有元件信息录入,通过连接方式将其变成一条电路,从而进行计算和使用,题量不大,只有一道题,在精不在多,难度适中,可以接受并且能够编程写出,完成功能的实现。设计与分析题目集七…

【日记】现在的孩子真是不怕大人呢(1975 字)

正文时间太晚了,而且想写的内容有点多,就不写在日记本上了。不过说内容多,其实也只有两件事情。其他的就一笔带过吧。一件关于灵,另一件事关于遇见的孩子。首先说说工作,今天真的如昨天预料的那样,特别忙。开一个户,上午就没了。倒不是因为有多难,也不是因为只有我一个…

OOP7-8次作业

OOP7-8次作业 一.前言: 1.第七次PTA: 根据之前的内容在,之前的内容上进行修改和扩增。 ①题目理解: 1.增加了互斥开关,互斥开关的难点在于他可以转换每条路的通路和断路,而且你切换的两个引脚他的电阻也是不一样的,这就增加了你获得电路电阻的难度。 2.对于在电路的获得…

修复Win10上ListView样式不正确的问题

在Windows 11下,使用WinUI2.6以上版本的ListView长这样:然而到了Win10上,尽管其他控件的样式没有改变,但ListViewItem变成了默认样式(初代Fluent) 最重大的问题是,Win10上的HorizontalAlignment未被设置成Stretch,可能造成严重的UI错位(隔壁livelyweather也有这个问题…

Asp.Net Core 使用IBrowserFile完成图片上传

Asp.Net Core 使用IBrowserFile完成图片上传 写在开头 前几天弄自己的项目时遇到的问题,发现了asp.net core 新增的IBrowserFile接口,事实上他可以满足大多数文件的上传,此处我仅以图片作为示例。 实现 添加一个帮助类FileHelper,遵循一种规范,即将静态文件资源存放至wwwro…

Transformer 能代替图神经网络吗?

当Transformer模型发布时,它彻底革新了机器翻译领域。虽然最初是为特定任务设计的,但这种革命性的架构显示出它可以轻松适应不同的任务。随后成为了Transformer一个标准,甚至用于它最初设计之外的数据(如图像和其他序列数据)。 然后人们也开始优化和寻找替代方案,主要是为…

从零开始的 DP 学习记录

为了补上我dp的短板(其实说真的dp约等于没学过,板都没有的那种),也为了以后复习dp不会再忘记dp怎么写,dp的各种思想是怎么来的,从零开始学习 dp ,并记录在此博客。 因为要补的东西也挺多的,就不多开文章了,直接在这里记录了。 当然也会记录日常生活 大概是首发于洛谷博…

【Linux系列】 Bash 重定向中 file 21 和 21 file 的区别

一、 写在前面 在 Bash 脚本和命令行操作中,输出重定向是一项基本且强大的功能。 它允许用户控制命令的输出流,将数据从一个地方转移到另一个地方,实现更加灵活和高效的工作流程。 本文旨在记录 Bash 中几种常见的输出重定向方法,包括:> file>file 2>&1 vs 2&…

c#实现定时从外部服务器获取文件并查重(MD5)

需求:需要定时去请求外部服务器的文件,看看每天是否有新的文件上传,如果有就下载到本地服务器,并记录数据。原来的文件重命名。 方案:这里通过文件的MD5和其他条件来判断文件是否存在。因为文件量过大,所以批量下载的时候有时候会出现部分文件没能下载成功,但是数据入库…