Yolov8 源码解析(二十八)
.\yolov8\ultralytics\data\base.py
# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport glob # 导入用于获取文件路径的模块
import math # 导入数学函数模块
import os # 导入操作系统功能模块
import random # 导入生成随机数的模块
from copy import deepcopy # 导入深拷贝函数
from multiprocessing.pool import ThreadPool # 导入多线程池模块
from pathlib import Path # 导入处理路径的模块
from typing import Optional # 导入类型提示模块import cv2 # 导入OpenCV图像处理库
import numpy as np # 导入NumPy数值计算库
import psutil # 导入进程和系统信息获取模块
from torch.utils.data import Dataset # 导入PyTorch数据集基类from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS # 导入自定义数据处理工具
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM # 导入自定义工具函数class BaseDataset(Dataset):"""Base dataset class for loading and processing image data.Args:img_path (str): Path to the folder containing images.imgsz (int, optional): Image size. Defaults to 640.cache (bool, optional): Cache images to RAM or disk during training. Defaults to False.augment (bool, optional): If True, data augmentation is applied. Defaults to True.hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None.prefix (str, optional): Prefix to print in log messages. Defaults to ''.rect (bool, optional): If True, rectangular training is used. Defaults to False.batch_size (int, optional): Size of batches. Defaults to None.stride (int, optional): Stride. Defaults to 32.pad (float, optional): Padding. Defaults to 0.0.single_cls (bool, optional): If True, single class training is used. Defaults to False.classes (list): List of included classes. Default is None.fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data).Attributes:im_files (list): List of image file paths.labels (list): List of label data dictionaries.ni (int): Number of images in the dataset.ims (list): List of loaded images.npy_files (list): List of numpy file paths.transforms (callable): Image transformation function."""def __init__(self,img_path,imgsz=640,cache=False,augment=True,hyp=DEFAULT_CFG,prefix="",rect=False,batch_size=16,stride=32,pad=0.5,single_cls=False,classes=None,fraction=1.0,):# 初始化数据集对象,设置各种参数和属性"""Initialize BaseDataset with given configuration and options."""# 调用父类初始化方法super().__init__()# 设置图片路径self.img_path = img_path# 图像大小self.imgsz = imgsz# 是否进行数据增强self.augment = augment# 是否单类别self.single_cls = single_cls# 数据集前缀self.prefix = prefix# 数据集采样比例self.fraction = fraction# 获取所有图像文件路径self.im_files = self.get_img_files(self.img_path)# 获取标签self.labels = self.get_labels()# 更新标签,根据是否单类别和指定的类别self.update_labels(include_class=classes) # single_cls and include_class# 图像数量self.ni = len(self.labels) # number of images# 是否使用矩形边界框self.rect = rect# 批处理大小self.batch_size = batch_size# 步长self.stride = stride# 填充self.pad = pad# 如果使用矩形边界框,确保指定了批处理大小if self.rect:assert self.batch_size is not None# 设置矩形边界框参数self.set_rectangle()# 用于马赛克图像的缓冲线程self.buffer = [] # buffer size = batch size# 最大缓冲长度,最小为图像数量、批处理大小的8倍、1000中的最小值(如果进行数据增强)self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0# 缓存图像(缓存选项包括 True, False, None, "ram", "disk")self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni# 生成每个图像文件对应的 .npy 文件路径self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]# 设置缓存选项self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None# 如果缓存选项是 "ram" 并且内存中已存在缓存,或者缓存选项是 "disk",则进行图像缓存if (self.cache == "ram" and self.check_cache_ram()) or self.cache == "disk":self.cache_images()# 构建图像转换操作self.transforms = self.build_transforms(hyp=hyp)def get_img_files(self, img_path):"""Read image files."""try:f = [] # image files列表,用于存储图像文件路径for p in img_path if isinstance(img_path, list) else [img_path]:p = Path(p) # 将路径转换为Path对象,以保证在不同操作系统上的兼容性if p.is_dir(): # 如果是目录f += glob.glob(str(p / "**" / "*.*"), recursive=True)# 获取目录下所有文件的路径,并加入到f列表中# 使用glob模块,支持递归查找# 使用pathlib的方式:F = list(p.rglob('*.*')) elif p.is_file(): # 如果是文件with open(p) as t:t = t.read().strip().splitlines() # 读取文件内容,并按行分割parent = str(p.parent) + os.sep# 获取文件的父目录,并在每个文件路径前添加父目录路径,处理本地到全局路径的转换f += [x.replace("./", parent) if x.startswith("./") else x for x in t]# 将文件路径添加到f列表中,处理相对路径# 使用pathlib的方式:F += [p.parent / x.lstrip(os.sep) for x in t]else:raise FileNotFoundError(f"{self.prefix}{p} does not exist")# 如果既不是文件也不是目录,则抛出文件不存在的异常im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)# 对f列表中的文件路径进行筛选,保留符合图像格式的文件路径,并排序# 使用pathlib的方式:self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"# 如果im_files为空,则抛出断言错误,表示未找到任何图像文件except Exception as e:raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e# 捕获所有异常,并抛出带有详细信息的文件加载错误异常if self.fraction < 1:im_files = im_files[: round(len(im_files) * self.fraction)] # 保留数据集的一部分比例# 如果fraction小于1,则根据fraction保留im_files中的部分文件路径return im_files# 返回处理后的图像文件路径列表def update_labels(self, include_class: Optional[list]):"""Update labels to include only these classes (optional)."""include_class_array = np.array(include_class).reshape(1, -1)# 将include_class转换为NumPy数组,并进行形状重塑for i in range(len(self.labels)):if include_class is not None: # 如果include_class不为空cls = self.labels[i]["cls"]bboxes = self.labels[i]["bboxes"]segments = self.labels[i]["segments"]keypoints = self.labels[i]["keypoints"]j = (cls == include_class_array).any(1)# 找到标签中与include_class相匹配的类别索引self.labels[i]["cls"] = cls[j] # 更新类别self.labels[i]["bboxes"] = bboxes[j] # 更新边界框if segments: # 如果存在分割信息self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]# 更新分割信息,只保留与include_class匹配的分割if keypoints is not None: # 如果存在关键点信息self.labels[i]["keypoints"] = keypoints[j] # 更新关键点信息if self.single_cls: # 如果标签是单类别的self.labels[i]["cls"][:, 0] = 0 # 将所有类别标记为0def load_image(self, i, rect_mode=True):"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""# 从数据集索引 'i' 加载一张图片,并返回原图和调整大小后的尺寸im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]if im is None: # not cached in RAM# 如果图像未被缓存在内存中if fn.exists(): # load npy# 如果存在对应的 *.npy 文件,则加载该文件try:im = np.load(fn)except Exception as e:# 捕获异常,警告并删除损坏的 *.npy 图像文件LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}")Path(fn).unlink(missing_ok=True)# 从原始图像文件加载图像(BGR格式)im = cv2.imread(f) # BGRelse: # read image# 否则,直接从原始图像文件中读取图像(BGR格式)im = cv2.imread(f) # BGR# 如果未能成功加载图像,则抛出文件未找到异常if im is None:raise FileNotFoundError(f"Image Not Found {f}")h0, w0 = im.shape[:2] # orig hwif rect_mode: # resize long side to imgsz while maintaining aspect ratio# 如果矩形模式为真,则将长边调整到指定的imgsz大小,并保持纵横比r = self.imgsz / max(h0, w0) # ratioif r != 1: # if sizes are not equal# 计算调整后的宽高,并进行插值缩放w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz# 否则,将图像拉伸调整到正方形大小imgszim = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)# 如果进行数据增强训练,则将处理后的图像数据和原始、调整后的尺寸保存到缓冲区if self.augment:self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resizedself.buffer.append(i)if 1 < len(self.buffer) >= self.max_buffer_length: # prevent empty buffer# 如果缓冲区长度超过最大长度限制,则弹出最旧的元素j = self.buffer.pop(0)if self.cache != "ram":# 如果不是RAM缓存,则清空该位置的图像和尺寸数据self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None# 返回加载的图像、原始尺寸和调整后的尺寸return im, (h0, w0), im.shape[:2]# 如果图像已缓存在内存中,则直接返回已缓存的图像及其原始和调整后的尺寸return self.ims[i], self.im_hw0[i], self.im_hw[i]def cache_images(self):"""Cache images to memory or disk."""b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes# 根据缓存选项选择不同的缓存函数和存储介质fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM")# 使用线程池处理图像缓存操作with ThreadPool(NUM_THREADS) as pool:# 并行加载图像或执行缓存操作results = pool.imap(fcn, range(self.ni))# 使用进度条显示缓存进度pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)for i, x in pbar:if self.cache == "disk":# 如果缓存到磁盘,则累加缓存的图像文件大小b += self.npy_files[i].stat().st_sizeelse: # 'ram'# 如果缓存到RAM,则直接将加载的图像和其尺寸保存到相应的位置self.ims[i], self.im_hw0[i], self.im_hw[i] = xb += self.ims[i].nbytes# 更新进度条描述信息,显示当前缓存的总量及存储介质pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})"pbar.close()def cache_images_to_disk(self, i):"""Saves an image as an *.npy file for faster loading."""f = self.npy_files[i] # 获取第 i 个 *.npy 文件的路径if not f.exists(): # 如果该文件不存在np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False) # 将对应图像保存为 *.npy 文件def check_cache_ram(self, safety_margin=0.5):"""Check image caching requirements vs available memory."""b, gb = 0, 1 << 30 # 初始化缓存图像占用的字节数和每个 GB 的字节数n = min(self.ni, 30) # 选取 self.ni 和 30 中较小的一个作为采样图片数目for _ in range(n):im = cv2.imread(random.choice(self.im_files)) # 随机选取一张图片进行读取ratio = self.imgsz / max(im.shape[0], im.shape[1]) # 计算图片尺寸与最大宽高之比b += im.nbytes * ratio**2 # 计算每张图片占用的内存字节数,并根据比率进行加权求和mem_required = b * self.ni / n * (1 + safety_margin) # 计算需要缓存整个数据集所需的内存大小(GB)mem = psutil.virtual_memory() # 获取系统内存信息success = mem_required < mem.available # 判断是否有足够的内存来缓存数据集if not success: # 如果内存不足self.cache = None # 清空缓存LOGGER.info(f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images "f"with {int(safety_margin * 100)}% safety margin but only "f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images ⚠️") # 记录日志,显示缓存失败的原因和相关内存信息return success # 返回是否成功缓存的布尔值def set_rectangle(self):"""Sets the shape of bounding boxes for YOLO detections as rectangles."""bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # 计算每张图片所属的批次索引nb = bi[-1] + 1 # 计算总批次数s = np.array([x.pop("shape") for x in self.labels]) # 提取标签中的形状信息(宽高)ar = s[:, 0] / s[:, 1] # 计算宽高比irect = ar.argsort() # 对宽高比进行排序的索引self.im_files = [self.im_files[i] for i in irect] # 根据排序后的索引重新排列图像文件路径self.labels = [self.labels[i] for i in irect] # 根据排序后的索引重新排列标签ar = ar[irect] # 根据排序后的索引重新排列宽高比# 设置训练图像的形状shapes = [[1, 1]] * nbfor i in range(nb):ari = ar[bi == i] # 找出属于当前批次的所有图片的宽高比mini, maxi = ari.min(), ari.max() # 计算当前批次内宽高比的最小值和最大值if maxi < 1:shapes[i] = [maxi, 1] # 如果最大宽高比小于1,则设为最大宽度,高度为1elif mini > 1:shapes[i] = [1, 1 / mini] # 如果最小宽高比大于1,则设为宽度1,高度为最小高度的倒数self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride # 计算批次形状,保证整数倍的步长self.batch = bi # 记录每张图像所属的批次索引def __getitem__(self, index):"""Returns transformed label information for given index."""return self.transforms(self.get_image_and_label(index)) # 返回给定索引的图像和标签的转换信息def get_image_and_label(self, index):"""Get and return label information from the dataset."""label = deepcopy(self.labels[index]) # 创建标签的深层副本,确保不影响原始数据 https://github.com/ultralytics/ultralytics/pull/1948label.pop("shape", None) # 如果存在形状信息,从标签中移除,通常适用于矩形标注数据# 载入图像并将相关信息存入标签字典label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)# 计算图像缩放比例,用于评估label["ratio_pad"] = (label["resized_shape"][0] / label["ori_shape"][0],label["resized_shape"][1] / label["ori_shape"][1],)if self.rect:# 如果使用矩形模式,添加批次对应的形状信息到标签中label["rect_shape"] = self.batch_shapes[self.batch[index]]# 更新标签信息并返回return self.update_labels_info(label)def __len__(self):"""Returns the length of the labels list for the dataset."""# 返回数据集标签列表的长度return len(self.labels)def update_labels_info(self, label):"""Custom your label format here."""# 自定义标签格式的方法,直接返回输入的标签return labeldef build_transforms(self, hyp=None):"""Users can customize augmentations here.Example:```pyif self.augment:# Training transformsreturn Compose([])else:# Val transformsreturn Compose([])```"""# 用户可以在此处自定义数据增强操作,此处抛出未实现错误,鼓励用户进行定制raise NotImplementedErrordef get_labels(self):"""Users can customize their own format here.Note:Ensure output is a dictionary with the following keys:```pydict(im_file=im_file,shape=shape, # format: (height, width)cls=cls,bboxes=bboxes, # xywhsegments=segments, # xykeypoints=keypoints, # xynormalized=True, # or Falsebbox_format="xyxy", # or xywh, ltwh)```"""# 用户可以在此处自定义标签输出格式,此处抛出未实现错误,鼓励用户进行定制raise NotImplementedError
.\yolov8\ultralytics\data\build.py
# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport os
import random
from pathlib import Pathimport numpy as np
import torch
from PIL import Image
from torch.utils.data import dataloader, distributed# 导入自定义数据集类
from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
# 导入数据加载器
from ultralytics.data.loaders import (LOADERS,LoadImagesAndVideos,LoadPilAndNumpy,LoadScreenshots,LoadStreams,LoadTensor,SourceTypes,autocast_list,
)
# 导入数据相关的工具函数和常量
from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS
# 导入辅助工具
from ultralytics.utils import RANK, colorstr
# 导入检查函数
from ultralytics.utils.checks import check_fileclass InfiniteDataLoader(dataloader.DataLoader):"""Dataloader that reuses workers.Uses same syntax as vanilla DataLoader."""def __init__(self, *args, **kwargs):"""Dataloader that infinitely recycles workers, inherits from DataLoader."""super().__init__(*args, **kwargs)# 使用 _RepeatSampler 来无限循环利用数据加载器的工作线程object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))# 创建迭代器self.iterator = super().__iter__()def __len__(self):"""Returns the length of the batch sampler's sampler."""return len(self.batch_sampler.sampler)def __iter__(self):"""Creates a sampler that repeats indefinitely."""for _ in range(len(self)):yield next(self.iterator)def reset(self):"""Reset iterator.This is useful when we want to modify settings of dataset while training."""# 重置迭代器,允许在训练过程中修改数据集设置self.iterator = self._get_iterator()class _RepeatSampler:"""Sampler that repeats forever.Args:sampler (Dataset.sampler): The sampler to repeat."""def __init__(self, sampler):"""Initializes an object that repeats a given sampler indefinitely."""self.sampler = samplerdef __iter__(self):"""Iterates over the 'sampler' and yields its contents."""while True:yield from iter(self.sampler)def seed_worker(worker_id): # noqa"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""# 设置数据加载器的工作线程种子worker_seed = torch.initial_seed() % 2**32np.random.seed(worker_seed)random.seed(worker_seed)def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False):"""Build YOLO Dataset."""# 根据 multi_modal 参数选择 YOLO 单模态或多模态数据集dataset = YOLOMultiModalDataset if multi_modal else YOLODataset# 返回一个数据集对象,用于训练或推断return dataset(img_path=img_path, # 图像路径imgsz=cfg.imgsz, # 图像尺寸batch_size=batch, # 批处理大小augment=mode == "train", # 是否进行数据增强(训练模式下)hyp=cfg, # 训练超参数配置rect=cfg.rect or rect, # 是否使用矩形批处理(从配置文件或参数中获取)cache=cfg.cache or None, # 是否缓存数据(从配置文件或参数中获取)single_cls=cfg.single_cls or False, # 是否单类别训练(从配置文件或参数中获取,默认为False)stride=int(stride), # 步幅大小(转换为整数)pad=0.0 if mode == "train" else 0.5, # 填充值(训练模式下为0.0,推断模式下为0.5)prefix=colorstr(f"{mode}: "), # 日志前缀,包含模式信息task=cfg.task, # 任务类型(从配置文件中获取)classes=cfg.classes, # 类别列表(从配置文件中获取)data=data, # 数据集对象fraction=cfg.fraction if mode == "train" else 1.0, # 数据集分数(训练模式下从配置文件获取,推断模式下为1.0))
# 构建用于 YOLO 数据集的数据加载器
def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32):"""Build YOLO Dataset."""# 返回一个 GroundingDataset 对象,用于训练或验证return GroundingDataset(img_path=img_path, # 图像文件路径json_file=json_file, # 包含标注信息的 JSON 文件路径imgsz=cfg.imgsz, # 图像尺寸batch_size=batch, # 批处理大小augment=mode == "train", # 是否进行数据增强hyp=cfg, # 配置信息对象,可能需要通过 get_hyps_from_cfg 函数获取rect=cfg.rect or rect, # 是否使用矩形批处理cache=cfg.cache or None, # 是否使用缓存single_cls=cfg.single_cls or False, # 是否为单类别检测stride=int(stride), # 步长pad=0.0 if mode == "train" else 0.5, # 边缘填充prefix=colorstr(f"{mode}: "), # 输出前缀task=cfg.task, # YOLO 的任务类型classes=cfg.classes, # 类别信息fraction=cfg.fraction if mode == "train" else 1.0, # 数据集的使用比例)# 构建用于训练或验证集的 DataLoader
def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1):"""Return an InfiniteDataLoader or DataLoader for training or validation set."""# 限制批处理大小不超过数据集的大小batch = min(batch, len(dataset))nd = torch.cuda.device_count() # CUDA 设备数量nw = min(os.cpu_count() // max(nd, 1), workers) # 确定使用的工作线程数量sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)generator = torch.Generator()generator.manual_seed(6148914691236517205 + RANK) # 设置随机数生成器种子# 返回一个 InfiniteDataLoader 或 DataLoader 对象return InfiniteDataLoader(dataset=dataset, # 数据集对象batch_size=batch, # 批处理大小shuffle=shuffle and sampler is None, # 是否打乱数据顺序num_workers=nw, # 工作线程数量sampler=sampler, # 分布式采样器pin_memory=PIN_MEMORY, # 是否将数据保存在固定内存中collate_fn=getattr(dataset, "collate_fn", None), # 数据集的整理函数worker_init_fn=seed_worker, # 工作线程初始化函数generator=generator, # 随机数生成器)# 检查输入数据源的类型,并返回相应的标志值
def check_source(source):"""Check source type and return corresponding flag values."""webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, Falseif isinstance(source, (str, int, Path)): # 检查是否为字符串、整数或路径source = str(source)is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS) # 检查是否为支持的图像或视频格式is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) # 检查是否为 URLwebcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) # 是否为摄像头screenshot = source.lower() == "screen" # 是否为屏幕截图if is_url and is_file:source = check_file(source) # 下载文件elif isinstance(source, LOADERS): # 检查是否为特定加载器类型in_memory = True # 是否在内存中elif isinstance(source, (list, tuple)): # 检查是否为列表或元组source = autocast_list(source) # 转换列表元素为 PIL 图像或 np 数组from_img = True # 是否从图像获取elif isinstance(source, (Image.Image, np.ndarray)): # 检查是否为 PIL 图像或 np 数组from_img = True # 是否从图像获取elif isinstance(source, torch.Tensor): # 检查是否为 PyTorch 张量tensor = True # 是否为张量else:raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") # 抛出错误,不支持的图像类型return source, webcam, screenshot, from_img, in_memory, tensor # 返回源数据及相关标志值# 加载推断数据源,用于目标检测,并应用必要的转换
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):"""Loads an inference source for object detection and applies necessary transformations."""# 返回一个 InfiniteDataLoader 对象,用于推断数据源加载return InfiniteDataLoader(dataset=dataset, # 数据集对象batch_size=batch, # 批处理大小shuffle=shuffle and sampler is None, # 是否打乱数据顺序num_workers=nw, # 工作线程数量sampler=sampler, # 分布式采样器pin_memory=PIN_MEMORY, # 是否将数据保存在固定内存中collate_fn=getattr(dataset, "collate_fn", None), # 数据集的整理函数worker_init_fn=seed_worker, # 工作线程初始化函数generator=generator, # 随机数生成器)Args:source (str, Path, Tensor, PIL.Image, np.ndarray): 接收推理输入的源数据类型,可以是文件路径、张量、图像对象等。batch (int, optional): 数据加载器的批大小。默认为1。vid_stride (int, optional): 视频源的帧间隔。默认为1。buffer (bool, optional): 决定流式帧是否缓存。默认为False。Returns:dataset (Dataset): 返回特定输入源的数据集对象。"""# 检查输入源的类型并进行适配source, stream, screenshot, from_img, in_memory, tensor = check_source(source)# 如果数据源在内存中,则使用其类型;否则根据源的不同选择源类型source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)# 数据加载器选择if tensor:# 如果输入源是张量,则加载张量数据集dataset = LoadTensor(source)elif in_memory:# 如果输入源在内存中,则直接使用该源作为数据集dataset = sourceelif stream:# 如果输入源是流式数据(视频流),则加载流数据集dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)elif screenshot:# 如果输入源是截图,则加载截图数据集dataset = LoadScreenshots(source)elif from_img:# 如果输入源是PIL图像或numpy数组,则加载对应数据集dataset = LoadPilAndNumpy(source)else:# 其他情况下(图片或视频文件),加载图片和视频数据集dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)# 将源类型附加到数据集对象setattr(dataset, "source_type", source_type)# 返回创建的数据集对象return dataset
.\yolov8\ultralytics\data\converter.py
# 导入必要的库和模块
import json
from collections import defaultdict
from pathlib import Pathimport cv2
import numpy as np# 导入 Ultralytics 自定义的日志记录和进度条显示工具
from ultralytics.utils import LOGGER, TQDM
# 导入 Ultralytics 自定义的文件处理工具中的路径增量函数
from ultralytics.utils.files import increment_path# 将 COCO 91 类别映射到 COCO 80 类别的函数
def coco91_to_coco80_class():"""Converts 91-index COCO class IDs to 80-index COCO class IDs.Returns:(list): A list of 91 class IDs where the index represents the 80-index class ID and the value is thecorresponding 91-index class ID."""return [0,1,2,3,4,5,6,7,8,9,10,None,11,12,13,14,15,16,17,18,19,20,21,22,23,None,24,25,None,None,26,27,28,29,30,31,32,33,34,35,36,37,38,39,None,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,None,60,None,None,61,None,62,63,64,65,66,67,68,69,70,71,72,None,73,74,75,76,77,78,79,None,]# 将 COCO 80 类别映射到 COCO 91 类别的函数
def coco80_to_coco91_class():"""Converts 80-index (val2014) to 91-index (paper).For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.Example:```pythonimport numpy as npa = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to cocox2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet```py"""# 返回一个包含指定整数的列表return [1, # 第一个整数2, # 第二个整数3, # 第三个整数4, # 第四个整数5, # 第五个整数6, # 第六个整数7, # 第七个整数8, # 第八个整数9, # 第九个整数10, # 第十个整数11, # 第十一个整数13, # 第十二个整数(注意此处应为第十三个整数,实际上有一个数字被跳过)14, # 第十四个整数15, # 第十五个整数16, # 第十六个整数17, # 第十七个整数18, # 第十八个整数19, # 第十九个整数20, # 第二十个整数21, # 第二十一个整数22, # 第二十二个整数23, # 第二十三个整数24, # 第二十四个整数25, # 第二十五个整数27, # 第二十六个整数28, # 第二十七个整数31, # 第二十八个整数32, # 第二十九个整数33, # 第三十个整数34, # 第三十一个整数35, # 第三十二个整数36, # 第三十三个整数37, # 第三十四个整数38, # 第三十五个整数39, # 第三十六个整数40, # 第三十七个整数41, # 第三十八个整数42, # 第三十九个整数43, # 第四十个整数44, # 第四十一个整数46, # 第四十二个整数47, # 第四十三个整数48, # 第四十四个整数49, # 第四十五个整数50, # 第四十六个整数51, # 第四十七个整数52, # 第四十八个整数53, # 第四十九个整数54, # 第五十个整数55, # 第五十一个整数56, # 第五十二个整数57, # 第五十三个整数58, # 第五十四个整数59, # 第五十五个整数60, # 第五十六个整数61, # 第五十七个整数62, # 第五十八个整数63, # 第五十九个整数64, # 第六十个整数65, # 第六十一个整数67, # 第六十二个整数70, # 第六十三个整数72, # 第六十四个整数73, # 第六十五个整数74, # 第六十六个整数75, # 第六十七个整数76, # 第六十八个整数77, # 第六十九个整数78, # 第七十个整数79, # 第七十一个整数80, # 第七十二个整数81, # 第七十三个整数82, # 第七十四个整数84, # 第七十五个整数85, # 第七十六个整数86, # 第七十七个整数87, # 第七十八个整数88, # 第七十九个整数89, # 第八十个整数90, # 第八十一个整数]
def convert_coco(labels_dir="../coco/annotations/",save_dir="coco_converted/",use_segments=False,use_keypoints=False,cls91to80=True,lvis=False,
):"""Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.Args:labels_dir (str, optional): Path to directory containing COCO dataset annotation files.save_dir (str, optional): Path to directory to save results to.use_segments (bool, optional): Whether to include segmentation masks in the output.use_keypoints (bool, optional): Whether to include keypoint annotations in the output.cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.lvis (bool, optional): Whether to convert data in lvis dataset way.Example:```pythonfrom ultralytics.data.converter import convert_cococonvert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)convert_coco('../datasets/lvis/annotations/', use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)```pyOutput:Generates output files in the specified output directory."""# Create dataset directorysave_dir = increment_path(save_dir) # 如果保存目录已存在,则增加路径编号for p in save_dir / "labels", save_dir / "images":p.mkdir(parents=True, exist_ok=True) # 创建目录# Convert classescoco80 = coco91_to_coco80_class() # 转换 COCO 数据集的 91 类别到 80 类别# Import jsonLOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")def convert_dota_to_yolo_obb(dota_root_path: str):"""Converts DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads theassociated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.Args:dota_root_path (str): The root directory path of the DOTA dataset.Example:```pythonfrom ultralytics.data.converter import convert_dota_to_yolo_obbconvert_dota_to_yolo_obb('path/to/DOTA')```pyNotes:The directory structure assumed for the DOTA dataset:- DOTA├─ images│ ├─ train│ └─ val└─ labels├─ train_original└─ val_originalAfter execution, the function will organize the labels into:- DOTA└─ labels├─ train└─ val"""dota_root_path = Path(dota_root_path)# Class names to indices mapping# 定义一个类别映射字典,将字符串类别映射到整数编码class_mapping = {"plane": 0,"ship": 1,"storage-tank": 2,"baseball-diamond": 3,"tennis-court": 4,"basketball-court": 5,"ground-track-field": 6,"harbor": 7,"bridge": 8,"large-vehicle": 9,"small-vehicle": 10,"helicopter": 11,"roundabout": 12,"soccer-ball-field": 13,"swimming-pool": 14,"container-crane": 15,"airport": 16,"helipad": 17,}def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir):"""将单个图片的DOTA标注转换为YOLO OBB格式,并保存到指定目录。"""# 构建原始标签文件路径和保存路径orig_label_path = orig_label_dir / f"{image_name}.txt"save_path = save_dir / f"{image_name}.txt"# 使用原始标签文件进行读取,保存转换后的标签with orig_label_path.open("r") as f, save_path.open("w") as g:lines = f.readlines()for line in lines:parts = line.strip().split()if len(parts) < 9:continue# 提取类别名称并映射到整数编码class_name = parts[8]class_idx = class_mapping[class_name]# 提取坐标信息并进行归一化coords = [float(p) for p in parts[:8]]normalized_coords = [coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)]# 格式化坐标信息,保留小数点后六位formatted_coords = ["{:.6g}".format(coord) for coord in normalized_coords]# 写入转换后的标签信息到文件中g.write(f"{class_idx} {' '.join(formatted_coords)}\n")# 对训练集和验证集两个阶段进行循环处理for phase in ["train", "val"]:# 构建图片路径、原始标签路径和保存标签的路径image_dir = dota_root_path / "images" / phaseorig_label_dir = dota_root_path / "labels" / f"{phase}_original"save_dir = dota_root_path / "labels" / phase# 如果保存标签的目录不存在,则创建save_dir.mkdir(parents=True, exist_ok=True)# 获取当前阶段图片的路径列表,并对每张图片进行处理image_paths = list(image_dir.iterdir())for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):# 如果图片不是PNG格式则跳过if image_path.suffix != ".png":continue# 获取图片名称(不含扩展名)、读取图片并获取其高度和宽度image_name_without_ext = image_path.stemimg = cv2.imread(str(image_path))h, w = img.shape[:2]# 调用函数将标签进行转换并保存到指定目录convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir)
# 将 YOLO 格式的边界框数据转换为分割数据或方向边界框(OBB)数据
# 生成分割数据时可能使用 SAM 自动标注器
def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt"):# 读取 SAM 模型的路径"""Args:im_dir (str): 图像文件夹的路径,包含待处理的图像save_dir (str, optional): 结果保存的文件夹路径,默认为 Nonesam_model (str, optional): SAM 自动标注器的模型文件名,默认为 "sam_b.pt"Returns:s (List[np.ndarray]): 连接后的分割数据列表,每个元素为 NumPy 数组"""Args:im_dir (str | Path): 要转换的图像目录的路径。save_dir (str | Path): 生成标签的保存路径,如果为None,则保存到与im_dir同级的`labels-segment`目录中。默认为None。sam_model (str): 用于中间分割数据的分割模型;可选参数。Notes:数据集假设的输入目录结构:- im_dir├─ 001.jpg├─ ..└─ NNN.jpg- labels├─ 001.txt├─ ..└─ NNN.txt"""from tqdm import tqdm # 导入进度条库tqdmfrom ultralytics import SAM # 导入分割模型SAMfrom ultralytics.data import YOLODataset # 导入YOLO数据集from ultralytics.utils import LOGGER # 导入日志记录器from ultralytics.utils.ops import xywh2xyxy # 导入辅助操作函数xywh2xyxy# NOTE: add placeholder to pass class index checkdataset = YOLODataset(im_dir, data=dict(names=list(range(1000)))) # 创建YOLO数据集对象,传入图像目录和类名列表if len(dataset.labels[0]["segments"]) > 0: # 如果存在分割数据LOGGER.info("Segmentation labels detected, no need to generate new ones!") # 记录日志,表示检测到分割标签,无需生成新标签return # 返回LOGGER.info("Detection labels detected, generating segment labels by SAM model!") # 记录日志,表示检测到检测标签,将使用SAM模型生成分割标签sam_model = SAM(sam_model) # 创建SAM模型对象for label in tqdm(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"): # 使用进度条遍历数据集标签h, w = label["shape"] # 获取标签图像的高度和宽度boxes = label["bboxes"] # 获取标签中的边界框信息if len(boxes) == 0: # 如果边界框数量为0,则跳过空标签continueboxes[:, [0, 2]] *= w # 将边界框的x坐标缩放到图像宽度上boxes[:, [1, 3]] *= h # 将边界框的y坐标缩放到图像高度上im = cv2.imread(label["im_file"]) # 读取标签对应的图像sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False) # 使用SAM模型进行分割,获取分割结果label["segments"] = sam_results[0].masks.xyn # 将分割结果存储在标签数据中的segments字段save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment" # 确定保存目录路径save_dir.mkdir(parents=True, exist_ok=True) # 创建保存目录,如果不存在则创建for label in dataset.labels: # 遍历数据集中的每个标签texts = [] # 存储要写入文件的文本列表lb_name = Path(label["im_file"]).with_suffix(".txt").name # 获取标签文件的名称txt_file = save_dir / lb_name # 确定要保存的文本文件路径cls = label["cls"] # 获取标签的类别信息for i, s in enumerate(label["segments"]): # 遍历每个分割标签line = (int(cls[i]), *s.reshape(-1)) # 构造要写入文件的一行文本内容texts.append(("%g " * len(line)).rstrip() % line) # 将文本内容格式化并添加到文本列表中if texts: # 如果存在文本内容with open(txt_file, "a") as f: # 打开文件,追加写入模式f.writelines(text + "\n" for text in texts) # 将文本列表中的内容逐行写入文件LOGGER.info(f"Generated segment labels saved in {save_dir}") # 记录日志,表示生成的分割标签已保存在指定目录中
.\yolov8\ultralytics\data\dataset.py
# Ultralytics YOLO 🚀, AGPL-3.0 license# 导入必要的模块和库
import contextlib
import json
from collections import defaultdict
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Pathimport cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import ConcatDataset# 导入 Ultralytics 自定义的工具函数和类
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
from ultralytics.utils.ops import resample_segments
from ultralytics.utils.torch_utils import TORCHVISION_0_18# 导入数据增强相关模块
from .augment import (Compose,Format,Instances,LetterBox,RandomLoadText,classify_augmentations,classify_transforms,v8_transforms,
)
# 导入基础数据集类和工具函数
from .base import BaseDataset
from .utils import (HELP_URL,LOGGER,get_hash,img2label_paths,load_dataset_cache_file,save_dataset_cache_file,verify_image,verify_image_label,
)# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
# 数据集缓存版本号
DATASET_CACHE_VERSION = "1.0.3"# YOLODataset 类,用于加载 YOLO 格式的对象检测和/或分割标签数据集
class YOLODataset(BaseDataset):"""Dataset class for loading object detection and/or segmentation labels in YOLO format.Args:data (dict, optional): A dataset YAML dictionary. Defaults to None.task (str): An explicit arg to point current task, Defaults to 'detect'.Returns:(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model."""# 初始化方法,设置数据集类型和任务类型def __init__(self, *args, data=None, task="detect", **kwargs):"""Initializes the YOLODataset with optional configurations for segments and keypoints."""# 根据任务类型设置是否使用分割标签、关键点标签或旋转矩形标签self.use_segments = task == "segment"self.use_keypoints = task == "pose"self.use_obb = task == "obb"self.data = data# 断言不能同时使用分割标签和关键点标签assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."# 调用父类 BaseDataset 的初始化方法super().__init__(*args, **kwargs)def cache_labels(self, path=Path("./labels.cache")):"""Cache dataset labels, check images and read shapes.Args:path (Path): Path where to save the cache file. Default is Path('./labels.cache').Returns:(dict): labels."""# 初始化空字典用于存储标签数据x = {"labels": []}# 初始化计数器和消息列表nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages# 构建描述信息字符串,表示正在扫描路径下的文件desc = f"{self.prefix}Scanning {path.parent / path.stem}..."# 获取图像文件总数total = len(self.im_files)# 从数据中获取关键点形状信息nkpt, ndim = self.data.get("kpt_shape", (0, 0))# 如果使用关键点信息且关键点数量或维度不正确,抛出异常if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):raise ValueError("'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of ""keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'")# 使用线程池处理图像验证任务with ThreadPool(NUM_THREADS) as pool:# 并行处理图像验证任务,获取验证结果results = pool.imap(func=verify_image_label,iterable=zip(self.im_files,self.label_files,repeat(self.prefix),repeat(self.use_keypoints),repeat(len(self.data["names"])),repeat(nkpt),repeat(ndim),),)# 初始化进度条对象pbar = TQDM(results, desc=desc, total=total)# 遍历进度条以显示验证进度for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:# 更新计数器nm += nm_fnf += nf_fne += ne_fnc += nc_f# 如果图像文件存在,则添加标签信息到x["labels"]中if im_file:x["labels"].append({"im_file": im_file,"shape": shape,"cls": lb[:, 0:1], # n, 1"bboxes": lb[:, 1:], # n, 4"segments": segments,"keypoints": keypoint,"normalized": True,"bbox_format": "xywh",})# 如果有消息,则添加到消息列表中if msg:msgs.append(msg)# 更新进度条描述信息pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"# 关闭进度条pbar.close()# 如果有警告消息,则记录日志if msgs:LOGGER.info("\n".join(msgs))# 如果未找到标签,则记录警告日志if nf == 0:LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")# 计算数据集文件的哈希值并存储在结果字典中x["hash"] = get_hash(self.label_files + self.im_files)# 将结果相关信息存储在结果字典中x["results"] = nf, nm, ne, nc, len(self.im_files)# 将警告消息列表存储在结果字典中x["msgs"] = msgs # warnings# 保存数据集缓存文件save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)# 返回结果字典return xdef get_labels(self):"""Returns dictionary of labels for YOLO training."""# 获取图像文件对应的标签文件路径字典self.label_files = img2label_paths(self.im_files)# 构建缓存文件路径,并尝试加载 *.cache 文件cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")try:# 尝试加载数据集缓存文件cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file# 检查缓存文件版本与哈希值是否匹配当前要求assert cache["version"] == DATASET_CACHE_VERSION # matches current versionassert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hashexcept (FileNotFoundError, AssertionError, AttributeError):# 加载失败时,重新生成标签缓存cache, exists = self.cache_labels(cache_path), False # run cache ops# 显示缓存信息nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, totalif exists and LOCAL_RANK in {-1, 0}:d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"TQDM(None, desc=self.prefix + d, total=n, initial=n) # display resultsif cache["msgs"]:LOGGER.info("\n".join(cache["msgs"])) # display warnings# 读取缓存内容[cache.pop(k) for k in ("hash", "version", "msgs")] # remove itemslabels = cache["labels"]if not labels:# 若缓存中无标签信息,则发出警告LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}")self.im_files = [lb["im_file"] for lb in labels] # update im_files# 检查数据集是否仅含有框或者分段信息lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))if len_segments and len_boxes != len_segments:# 若分段数与框数不相等,则发出警告,并移除所有分段信息LOGGER.warning(f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. ""To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")for lb in labels:lb["segments"] = []if len_cls == 0:# 若标签数量为零,则发出警告LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}")return labels# 构建并追加变换操作到列表中def build_transforms(self, hyp=None):"""Builds and appends transforms to the list."""# 如果启用数据增强if self.augment:# 设置混合和镶嵌的比例,如果未使用矩形模式则为0.0hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0# 使用指定的版本和超参数构建变换transforms = v8_transforms(self, self.imgsz, hyp)else:# 否则,使用指定的图像尺寸创建 LetterBox 变换transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])# 添加格式化变换到变换列表transforms.append(Format(bbox_format="xywh",normalize=True,return_mask=self.use_segments,return_keypoint=self.use_keypoints,return_obb=self.use_obb,batch_idx=True,mask_ratio=hyp.mask_ratio,mask_overlap=hyp.overlap_mask,bgr=hyp.bgr if self.augment else 0.0, # 仅影响训练时的图像背景))return transforms# 关闭镶嵌,复制粘贴和混合选项,并构建转换def close_mosaic(self, hyp):"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""# 将镶嵌比例设置为0.0hyp.mosaic = 0.0# 保持与之前版本v8 close-mosaic相同的行为,复制粘贴比例设置为0.0hyp.copy_paste = 0.0# 保持与之前版本v8 close-mosaic相同的行为,混合比例设置为0.0hyp.mixup = 0.0# 使用给定超参数构建转换self.transforms = self.build_transforms(hyp)def update_labels_info(self, label):"""Custom your label format here.Note:cls is not with bboxes now, classification and semantic segmentation need an independent cls labelCan also support classification and semantic segmentation by adding or removing dict keys there."""# 弹出标签中的边界框信息bboxes = label.pop("bboxes")# 弹出标签中的分割信息,默认为空列表segments = label.pop("segments", [])# 弹出标签中的关键点信息,默认为Nonekeypoints = label.pop("keypoints", None)# 弹出标签中的边界框格式信息bbox_format = label.pop("bbox_format")# 弹出标签中的归一化信息normalized = label.pop("normalized")# 如果使用方向框,则设置分割重新采样数为100,否则设置为1000segment_resamples = 100 if self.use_obb else 1000# 如果存在分割信息if len(segments) > 0:# 对分割信息进行重采样,返回重采样后的堆栈数组segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)else:# 否则创建全零数组,形状为(0, 1000, 2)segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)# 创建实例对象,包含边界框、分割、关键点等信息label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)return label# 定义一个函数用于将数据样本整理成批次def collate_fn(batch):"""Collates data samples into batches."""# 创建一个新的批次字典new_batch = {}# 获取批次中第一个样本的所有键keys = batch[0].keys()# 获取批次中所有样本的值,并转置成列表形式values = list(zip(*[list(b.values()) for b in batch]))# 遍历所有键值对for i, k in enumerate(keys):# 获取当前键对应的值列表value = values[i]# 如果键是 "img",则将值列表堆叠为张量if k == "img":value = torch.stack(value, 0)# 如果键在 {"masks", "keypoints", "bboxes", "cls", "segments", "obb"} 中,# 则将值列表连接为张量if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:value = torch.cat(value, 0)# 将处理后的值赋给新的批次字典对应的键new_batch[k] = value# 将新的批次索引列表转换为列表形式new_batch["batch_idx"] = list(new_batch["batch_idx"])# 为每个批次索引添加目标图像的索引以供 build_targets() 使用for i in range(len(new_batch["batch_idx"])):new_batch["batch_idx"][i] += i # add target image index for build_targets()# 将处理后的批次索引连接为张量new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)# 返回整理好的新批次字典return new_batch
class YOLOMultiModalDataset(YOLODataset):"""Dataset class for loading object detection and/or segmentation labels in YOLO format.Args:data (dict, optional): A dataset YAML dictionary. Defaults to None.task (str): An explicit arg to point current task, Defaults to 'detect'.Returns:(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model."""def __init__(self, *args, data=None, task="detect", **kwargs):"""Initializes a dataset object for object detection tasks with optional specifications."""# 调用父类构造函数初始化对象super().__init__(*args, data=data, task=task, **kwargs)def update_labels_info(self, label):"""Add texts information for multi-modal model training."""# 调用父类方法更新标签信息labels = super().update_labels_info(label)# NOTE: some categories are concatenated with its synonyms by `/`.# 将数据集中的类别名按照 `/` 分割成列表,添加到标签中labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]return labelsdef build_transforms(self, hyp=None):"""Enhances data transformations with optional text augmentation for multi-modal training."""# 调用父类方法构建数据转换列表transforms = super().build_transforms(hyp)if self.augment:# NOTE: hard-coded the args for now.# 如果开启数据增强,插入一个文本加载的转换操作transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True))return transformsclass GroundingDataset(YOLODataset):"""Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format."""def __init__(self, *args, task="detect", json_file, **kwargs):"""Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file."""# 断言任务类型为 "detect"assert task == "detect", "`GroundingDataset` only support `detect` task for now!"self.json_file = json_file# 调用父类构造函数初始化对象super().__init__(*args, task=task, data={}, **kwargs)def get_img_files(self, img_path):"""The image files would be read in `get_labels` function, return empty list here."""# 返回空列表,因为图像文件在 `get_labels` 函数中读取return []def get_labels(self):"""Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image."""labels = [] # 初始化空列表用于存储标签数据LOGGER.info("Loading annotation file...") # 记录日志,指示正在加载注释文件with open(self.json_file, "r") as f:annotations = json.load(f) # 从 JSON 文件中加载注释数据images = {f'{x["id"]:d}': x for x in annotations["images"]} # 创建图像字典,以图像ID为键img_to_anns = defaultdict(list)for ann in annotations["annotations"]:img_to_anns[ann["image_id"]].append(ann) # 根据图像ID将注释分组到字典中for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):img = images[f"{img_id:d}"] # 获取当前图像的信息h, w, f = img["height"], img["width"], img["file_name"] # 获取图像的高度、宽度和文件名im_file = Path(self.img_path) / f # 构建图像文件的路径if not im_file.exists():continue # 如果图像文件不存在,则跳过处理self.im_files.append(str(im_file)) # 将图像文件路径添加到实例变量中bboxes = [] # 初始化空列表用于存储边界框信息cat2id = {} # 初始化空字典,用于存储类别到ID的映射关系texts = [] # 初始化空列表用于存储文本信息for ann in anns:if ann["iscrowd"]:continue # 如果注释标记为iscrowd,则跳过处理box = np.array(ann["bbox"], dtype=np.float32) # 获取注释中的边界框信息并转换为numpy数组box[:2] += box[2:] / 2 # 将边界框坐标转换为中心点坐标box[[0, 2]] /= float(w) # 归一化边界框的x坐标box[[1, 3]] /= float(h) # 归一化边界框的y坐标if box[2] <= 0 or box[3] <= 0:continue # 如果边界框的宽度或高度小于等于零,则跳过处理cat_name = " ".join([img["caption"][t[0]:t[1]] for t in ann["tokens_positive"]]) # 从tokens_positive获取类别名称if cat_name not in cat2id:cat2id[cat_name] = len(cat2id) # 将类别名称映射到唯一的IDtexts.append([cat_name]) # 将类别名称添加到文本列表中cls = cat2id[cat_name] # 获取类别的IDbox = [cls] + box.tolist() # 将类别ID与边界框信息合并if box not in bboxes:bboxes.append(box) # 将边界框信息添加到列表中lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32) # 构建边界框数组或者空数组labels.append({"im_file": im_file,"shape": (h, w),"cls": lb[:, 0:1], # 提取类别信息,n行1列"bboxes": lb[:, 1:], # 提取边界框信息,n行4列"normalized": True,"bbox_format": "xywh","texts": texts,}) # 将图像信息和处理后的标签数据添加到标签列表中return labels # 返回所有图像的标签信息列表def build_transforms(self, hyp=None):"""Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity."""transforms = super().build_transforms(hyp) # 调用父类方法,获取基本的数据增强列表if self.augment:# NOTE: hard-coded the args for now.transforms.insert(-1, RandomLoadText(max_samples=80, padding=True)) # 在数据增强列表的倒数第二个位置插入文本加载的随机操作return transforms # 返回配置后的数据增强列表
class YOLOConcatDataset(ConcatDataset):"""Dataset as a concatenation of multiple datasets.This class is useful to assemble different existing datasets."""@staticmethoddef collate_fn(batch):"""Collates data samples into batches."""return YOLODataset.collate_fn(batch)# TODO: support semantic segmentation
class SemanticDataset(BaseDataset):"""Semantic Segmentation Dataset.This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalitiesfrom the BaseDataset class.Note:This class is currently a placeholder and needs to be populated with methods and attributes for supportingsemantic segmentation tasks."""def __init__(self):"""Initialize a SemanticDataset object."""super().__init__()class ClassificationDataset:"""Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like imageaugmentation, caching, and verification. It's designed to efficiently handle large datasets for training deeplearning models, with optional image transformations and caching mechanisms to speed up training.This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching imagesin RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification processto ensure data integrity and consistency.Attributes:cache_ram (bool): Indicates if caching in RAM is enabled.cache_disk (bool): Indicates if caching on disk is enabled.samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cachefile (if caching on disk), and optionally the loaded image array (if caching in RAM).torch_transforms (callable): PyTorch transforms to be applied to the images."""def __getitem__(self, i):"""Returns subset of data and targets corresponding to given indices."""f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), imageif self.cache_ram:if im is None: # Warning: two separate if statements required here, do not combine this with previous lineim = self.samples[i][3] = cv2.imread(f)elif self.cache_disk:if not fn.exists(): # load npynp.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)im = np.load(fn)else: # read imageim = cv2.imread(f) # BGR# Convert NumPy array to PIL imageim = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))sample = self.torch_transforms(im)return {"img": sample, "cls": j}def __len__(self) -> int:"""Return the total number of samples in the dataset."""return len(self.samples)def verify_images(self):"""Verify all images in dataset."""# 构建描述信息,指定要扫描的根目录desc = f"{self.prefix}Scanning {self.root}..."# 根据根目录生成对应的缓存文件路径path = Path(self.root).with_suffix(".cache") # *.cache file path# 尝试加载缓存文件,处理可能出现的文件未找到、断言错误和属性错误with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):# 加载数据集缓存文件cache = load_dataset_cache_file(path) # attempt to load a *.cache file# 断言缓存文件版本与当前版本匹配assert cache["version"] == DATASET_CACHE_VERSION # matches current version# 断言缓存文件的哈希与数据集样本的哈希一致assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash# 解构缓存结果,包括发现的、丢失的、空的、损坏的样本数量以及样本列表nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total# 如果在主机的本地或者单个进程运行时,显示描述信息和进度条if LOCAL_RANK in {-1, 0}:d = f"{desc} {nf} images, {nc} corrupt"TQDM(None, desc=d, total=n, initial=n)# 如果存在警告消息,则记录日志显示if cache["msgs"]:LOGGER.info("\n".join(cache["msgs"])) # display warnings# 返回样本列表return samples# 如果未能检索到缓存文件,则执行扫描操作nf, nc, msgs, samples, x = 0, 0, [], [], {}# 使用线程池并发执行图像验证函数with ThreadPool(NUM_THREADS) as pool:results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))# 创建进度条并显示扫描描述信息pbar = TQDM(results, desc=desc, total=len(self.samples))for sample, nf_f, nc_f, msg in pbar:# 如果图像未损坏,则将其添加到样本列表中if nf_f:samples.append(sample)# 如果存在警告消息,则添加到消息列表中if msg:msgs.append(msg)# 更新发现的和损坏的图像数量nf += nf_fnc += nc_f# 更新进度条的描述信息pbar.desc = f"{desc} {nf} images, {nc} corrupt"# 关闭进度条pbar.close()# 如果存在警告消息,则记录日志显示if msgs:LOGGER.info("\n".join(msgs))# 计算数据集样本的哈希值并保存相关信息到 x 字典x["hash"] = get_hash([x[0] for x in self.samples])x["results"] = nf, nc, len(samples), samplesx["msgs"] = msgs # warnings# 将数据集缓存信息保存到缓存文件中save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)# 返回发现的样本列表return samples
.\yolov8\ultralytics\data\explorer\explorer.py
data: Union[str, Path] = "coco128.yaml",model: str = "yolov8n.pt",uri: str = USER_CONFIG_DIR / "explorer",初始化方法,接受数据配置文件路径或字符串,默认为"coco128.yaml";模型文件名,默认为"yolov8n.pt";URI路径,默认为用户配置目录下的"explorer"。self.data = Path(data)self.model = Path(model)self.uri = Path(uri)将传入的数据路径、模型路径和URI路径转换为`Path`对象,并分别赋值给实例变量`self.data`、`self.model`和`self.uri`。self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")根据当前系统是否支持CUDA,选择使用GPU(如果可用)或CPU,并将设备类型赋值给实例变量`self.device`。self.model = YOLO(self.model).to(self.device).eval()使用`YOLO`类加载指定的YOLO模型文件,并将其移动到之前确定的设备(GPU或CPU),然后设置为评估模式(eval),覆盖之前定义的`self.model`。self.data = ExplorerDataset(self.data)使用`ExplorerDataset`类加载指定的数据配置文件,并赋值给实例变量`self.data`,以供后续数据集探索和操作使用。def embed_images(self, images: List[Union[np.ndarray, str, Path]]) -> List[np.ndarray]:"""Embeds a list of images into feature vectors using the initialized YOLO model."""定义一个方法`embed_images`,接受一个包含图像的列表(可以是`np.ndarray`、字符串路径或`Path`对象),返回一个包含特征向量的`np.ndarray`列表。embeddings = []for image in tqdm(images, desc="Embedding images"):if isinstance(image, (str, Path)):image = cv2.imread(str(image)) # BGRif image is None:LOGGER.error(f"Image Not Found {image}")embeddings.append(None)continueif isinstance(image, np.ndarray):image = torch.from_numpy(image).to(self.device).float() / 255.0else:embeddings.append(None)continueif image.ndimension() == 3:image = image.unsqueeze(0)with torch.no_grad():features = self.model(image)[0].cpu().numpy()embeddings.append(features)return embeddings遍历图像列表,对每张图像进行以下操作:如果图像是字符串路径或`Path`对象,使用OpenCV加载图像(格式为BGR);如果加载失败,记录错误并在嵌入列表中添加`None`;如果图像是`np.ndarray`,将其转换为`torch.Tensor`并移动到设备上,然后进行归一化处理;最后使用YOLO模型提取图像特征,将特征向量添加到嵌入列表中。def create_table(self, schema: dict) -> bool:"""Creates a table in LanceDB using the provided schema."""定义一个方法`create_table`,接受一个表示表结构的字典作为参数,返回布尔值表示是否成功创建表。success = Falsetry:success = get_table_schema(self.uri, schema)except Exception as e:LOGGER.error(f"Error creating table: {e}")return success尝试调用`get_table_schema`函数,使用提供的URI路径和表结构字典创建表格。如果出现异常,记录错误信息,并返回`False`;否则返回函数调用结果。def query_similarity(self, image: Union[np.ndarray, str, Path], threshold: float = 0.5) -> List[Tuple[str, float]]:"""Queries LanceDB for images similar to the provided image, using YOLO features."""定义一个方法`query_similarity`,接受一个图像(可以是`np.ndarray`、字符串路径或`Path`对象)和相似度阈值作为参数,返回一个包含(文件名,相似度得分)元组的列表。schema = get_sim_index_schema()调用`get_sim_index_schema`函数,获取相似度索引的模式。results = []try:image_embed = self.embed_images([image])[0]调用`embed_images`方法,将提供的图像转换为特征向量。if image_embed is None:return results如果特征向量为空,直接返回空结果列表。query_result = prompt_sql_query(self.uri, schema, image_embed, threshold)使用提供的URI路径、模式、特征向量和阈值,调用`prompt_sql_query`函数执行相似度查询。results = [(r[0], float(r[1])) for r in query_result]except Exception as e:LOGGER.error(f"Error querying similarity: {e}")return results遍历查询结果,将文件名和相似度得分组成元组,并将它们添加到结果列表中。如果出现异常,记录错误信息,并返回空的结果列表。) -> None:"""初始化 Explorer 类,设置数据集路径、模型和数据库连接的 URI。"""# 注意 duckdb==0.10.0 的 bug https://github.com/ultralytics/ultralytics/pull/8181checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])import lancedb# 建立与数据库的连接self.connection = lancedb.connect(uri)# 设定表格名称,使用数据路径和模型名称的小写形式self.table_name = f"{Path(data).name.lower()}_{model.lower()}"# 设定相似度索引的基础名称,用于重用表格并添加阈值和 top_k 参数self.sim_idx_base_name = (f"{self.table_name}_sim_idx".lower()) # 使用这个名称并附加阈值和 top_k 以重用表格# 初始化 YOLO 模型self.model = YOLO(model)# 数据路径self.data = data # None# 选择集合为空self.choice_set = None# 表格为空self.table = None# 进度为 0self.progress = 0def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:"""创建包含数据集中图像嵌入的 LanceDB 表格。如果表格已经存在,则会重用它。传入 force=True 来覆盖现有表格。Args:force (bool): 是否覆盖现有表格。默认为 False。split (str): 要使用的数据集拆分。默认为 'train'。Example:```pythonexp = Explorer()exp.create_embeddings_table()```py"""# 如果表格已存在且不强制覆盖,则返回if self.table is not None and not force:LOGGER.info("表格已存在。正在重用。传入 force=True 来覆盖它。")return# 如果表格名称在连接的表格列表中且不强制覆盖,则重用表格if self.table_name in self.connection.table_names() and not force:LOGGER.info(f"表格 {self.table_name} 已存在。正在重用。传入 force=True 来覆盖它。")self.table = self.connection.open_table(self.table_name)self.progress = 1return# 如果数据为空,则抛出 ValueErrorif self.data is None:raise ValueError("必须提供数据以创建嵌入表格")# 检查数据集的详细信息data_info = check_det_dataset(self.data)# 如果拆分参数不在数据集信息中,则抛出 ValueErrorif split not in data_info:raise ValueError(f"数据集中找不到拆分 {split}。数据集中可用的键为 {list(data_info.keys())}")# 获取选择集并确保其为列表形式choice_set = data_info[split]choice_set = choice_set if isinstance(choice_set, list) else [choice_set]self.choice_set = choice_set# 创建 ExplorerDataset 实例dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)# 创建表格模式batch = dataset[0]# 获取嵌入向量的大小vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]# 创建表格table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")# 向表格添加数据table.add(self._yield_batches(dataset,data_info,self.model,exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],))self.table = tabledef _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):"""Generates batches of data for embedding, excluding specified keys."""# 遍历数据集中的每个样本for i in tqdm(range(len(dataset))):# 更新进度条self.progress = float(i + 1) / len(dataset)# 获取当前样本数据batch = dataset[i]# 排除指定的键for k in exclude_keys:batch.pop(k, None)# 对批次数据进行清洗batch = sanitize_batch(batch, data_info)# 使用模型对图像文件进行嵌入batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()# 生成包含当前批次的列表,并进行 yieldyield [batch]def query(self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25) -> Any: # pyarrow.Table"""Query the table for similar images. Accepts a single image or a list of images.Args:imgs (str or list): Path to the image or a list of paths to the images.limit (int): Number of results to return.Returns:(pyarrow.Table): An arrow table containing the results. Supports converting to:- pandas dataframe: `result.to_pandas()`- dict of lists: `result.to_pydict()`Example:```pythonexp = Explorer()exp.create_embeddings_table()similar = exp.query(imgs=['https://ultralytics.com/images/zidane.jpg'])```py"""# 检查表格是否已创建if self.table is None:raise ValueError("Table is not created. Please create the table first.")# 如果 imgs 是单个字符串,则转换为列表if isinstance(imgs, str):imgs = [imgs]# 断言 imgs 类型为列表assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"# 使用模型嵌入图像数据embeds = self.model.embed(imgs)# 如果传入多张图像,则计算平均嵌入向量embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()# 使用嵌入向量进行查询,并限制结果数量return self.table.search(embeds).limit(limit).to_arrow()def sql_query(self, query: str, return_type: str = "pandas"):"""Execute an SQL query on the embedded data.Args:query (str): SQL query string.return_type (str): Type of the return data. Default is "pandas".Returns:Depending on return_type:- "pandas": Returns a pandas dataframe.- "arrow": Returns a pyarrow Table.- "dict": Returns a dictionary.Example:```pythonexp = Explorer()query_result = exp.sql_query("SELECT * FROM embeddings WHERE category='person'", return_type='arrow')```py"""# 执行 SQL 查询,并根据返回类型返回相应的数据结构if return_type == "pandas":return pd.read_sql_query(query, self.conn)elif return_type == "arrow":return pa.Table.from_pandas(pd.read_sql_query(query, self.conn))elif return_type == "dict":return pd.read_sql_query(query, self.conn).to_dict(orient='list')else:raise ValueError(f"Unsupported return_type: {return_type}. Choose from 'pandas', 'arrow', or 'dict'.")) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table"""Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.Args:query (str): SQL query to run.return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.Returns:(pyarrow.Table): An arrow table containing the results.Example:```pythonexp = Explorer()exp.create_embeddings_table()query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"result = exp.sql_query(query)```py"""# Ensure the return_type is either 'pandas' or 'arrow'assert return_type in {"pandas","arrow",}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"import duckdb# Raise an error if the table is not createdif self.table is None:raise ValueError("Table is not created. Please create the table first.")# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.# Convert the internal table representation to Arrow formattable = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB# Check if the query starts with correct SQL keywordsif not query.startswith("SELECT") and not query.startswith("WHERE"):raise ValueError(f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "f"clause. found {query}")# If the query starts with WHERE, prepend it with SELECT * FROM 'table'if query.startswith("WHERE"):query = f"SELECT * FROM 'table' {query}"# Log the query being executedLOGGER.info(f"Running query: {query}")# Execute the SQL query using duckdbrs = duckdb.sql(query)# Return the result based on the specified return_typeif return_type == "arrow":return rs.arrow()elif return_type == "pandas":return rs.df()def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:"""Plot the results of a SQL-Like query on the table.Args:query (str): SQL query to run.labels (bool): Whether to plot the labels or not.Returns:(PIL.Image): Image containing the plot.Example:```pythonexp = Explorer()exp.create_embeddings_table()query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"result = exp.plot_sql_query(query)```py"""# Execute the SQL query with return_type='arrow' to get the result as an Arrow tableresult = self.sql_query(query, return_type="arrow")# If no results are found, log and return Noneif len(result) == 0:LOGGER.info("No results found.")return None# Generate a plot based on the query result and return it as a PIL Imageimg = plot_query_result(result, plot_labels=labels)return Image.fromarray(img)def get_similar(self,img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,idx: Union[int, List[int]] = None,limit: int = 25,return_type: str = "pandas",) -> Any: # pandas.DataFrame or pyarrow.Table"""Query the table for similar images. Accepts a single image or a list of images.Args:img (str or list): Path to the image or a list of paths to the images.idx (int or list): Index of the image in the table or a list of indexes.limit (int): Number of results to return. Defaults to 25.return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.Returns:(pandas.DataFrame or pyarrow.Table): Depending on return_type, either a DataFrame or a Table.Example:```pythonexp = Explorer()exp.create_embeddings_table()similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')```py"""assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}"# Check if img argument is valid and normalize itimg = self._check_imgs_or_idxs(img, idx)# Query for similar images using the normalized img argumentsimilar = self.query(img, limit=limit)if return_type == "arrow":# Return the query result as a pyarrow.Tablereturn similarelif return_type == "pandas":# Convert the query result to a pandas DataFrame and returnreturn similar.to_pandas()def plot_similar(self,img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,idx: Union[int, List[int]] = None,limit: int = 25,labels: bool = True,) -> Image.Image:"""Plot the similar images. Accepts images or indexes.Args:img (str or list): Path to the image or a list of paths to the images.idx (int or list): Index of the image in the table or a list of indexes.labels (bool): Whether to plot the labels or not.limit (int): Number of results to return. Defaults to 25.Returns:(PIL.Image): Image containing the plot.Example:```pythonexp = Explorer()exp.create_embeddings_table()similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')```py"""# Retrieve similar images data in arrow formatsimilar = self.get_similar(img, idx, limit, return_type="arrow")# If no similar images found, log and return Noneif len(similar) == 0:LOGGER.info("No results found.")return None# Plot the query result and return as a PIL.Imageimg = plot_query_result(similar, plot_labels=labels)return Image.fromarray(img)def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame"""Calculate the similarity index of all the images in the table. Here, the index will contain the data points thatare max_dist or closer to the image in the embedding space at a given index.Args:max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.vector search. Defaults: None.force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.Returns:(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,and columns include indices of similar images and their respective distances.Example:```pythonexp = Explorer()exp.create_embeddings_table()sim_idx = exp.similarity_index()```py"""# 如果表不存在,则抛出值错误异常if self.table is None:raise ValueError("Table is not created. Please create the table first.")# 构建相似性索引表名,包括最大距离和top_k参数sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()# 如果指定的相似性索引表名已经存在且不强制覆盖,则记录日志并返回现有表的 pandas 数据帧if sim_idx_table_name in self.connection.table_names() and not force:LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")return self.connection.open_table(sim_idx_table_name).to_pandas()# 如果指定了top_k参数且不在0到1之间,则抛出值错误异常if top_k and not (1.0 >= top_k >= 0.0):raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")# 如果max_dist小于0,则抛出值错误异常if max_dist < 0.0:raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")# 计算实际的top_k数量,确保不小于1top_k = int(top_k * len(self.table)) if top_k else len(self.table)top_k = max(top_k, 1)# 从表中提取特征向量和图像文件名features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()im_files = features["im_file"]embeddings = features["vector"]# 创建相似性索引表,使用指定的表名和模式sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")def _yield_sim_idx():"""Generates a dataframe with similarity indices and distances for images."""# 使用进度条遍历嵌入向量列表for i in tqdm(range(len(embeddings))):# 在表中搜索与当前嵌入向量最相似的top_k项,并限制距离小于等于max_dist的项sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")# 生成包含相似性索引信息的列表yield [{"idx": i,"im_file": im_files[i],"count": len(sim_idx),"sim_im_files": sim_idx["im_file"].tolist(),}]# 将相似性索引信息添加到相似性索引表中sim_table.add(_yield_sim_idx())# 更新对象的相似性索引属性self.sim_index = sim_table# 返回相似性索引表的 pandas 数据帧return sim_table.to_pandas()def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:"""Plot the similarity index of all the images in the table. Here, the index will contain the data points that aremax_dist or closer to the image in the embedding space at a given index.Args:max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.top_k (float): Percentage of closest data points to consider when counting. Used to apply limit whenrunning vector search. Defaults to 0.01.force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.Returns:(PIL.Image): Image containing the plot.Example:```pythonexp = Explorer()exp.create_embeddings_table()similarity_idx_plot = exp.plot_similarity_index()similarity_idx_plot.show() # view image previewsimilarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file```py"""# Retrieve similarity index based on provided parameterssim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)# Extract counts of similar images from the similarity indexsim_count = sim_idx["count"].tolist()sim_count = np.array(sim_count)# Generate indices for the bar plotindices = np.arange(len(sim_count))# Create the bar plot using matplotlibplt.bar(indices, sim_count)# Customize the plot with labels and titleplt.xlabel("data idx")plt.ylabel("Count")plt.title("Similarity Count")# Save the plot to a PNG image in memorybuffer = BytesIO()plt.savefig(buffer, format="png")buffer.seek(0)# Use Pillow to open the image from the buffer and return itreturn Image.fromarray(np.array(Image.open(buffer)))def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]) -> List[np.ndarray]:"""Determines whether to fetch images or indexes based on provided arguments and returns image paths."""# Check if both img and idx are None, which is not allowedif img is None and idx is None:raise ValueError("Either img or idx must be provided.")# Check if both img and idx are provided, which is also not allowedif img is not None and idx is not None:raise ValueError("Only one of img or idx must be provided.")# If idx is provided, fetch corresponding image paths from the tableif idx is not None:idx = idx if isinstance(idx, list) else [idx]img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]# Return a list of image paths as numpy arraysreturn img if isinstance(img, list) else [img]# 定义一个方法,用于向AI提出问题并获取结果def ask_ai(self, query):"""Ask AI a question.Args:query (str): Question to ask.Returns:(pandas.DataFrame): A dataframe containing filtered results to the SQL query.Example:```pythonexp = Explorer()exp.create_embeddings_table()answer = exp.ask_ai('Show images with 1 person and 2 dogs')```py"""# 使用提供的查询字符串调用prompt_sql_query函数,并获取结果result = prompt_sql_query(query)try:# 尝试使用结果调用sql_query方法,返回处理后的数据帧return self.sql_query(result)except Exception as e:# 如果出现异常,记录错误信息到日志,并返回NoneLOGGER.error("AI generated query is not valid. Please try again with a different prompt")LOGGER.error(e)return None# 定义一个方法,用于可视化查询结果,但当前未实现任何功能def visualize(self, result):"""Visualize the results of a query. TODO.Args:result (pyarrow.Table): Table containing the results of a query."""# 目前这个方法没有实现任何功能,因此pass# 定义一个方法,用于生成数据集的报告,但当前未实现任何功能def generate_report(self, result):"""Generate a report of the dataset.TODO"""# 目前这个方法没有实现任何功能,因此pass