Yolov8-源码解析-四十四-

news/2024/11/15 16:25:40/文章来源:https://www.cnblogs.com/apachecn/p/18398145

Yolov8 源码解析(四十四)

.\yolov8\ultralytics\utils\triton.py

# Ultralytics YOLO 🚀, AGPL-3.0 license# 引入必要的类型
from typing import List
# 解析 URL 的组件
from urllib.parse import urlsplit# 引入 NumPy 库
import numpy as npclass TritonRemoteModel:"""用于与远程 Triton 推理服务器模型进行交互的客户端类。Attributes:endpoint (str): Triton 服务器上模型的名称。url (str): Triton 服务器的 URL。triton_client: Triton 客户端(可以是 HTTP 或 gRPC)。InferInput: Triton 客户端的输入类。InferRequestedOutput: Triton 客户端的输出请求类。input_formats (List[str]): 模型输入的数据类型。np_input_formats (List[type]): 模型输入的 NumPy 数据类型。input_names (List[str]): 模型输入的名称列表。output_names (List[str]): 模型输出的名称列表。"""def __init__(self, url: str, endpoint: str = "", scheme: str = ""):"""初始化 TritonRemoteModel。参数可以单独提供,也可以从形如 <scheme>://<netloc>/<endpoint>/<task_name> 的 'url' 参数中解析。Args:url (str): Triton 服务器的 URL。endpoint (str): Triton 服务器上模型的名称。scheme (str): 通信协议('http' 或 'grpc')。"""if not endpoint and not scheme:  # 从 URL 字符串中解析所有参数splits = urlsplit(url)endpoint = splits.path.strip("/").split("/")[0]scheme = splits.schemeurl = splits.netlocself.endpoint = endpointself.url = url# 根据通信协议选择 Triton 客户端if scheme == "http":import tritonclient.http as client  # noqaself.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)config = self.triton_client.get_model_config(endpoint)else:import tritonclient.grpc as client  # noqaself.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]# 按字母顺序对输出名称进行排序,例如 'output0', 'output1' 等。config["output"] = sorted(config["output"], key=lambda x: x.get("name"))# 定义模型属性type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}self.InferRequestedOutput = client.InferRequestedOutputself.InferInput = client.InferInputself.input_formats = [x["data_type"] for x in config["input"]]self.np_input_formats = [type_map[x] for x in self.input_formats]self.input_names = [x["name"] for x in config["input"]]self.output_names = [x["name"] for x in config["output"]]# 定义一个特殊方法 __call__,允许将实例对象像函数一样调用,接受多个 numpy 数组作为输入,并返回多个 numpy 数组作为输出def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:"""Call the model with the given inputs.Args:*inputs (List[np.ndarray]): Input data to the model.Returns:(List[np.ndarray]): Model outputs."""# 初始化一个空列表,用于存储推理输入infer_inputs = []# 获取输入数组的数据类型,假设所有输入数组的数据类型相同input_format = inputs[0].dtype# 遍历所有输入数组for i, x in enumerate(inputs):# 如果当前输入数组的数据类型与预期的输入数据类型不匹配,将其转换为预期的数据类型if x.dtype != self.np_input_formats[i]:x = x.astype(self.np_input_formats[i])# 创建一个推理输入对象,指定输入名称、形状和数据类型infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))# 将 numpy 数组的数据复制到推理输入对象中infer_input.set_data_from_numpy(x)# 将推理输入对象添加到推理输入列表中infer_inputs.append(infer_input)# 根据输出名称列表创建推理输出对象列表infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]# 调用 Triton 客户端进行推理,传入模型名称、输入列表和输出列表,并获取推理结果outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)# 将每个输出结果转换回原始输入数据类型,并存储在列表中返回return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

.\yolov8\ultralytics\utils\tuner.py

# 导入 subprocess 模块,用于执行外部命令
import subprocess# 从 ultralytics 的 cfg 模块中导入 TASK2DATA、TASK2METRIC、get_save_dir 函数
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir
# 从 ultralytics 的 utils 模块中导入 DEFAULT_CFG、DEFAULT_CFG_DICT、LOGGER、NUM_THREADS、checks 函数
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks# 定义函数 run_ray_tune,用于运行 Ray Tune 进行超参数调优
def run_ray_tune(model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args
):"""Runs hyperparameter tuning using Ray Tune.Args:model (YOLO): Model to run the tuner on.space (dict, optional): The hyperparameter search space. Defaults to None.grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10.gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None.max_samples (int, optional): The maximum number of trials to run. Defaults to 10.train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}.Returns:(dict): A dictionary containing the results of the hyperparameter search.Example:```pythonfrom ultralytics import YOLO# Load a YOLOv8n modelmodel = YOLO('yolov8n.pt')# Start tuning hyperparameters for YOLOv8n training on the COCO8 datasetresult_grid = model.tune(data='coco8.yaml', use_ray=True)```"""# 在日志中输出一条信息,引导用户了解 RayTune 的文档链接LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")# 如果 train_args 为 None,则将其设为一个空字典if train_args is None:train_args = {}try:# 尝试安装 Ray Tune 相关依赖subprocess.run("pip install ray[tune]".split(), check=True)  # do not add single quotes here# 导入必要的 Ray Tune 模块import rayfrom ray import tunefrom ray.air import RunConfigfrom ray.air.integrations.wandb import WandbLoggerCallbackfrom ray.tune.schedulers import ASHASchedulerexcept ImportError:# 若导入失败,则抛出 ModuleNotFoundError 异常raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')try:# 尝试导入 wandb 模块,并确保其有 __version__ 属性import wandbassert hasattr(wandb, "__version__")except (ImportError, AssertionError):# 若导入失败或缺少 __version__ 属性,则将 wandb 设为 Falsewandb = False# 使用 checks 模块检查 ray 的版本是否符合要求checks.check_version(ray.__version__, ">=2.0.0", "ray")# 默认的超参数搜索空间,包含各种超参数的取值范围或概率分布default_space = {# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),"lr0": tune.uniform(1e-5, 1e-1),  # 初始学习率范围"lrf": tune.uniform(0.01, 1.0),  # 最终OneCycleLR学习率 (lr0 * lrf)"momentum": tune.uniform(0.6, 0.98),  # SGD动量/Adam beta1"weight_decay": tune.uniform(0.0, 0.001),  # 优化器权重衰减"warmup_epochs": tune.uniform(0.0, 5.0),  # 预热 epochs 数量(允许小数)"warmup_momentum": tune.uniform(0.0, 0.95),  # 预热初始动量"box": tune.uniform(0.02, 0.2),  # 盒损失增益"cls": tune.uniform(0.2, 4.0),  # 分类损失增益(与像素缩放比例相关)"hsv_h": tune.uniform(0.0, 0.1),  # 图像HSV-Hue增强(比例)"hsv_s": tune.uniform(0.0, 0.9),  # 图像HSV-Saturation增强(比例)"hsv_v": tune.uniform(0.0, 0.9),  # 图像HSV-Value增强(比例)"degrees": tune.uniform(0.0, 45.0),  # 图像旋转角度范围(+/- 度)"translate": tune.uniform(0.0, 0.9),  # 图像平移范围(+/- 比例)"scale": tune.uniform(0.0, 0.9),  # 图像缩放范围(+/- 增益)"shear": tune.uniform(0.0, 10.0),  # 图像剪切范围(+/- 度)"perspective": tune.uniform(0.0, 0.001),  # 图像透视变换范围(+/- 比例),范围 0-0.001"flipud": tune.uniform(0.0, 1.0),  # 图像上下翻转概率"fliplr": tune.uniform(0.0, 1.0),  # 图像左右翻转概率"bgr": tune.uniform(0.0, 1.0),  # 图像通道BGR转换概率"mosaic": tune.uniform(0.0, 1.0),  # 图像混合(mosaic)概率"mixup": tune.uniform(0.0, 1.0),  # 图像混合(mixup)概率"copy_paste": tune.uniform(0.0, 1.0),  # 分割复制粘贴概率}# 将模型放入Ray存储task = model.taskmodel_in_store = ray.put(model)def _tune(config):"""使用指定的超参数和额外的参数训练YOLO模型。Args:config (dict): 用于训练的超参数字典。Returns:dict: 训练结果字典。"""model_to_train = ray.get(model_in_store)  # 从Ray存储中获取要调整的模型model_to_train.reset_callbacks()  # 重置回调函数config.update(train_args)  # 更新训练参数results = model_to_train.train(**config)  # 使用给定的超参数进行训练return results.results_dict  # 返回训练结果字典# 获取搜索空间if not space:space = default_space  # 如果未提供搜索空间,则使用默认搜索空间LOGGER.warning("WARNING ⚠️ 没有提供搜索空间,使用默认搜索空间。")# 获取数据集data = train_args.get("data", TASK2DATA[task])  # 获取数据集space["data"] = data  # 将数据集添加到搜索空间if "data" not in train_args:LOGGER.warning(f'WARNING ⚠️ 没有提供数据集,使用默认数据集 "data={data}".')# 定义带有分配资源的可训练函数trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})# 定义 ASHA 调度器用于超参数搜索asha_scheduler = ASHAScheduler(time_attr="epoch",  # 指定时间属性为 epochmetric=TASK2METRIC[task],  # 设置优化的指标为任务对应的指标mode="max",  # 设定优化模式为最大化max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,  # 设置最大的训练轮数grace_period=grace_period,  # 设置优雅期间reduction_factor=3,  # 设置收敛因子)# 定义用于超参数搜索的回调函数tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []# 创建 Ray Tune 的超参数搜索调谐器tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve()  # 获取保存调优结果的绝对路径,确保目录存在tune_dir.mkdir(parents=True, exist_ok=True)  # 创建保存调优结果的目录,如果不存在则创建tuner = tune.Tuner(trainable_with_resources,  # 可训练函数param_space=space,  # 参数空间tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples),  # 调优配置run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir),  # 运行配置)# 运行超参数搜索tuner.fit()# 返回超参数搜索的结果return tuner.get_results()

.\yolov8\ultralytics\utils\__init__.py

# Ultralytics YOLO 🚀, AGPL-3.0 licenseimport contextlib
import importlib.metadata
import inspect
import logging.config
import os
import platform
import re
import subprocess
import sys
import threading
import time
import urllib
import uuid
from pathlib import Path
from types import SimpleNamespace
from typing import Unionimport cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
from tqdm import tqdm as tqdm_originalfrom ultralytics import __version__# PyTorch Multi-GPU DDP Constants
RANK = int(os.getenv("RANK", -1))  # 获取环境变量 RANK 的整数值,默认为 -1
LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))  # 获取环境变量 LOCAL_RANK 的整数值,默认为 -1,用于 PyTorch Elastic运行# Other Constants
ARGV = sys.argv or ["", ""]  # 获取命令行参数列表,若为空则初始化为包含两个空字符串的列表
FILE = Path(__file__).resolve()  # 获取当前脚本文件的绝对路径
ROOT = FILE.parents[1]  # 获取当前脚本文件的父目录的父目录路径,即 YOLO 的根目录
ASSETS = ROOT / "assets"  # 默认图像文件目录的路径
DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"  # 默认配置文件路径
NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # YOLO 多进程线程数,至少为 1,最多为 CPU 核心数减 1
AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true"  # 全局自动安装模式,默认为 True
VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true"  # 全局详细模式,默认为 True
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None  # tqdm 进度条显示格式,如果详细模式开启则使用指定格式,否则为 None
LOGGING_NAME = "ultralytics"  # 日志记录器名称
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"])  # 操作系统类型的布尔值
ARM64 = platform.machine() in {"arm64", "aarch64"}  # ARM64 架构的布尔值
PYTHON_VERSION = platform.python_version()  # Python 版本号
TORCH_VERSION = torch.__version__  # PyTorch 版本号
TORCHVISION_VERSION = importlib.metadata.version("torchvision")  # torchvision 的版本号,比直接导入速度更快
HELP_MSG = """Usage examples for running YOLOv8:1. Install the ultralytics package:pip install ultralytics2. Use the Python SDK:from ultralytics import YOLO# Load a modelmodel = YOLO('yolov8n.yaml')  # 从头开始构建一个新的模型model = YOLO("yolov8n.pt")  # 加载预训练模型(推荐用于训练)# Use the modelresults = model.train(data="coco8.yaml", epochs=3)  # 训练模型results = model.val()  # 在验证集上评估模型性能results = model('https://ultralytics.com/images/bus.jpg')  # 对图像进行预测success = model.export(format='onnx')  # 将模型导出为 ONNX 格式
"""# 使用命令行界面 (CLI):YOLOv8 的 'yolo' CLI 命令遵循以下语法:yolo TASK MODE ARGS其中   TASK (可选) 可以是 [detect, segment, classify] 中的一个MODE (必需) 可以是 [train, val, predict, export] 中的一个ARGS (可选) 是任意数量的自定义 'arg=value' 对,如 'imgsz=320',用于覆盖默认设置。可以在 https://docs.ultralytics.com/usage/cfg 或通过 'yolo cfg' 查看所有 ARGS。- 训练一个检测模型,使用 coco8.yaml 数据集,模型为 yolov8n.pt,训练 10 个 epochs,初始学习率为 0.01yolo detect train data=coco8.yaml model=yolov8n.pt epochs=10 lr0=0.01- 使用预训练的分割模型预测 YouTube 视频,图像尺寸为 320:yolo segment predict model=yolov8n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320- 在批量大小为 1 和图像尺寸为 640 的情况下,验证预训练的检测模型:yolo detect val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=640- 将 YOLOv8n 分类模型导出为 ONNX 格式,图像尺寸为 224x128 (不需要 TASK 参数)yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128- 运行特殊命令:yolo helpyolo checksyolo versionyolo settingsyolo copy-cfgyolo cfg文档链接:https://docs.ultralytics.com社区链接:https://community.ultralytics.comGitHub 仓库链接:https://github.com/ultralytics/ultralytics
# 设置和环境变量# 设置 Torch 的打印选项,包括行宽、精度和默认配置文件
torch.set_printoptions(linewidth=320, precision=4, profile="default")# 设置 NumPy 的打印选项,包括行宽和浮点数格式
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format})  # format short g, %precision=5# 设置 OpenCV 的线程数为 0,以防止与 PyTorch DataLoader 的多线程不兼容
cv2.setNumThreads(0)# 设置 NumExpr 的最大线程数为 NUM_THREADS
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS)# 设置 CUBLAS 的工作空间配置为 ":4096:8",用于确定性训练
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"# 设置 TensorFlow 的最小日志级别为 "3",以在 Colab 中抑制冗长的 TF 编译器警告
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"# 设置 Torch 的 C++ 日志级别为 "ERROR",以抑制 "NNPACK.cpp could not initialize NNPACK" 警告
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"# 设置 Kineto 的日志级别为 "5",以在计算 FLOPs 时抑制冗长的 PyTorch 分析器输出
os.environ["KINETO_LOG_LEVEL"] = "5"class TQDM(tqdm_original):"""自定义的 Ultralytics tqdm 类,具有不同的默认参数设置。Args:*args (list): 传递给原始 tqdm 的位置参数。**kwargs (any): 关键字参数,应用自定义默认值。"""def __init__(self, *args, **kwargs):"""初始化自定义的 Ultralytics tqdm 类,具有不同的默认参数设置。注意,这些参数在调用 TQDM 时仍然可以被覆盖。"""kwargs["disable"] = not VERBOSE or kwargs.get("disable", False)  # 逻辑 'and' 操作,使用默认值(如果传递)。kwargs.setdefault("bar_format", TQDM_BAR_FORMAT)  # 如果传递,则覆盖默认值。super().__init__(*args, **kwargs)class SimpleClass:"""Ultralytics SimpleClass 是一个基类,提供更简单的调试和使用方法,包括有用的字符串表示、错误报告和属性访问方法。"""def __str__(self):"""返回对象的人类可读字符串表示形式。"""attr = []for a in dir(self):v = getattr(self, a)if not callable(v) and not a.startswith("_"):if isinstance(v, SimpleClass):# 对于子类,仅显示模块和类名s = f"{a}: {v.__module__}.{v.__class__.__name__} object"else:s = f"{a}: {repr(v)}"attr.append(s)return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)def __repr__(self):"""返回对象的机器可读字符串表示形式。"""return self.__str__()def __getattr__(self, attr):"""自定义属性访问错误消息,提供有用的信息。"""name = self.__class__.__name__raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")class IterableSimpleNamespace(SimpleNamespace):"""Ultralytics IterableSimpleNamespace 是 SimpleNamespace 的扩展类,添加了可迭代功能并支持与 dict() 和 for 循环一起使用。"""# 返回一个迭代器,迭代命名空间的属性键值对def __iter__(self):"""Return an iterator of key-value pairs from the namespace's attributes."""return iter(vars(self).items())# 返回对象的可读字符串表示,每行显示属性名=属性值def __str__(self):"""Return a human-readable string representation of the object."""return "\n".join(f"{k}={v}" for k, v in vars(self).items())# 自定义属性访问错误消息,提供帮助信息def __getattr__(self, attr):"""Custom attribute access error message with helpful information."""name = self.__class__.__name__raise AttributeError(f"""'{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace{DEFAULT_CFG_PATH} with the latest version fromhttps://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml""")# 返回指定键的值,如果不存在则返回默认值def get(self, key, default=None):"""Return the value of the specified key if it exists; otherwise, return the default value."""return getattr(self, key, default)
# 定义一个函数 plt_settings,用作装饰器,临时设置绘图函数的 rc 参数和后端def plt_settings(rcparams=None, backend="Agg"):"""Decorator to temporarily set rc parameters and the backend for a plotting function.Example:decorator: @plt_settings({"font.size": 12})context manager: with plt_settings({"font.size": 12}):Args:rcparams (dict): Dictionary of rc parameters to set.backend (str, optional): Name of the backend to use. Defaults to 'Agg'.Returns:(Callable): Decorated function with temporarily set rc parameters and backend. This decorator can beapplied to any function that needs to have specific matplotlib rc parameters and backend for its execution."""# 如果未提供 rcparams,则使用默认字典设置 {"font.size": 11}if rcparams is None:rcparams = {"font.size": 11}# 定义装饰器函数 decoratordef decorator(func):"""Decorator to apply temporary rc parameters and backend to a function."""# 定义 wrapper 函数,用于设置 rc 参数和后端,调用原始函数,并恢复设置def wrapper(*args, **kwargs):"""Sets rc parameters and backend, calls the original function, and restores the settings."""# 获取当前的后端original_backend = plt.get_backend()# 如果指定的后端与当前不同,则关闭所有图形,并切换到指定的后端if backend.lower() != original_backend.lower():plt.close("all")  # auto-close()ing of figures upon backend switching is deprecated since 3.8plt.switch_backend(backend)# 使用指定的 rc 参数上下文管理器with plt.rc_context(rcparams):result = func(*args, **kwargs)# 如果使用了不同的后端,则关闭所有图形,并恢复原始后端if backend != original_backend:plt.close("all")plt.switch_backend(original_backend)return resultreturn wrapperreturn decorator# 定义函数 set_logging,为给定名称设置日志记录,支持 UTF-8 编码,并确保在不同环境中的兼容性
def set_logging(name="LOGGING_NAME", verbose=True):"""Sets up logging for the given name with UTF-8 encoding support, ensuring compatibility across differentenvironments."""# 根据 verbose 参数设置日志级别,多 GPU 训练中考虑到 RANK 的情况level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR  # rank in world for Multi-GPU trainings# 配置控制台(stdout)的编码为 UTF-8,以确保兼容性formatter = logging.Formatter("%(message)s")  # Default formatter# 如果在 Windows 环境下,并且 sys.stdout 具有 "encoding" 属性,并且编码不是 "utf-8"if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8":# 定义一个定制的日志格式化器 CustomFormatterclass CustomFormatter(logging.Formatter):def format(self, record):"""Sets up logging with UTF-8 encoding and configurable verbosity."""# 返回格式化后的日志记录,包括表情符号处理return emojis(super().format(record))try:# 尝试重新配置 stdout 使用 UTF-8 编码(如果支持)if hasattr(sys.stdout, "reconfigure"):sys.stdout.reconfigure(encoding="utf-8")# 对于不支持 reconfigure 的环境,用 TextIOWrapper 包装 stdoutelif hasattr(sys.stdout, "buffer"):import iosys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")else:# 创建自定义格式化器以应对非 UTF-8 环境formatter = CustomFormatter("%(message)s")except Exception as e:# 如果出现异常,创建适应非 UTF-8 环境的自定义格式化器print(f"Creating custom formatter for non UTF-8 environments due to {e}")formatter = CustomFormatter("%(message)s")# 创建并配置流处理器 StreamHandler,使用适当的格式化器和日志级别stream_handler = logging.StreamHandler(sys.stdout)stream_handler.setFormatter(formatter)stream_handler.setLevel(level)# 设置日志记录器 loggerlogger = logging.getLogger(name)logger.setLevel(level)logger.addHandler(stream_handler)# 禁止日志传播到父记录器logger.propagate = False# 返回配置好的 logger 对象return logger
# 设置日志记录器
LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE)  # 在全局定义日志记录器,用于 train.py、val.py、predict.py 等
# 设置 "sentry_sdk" 和 "urllib3.connectionpool" 的日志级别为 CRITICAL + 1
for logger in "sentry_sdk", "urllib3.connectionpool":logging.getLogger(logger).setLevel(logging.CRITICAL + 1)def emojis(string=""):"""返回与平台相关的安全的字符串版本,支持表情符号。"""return string.encode().decode("ascii", "ignore") if WINDOWS else stringclass ThreadingLocked:"""用于确保函数或方法的线程安全执行的装饰器类。可以作为装饰器使用,以确保如果从多个线程调用装饰的函数,则只有一个线程能够执行该函数。Attributes:lock (threading.Lock): 管理对装饰函数访问的锁对象。Example:```pyfrom ultralytics.utils import ThreadingLocked@ThreadingLocked()def my_function():# 在此处编写代码```"""def __init__(self):"""初始化装饰器类,用于函数或方法的线程安全执行。"""self.lock = threading.Lock()def __call__(self, f):"""执行函数或方法的线程安全装饰器。"""from functools import wraps@wraps(f)def decorated(*args, **kwargs):"""应用线程安全性到装饰的函数或方法。"""with self.lock:return f(*args, **kwargs)return decorateddef yaml_save(file="data.yaml", data=None, header=""):"""将数据保存为 YAML 格式到文件中。Args:file (str, optional): 文件名。默认为 'data.yaml'。data (dict): 要保存的数据,以 YAML 格式。header (str, optional): 要添加的 YAML 头部。Returns:(None): 数据保存到指定文件中。"""if data is None:data = {}file = Path(file)if not file.parent.exists():# 如果父目录不存在,则创建父目录file.parent.mkdir(parents=True, exist_ok=True)# 将路径对象转换为字符串valid_types = int, float, str, bool, list, tuple, dict, type(None)for k, v in data.items():if not isinstance(v, valid_types):data[k] = str(v)# 将数据以 YAML 格式写入文件with open(file, "w", errors="ignore", encoding="utf-8") as f:if header:f.write(header)yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True)def yaml_load(file="data.yaml", append_filename=False):"""从文件中加载 YAML 格式的数据。Args:file (str, optional): 文件名。默认为 'data.yaml'。append_filename (bool): 是否将 YAML 文件名添加到 YAML 字典中。默认为 False。Returns:(dict): YAML 格式的数据和文件名。"""assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()"# 使用指定的编码打开文件,忽略解码错误,返回文件对象 fwith open(file, errors="ignore", encoding="utf-8") as f:s = f.read()  # 将文件内容读取为字符串 s# 如果字符串 s 中存在不可打印字符,则通过正则表达式去除特殊字符if not s.isprintable():s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)# 使用 yaml.safe_load() 加载字符串 s,转换为 Python 字典;若文件为空则返回空字典data = yaml.safe_load(s) or {}# 如果 append_filename 为真,则将文件名以字符串形式添加到字典 data 中if append_filename:data["yaml_file"] = str(file)# 返回加载后的数据字典 datareturn data
# 漂亮地打印一个 YAML 文件或 YAML 格式的字典
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:# 如果 yaml_file 是字符串或 Path 对象,使用 yaml_load 加载 YAML 文件yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file# 将 yaml_dict 转换成 YAML 格式的字符串,不排序键,允许 Unicode,宽度为无限大dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=float("inf"))# 记录打印信息到日志,包含文件名(加粗黑色),以及 YAML 格式的字符串LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")# 默认配置字典,从 DEFAULT_CFG_PATH 加载 YAML 文件得到
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
# 将值为字符串且为 "None" 的项改为 None
for k, v in DEFAULT_CFG_DICT.items():if isinstance(v, str) and v.lower() == "none":DEFAULT_CFG_DICT[k] = None
# 获取默认配置字典的键集合
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
# 用 DEFAULT_CFG_DICT 创建一个 IterableSimpleNamespace 对象作为默认配置
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)def read_device_model() -> str:"""从系统中读取设备型号信息,并缓存以便快速访问。被 is_jetson() 和 is_raspberrypi() 使用。Returns:(str): 如果成功读取,返回设备型号文件内容,否则返回空字符串。"""# 尝试打开 "/proc/device-tree/model" 文件,读取并返回其内容with contextlib.suppress(Exception):with open("/proc/device-tree/model") as f:return f.read()# 如果发生异常(如文件不存在),返回空字符串return ""def is_ubuntu() -> bool:"""检查当前操作系统是否为 Ubuntu。Returns:(bool): 如果操作系统是 Ubuntu,则返回 True,否则返回 False。"""# 尝试打开 "/etc/os-release" 文件,检查是否包含 "ID=ubuntu" 的行with contextlib.suppress(FileNotFoundError):with open("/etc/os-release") as f:return "ID=ubuntu" in f.read()# 如果文件未找到或未包含 "ID=ubuntu" 行,则返回 Falsereturn Falsedef is_colab():"""检查当前脚本是否运行在 Google Colab 笔记本环境中。Returns:(bool): 如果运行在 Colab 笔记本中,则返回 True,否则返回 False。"""# 检查环境变量中是否包含 "COLAB_RELEASE_TAG" 或 "COLAB_BACKEND_VERSION"return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environdef is_kaggle():"""检查当前脚本是否运行在 Kaggle 内核中。Returns:(bool): 如果运行在 Kaggle 内核中,则返回 True,否则返回 False。"""# 检查当前工作目录是否为 "/kaggle/working",并且检查 KAGGLE_URL_BASE 是否为 "https://www.kaggle.com"return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"def is_jupyter():"""检查当前脚本是否运行在 Jupyter Notebook 中。在 Colab、Jupyterlab、Kaggle、Paperspace 环境中验证通过。Returns:(bool): 如果运行在 Jupyter Notebook 中,则返回 True,否则返回 False。"""# 尝试导入 IPython 中的 get_ipython 函数,如果导入成功并返回不为 None,则说明在 Jupyter 环境中with contextlib.suppress(Exception):from IPython import get_ipythonreturn get_ipython() is not None# 如果导入过程中发生异常,则返回 Falsereturn Falsedef is_docker() -> bool:"""判断当前脚本是否运行在 Docker 容器中。Returns:(bool): 如果运行在 Docker 容器中,则返回 True,否则返回 False。"""# 尝试打开 "/proc/self/cgroup" 文件,检查是否包含 "docker" 字符串with contextlib.suppress(Exception):with open("/proc/self/cgroup") as f:return "docker" in f.read()# 如果发生异常(如文件不存在),返回 Falsereturn Falsedef is_raspberrypi() -> bool:"""判断当前设备是否为 Raspberry Pi。Returns:(bool): 如果设备为 Raspberry Pi,则返回 True,否则返回 False。"""# 检查当前 Python 环境是否运行在树莓派上,通过检查设备模型信息# 返回值为布尔类型:如果在树莓派上运行则返回 True,否则返回 False"""检查 Python 环境的设备模型信息中是否包含"Raspberry Pi""""return "Raspberry Pi" in PROC_DEVICE_MODEL
def is_jetson() -> bool:"""Determines if the Python environment is running on a Jetson Nano or Jetson Orin device by checking the device modelinformation.Returns:(bool): True if running on a Jetson Nano or Jetson Orin, False otherwise."""return "NVIDIA" in PROC_DEVICE_MODEL  # 检查 PROC_DEVICE_MODEL 是否包含 "NVIDIA",表示运行在 Jetson Nano 或 Jetson Orindef is_online() -> bool:"""Check internet connectivity by attempting to connect to a known online host.Returns:(bool): True if connection is successful, False otherwise."""with contextlib.suppress(Exception):assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true"  # 检查环境变量 YOLO_OFFLINE 是否为 "True"import socketfor dns in ("1.1.1.1", "8.8.8.8"):  # 检查 Cloudflare 和 Google DNSsocket.create_connection(address=(dns, 80), timeout=2.0).close()return Truereturn Falsedef is_pip_package(filepath: str = __name__) -> bool:"""Determines if the file at the given filepath is part of a pip package.Args:filepath (str): The filepath to check.Returns:(bool): True if the file is part of a pip package, False otherwise."""import importlib.util# 获取模块的规范spec = importlib.util.find_spec(filepath)# 返回规范不为 None 且 origin 不为 None(表示是一个包)return spec is not None and spec.origin is not Nonedef is_dir_writeable(dir_path: Union[str, Path]) -> bool:"""Check if a directory is writeable.Args:dir_path (str | Path): The path to the directory.Returns:(bool): True if the directory is writeable, False otherwise."""return os.access(str(dir_path), os.W_OK)def is_pytest_running():"""Determines whether pytest is currently running or not.Returns:(bool): True if pytest is running, False otherwise."""return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)def is_github_action_running() -> bool:"""Determine if the current environment is a GitHub Actions runner.Returns:(bool): True if the current environment is a GitHub Actions runner, False otherwise."""return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environdef get_git_dir():"""Determines whether the current file is part of a git repository and if so, returns the repository root directory. Ifthe current file is not part of a git repository, returns None.Returns:(Path | None): Git root directory if found or None if not found."""for d in Path(__file__).parents:if (d / ".git").is_dir():return ddef is_git_dir():"""Determines whether the current file is part of a git repository. If the current file is not part of a gitrepository, returns False.Returns:(bool): True if the current file is part of a git repository, False otherwise."""for d in Path(__file__).parents:if (d / ".git").is_dir():return Truereturn False# 检查全局变量 GIT_DIR 是否为 Nonereturn GIT_DIR is not None
def get_git_origin_url():"""Retrieves the origin URL of a git repository.Returns:(str | None): The origin URL of the git repository or None if not git directory."""# 检查当前是否在 Git 仓库中if IS_GIT_DIR:# 使用 subprocess 模块调用 git 命令获取远程仓库的 URLwith contextlib.suppress(subprocess.CalledProcessError):origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])# 将获取到的字节流解码成字符串,并去除首尾的空白字符return origin.decode().strip()def get_git_branch():"""Returns the current git branch name. If not in a git repository, returns None.Returns:(str | None): The current git branch name or None if not a git directory."""# 检查当前是否在 Git 仓库中if IS_GIT_DIR:# 使用 subprocess 模块调用 git 命令获取当前分支名with contextlib.suppress(subprocess.CalledProcessError):origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])# 将获取到的字节流解码成字符串,并去除首尾的空白字符return origin.decode().strip()def get_default_args(func):"""Returns a dictionary of default arguments for a function.Args:func (callable): The function to inspect.Returns:(dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter."""# 使用 inspect 模块获取函数的签名信息signature = inspect.signature(func)# 构建并返回参数名与默认参数值组成的字典return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}def get_ubuntu_version():"""Retrieve the Ubuntu version if the OS is Ubuntu.Returns:(str): Ubuntu version or None if not an Ubuntu OS."""# 检查当前操作系统是否为 Ubuntuif is_ubuntu():# 尝试打开 /etc/os-release 文件并匹配版本号信息with contextlib.suppress(FileNotFoundError, AttributeError):with open("/etc/os-release") as f:# 使用正则表达式搜索并提取版本号return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]def get_user_config_dir(sub_dir="Ultralytics"):"""Return the appropriate config directory based on the environment operating system.Args:sub_dir (str): The name of the subdirectory to create.Returns:(Path): The path to the user config directory."""# 根据不同的操作系统返回相应的用户配置目录if WINDOWS:path = Path.home() / "AppData" / "Roaming" / sub_direlif MACOS:  # macOSpath = Path.home() / "Library" / "Application Support" / sub_direlif LINUX:path = Path.home() / ".config" / sub_direlse:# 如果不支持当前操作系统,抛出 ValueError 异常raise ValueError(f"Unsupported operating system: {platform.system()}")# 对于 GCP 和 AWS Lambda,只有 /tmp 目录可写入if not is_dir_writeable(path.parent):# 如果父目录不可写,输出警告信息并使用备选路径 /tmp 或当前工作目录LOGGER.warning(f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD.""Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path.")path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir# 如果路径不存在,则创建相应的子目录path.mkdir(parents=True, exist_ok=True)return path# Define constants (required below)
PROC_DEVICE_MODEL = read_device_model()  # is_jetson() and is_raspberrypi() depend on this constant
# 检查当前是否在线
ONLINE = is_online()# 检查当前是否在Google Colab环境中
IS_COLAB = is_colab()# 检查当前是否在Docker容器中
IS_DOCKER = is_docker()# 检查当前是否在NVIDIA Jetson设备上
IS_JETSON = is_jetson()# 检查当前是否在Jupyter环境中
IS_JUPYTER = is_jupyter()# 检查当前是否在Kaggle环境中
IS_KAGGLE = is_kaggle()# 检查当前代码是否安装为pip包
IS_PIP_PACKAGE = is_pip_package()# 检查当前是否在树莓派环境中
IS_RASPBERRYPI = is_raspberrypi()# 获取当前Git仓库的目录
GIT_DIR = get_git_dir()# 检查当前目录是否是Git仓库
IS_GIT_DIR = is_git_dir()# 获取用户配置目录,默认为环境变量YOLO_CONFIG_DIR,或者使用系统默认配置目录
USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir())  # Ultralytics settings dir# 设置配置文件路径为用户配置目录下的settings.yaml文件
SETTINGS_YAML = USER_CONFIG_DIR / "settings.yaml"Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager.Examples:As a decorator:>>> @TryExcept(msg="Error occurred in func", verbose=True)>>> def func():>>>    # Function logic here>>>     passAs a context manager:>>> with TryExcept(msg="Error occurred in block", verbose=True):>>>     # Code block here>>>     pass"""# 定义 TryExcept 类,用于处理异常,可以作为装饰器或上下文管理器使用def __init__(self, msg="", verbose=True):"""Initialize TryExcept class with optional message and verbosity settings."""self.msg = msgself.verbose = verbose# 初始化 TryExcept 类,可以设置错误消息和详细输出选项def __enter__(self):"""Executes when entering TryExcept context, initializes instance."""# 进入上下文管理器时执行的方法,初始化实例passdef __exit__(self, exc_type, value, traceback):"""Defines behavior when exiting a 'with' block, prints error message if necessary."""# 定义退出上下文管理器时的行为,如果需要,打印错误消息if self.verbose and value:print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))# 如果设置了详细输出并且出现了异常,打印错误消息return True# 返回 True 表示已经处理了异常,不会向上层抛出异常
class Retry(contextlib.ContextDecorator):"""Retry class for function execution with exponential backoff.Can be used as a decorator to retry a function on exceptions, up to a specified number of times with anexponentially increasing delay between retries.Examples:Example usage as a decorator:>>> @Retry(times=3, delay=2)>>> def test_func():>>>     # Replace with function logic that may raise exceptions>>>     return True"""def __init__(self, times=3, delay=2):"""Initialize Retry class with specified number of retries and delay."""self.times = times  # 设置重试次数self.delay = delay  # 设置初始重试延迟时间self._attempts = 0  # 记录当前重试次数def __call__(self, func):"""Decorator implementation for Retry with exponential backoff."""def wrapped_func(*args, **kwargs):"""Applies retries to the decorated function or method."""self._attempts = 0  # 重试次数初始化为0while self._attempts < self.times:  # 循环直到达到最大重试次数try:return func(*args, **kwargs)  # 调用被装饰的函数或方法except Exception as e:self._attempts += 1  # 增加重试次数计数print(f"Retry {self._attempts}/{self.times} failed: {e}")  # 打印重试失败信息if self._attempts >= self.times:  # 如果达到最大重试次数raise e  # 抛出异常time.sleep(self.delay * (2**self._attempts))  # 按指数增加的延迟时间return wrapped_funcdef threaded(func):"""Multi-threads a target function by default and returns the thread or function result.Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed."""def wrapper(*args, **kwargs):"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""if kwargs.pop("threaded", True):  # 如果未指定 'threaded' 参数或为 Truethread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)  # 创建线程对象thread.start()  # 启动线程return thread  # 返回线程对象else:return func(*args, **kwargs)  # 直接调用函数或方法return wrapperdef set_sentry():"""Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed andsync=True in settings. Run 'yolo settings' to see and update settings YAML file.Conditions required to send errors (ALL conditions must be met or no errors will be reported):- sentry_sdk package is installed- sync=True in YOLO settings- pytest is not running- running in a pip package installation- running in a non-git directory- running with rank -1 or 0- online environment- CLI used to run package (checked with 'yolo' as the name of the main CLI command)The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundErrorexceptions and to exclude events with 'out of memory' in their exception message."""# 设置 Sentry 事件的自定义标签和用户信息def before_send(event, hint):"""根据特定的异常类型和消息修改事件,然后发送给 Sentry。Args:event (dict): 包含错误信息的事件字典。hint (dict): 包含额外错误信息的字典。Returns:dict: 修改后的事件字典,如果不发送事件到 Sentry 则返回 None。"""# 如果 hint 中包含异常信息if "exc_info" in hint:exc_type, exc_value, tb = hint["exc_info"]# 如果异常类型是 KeyboardInterrupt、FileNotFoundError,或者异常消息包含"out of memory"if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):return None  # 不发送事件# 设置事件的标签event["tags"] = {"sys_argv": ARGV[0],  # 系统参数列表中的第一个参数"sys_argv_name": Path(ARGV[0]).name,  # 第一个参数的文件名部分"install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other",  # 安装来源是 git、pip 还是其他"os": ENVIRONMENT,  # 环境变量中的操作系统信息}return event  # 返回修改后的事件字典# 如果满足一系列条件,则配置 Sentryif (SETTINGS["sync"]  # 同步设置为 Trueand RANK in {-1, 0}  # 运行时的进程等级为 -1 或 0and Path(ARGV[0]).name == "yolo"  # 第一个参数的文件名为 "yolo"and not TESTS_RUNNING  # 没有正在运行的测试and ONLINE  # 处于联机状态and IS_PIP_PACKAGE  # 安装方式为 pipand not IS_GIT_DIR  # 不是从 git 安装):# 如果 sentry_sdk 包未安装,则返回try:import sentry_sdk  # 导入 sentry_sdk 包except ImportError:return# 初始化 Sentry SDKsentry_sdk.init(dsn="https://5ff1556b71594bfea135ff0203a0d290@o4504521589325824.ingest.sentry.io/4504521592406016",  # Sentry 项目的 DSNdebug=False,  # 调试模式设为 Falsetraces_sample_rate=1.0,  # 所有跟踪数据采样率设为 100%release=__version__,  # 使用当前应用版本号environment="production",  # 环境设置为生产环境before_send=before_send,  # 设置发送前的处理函数为 before_sendignore_errors=[KeyboardInterrupt, FileNotFoundError],  # 忽略的错误类型列表)# 设置 Sentry 用户信息,使用 SHA-256 匿名化的 UUID 哈希sentry_sdk.set_user({"id": SETTINGS["uuid"]})
class SettingsManager(dict):"""Manages Ultralytics settings stored in a YAML file.Args:file (str | Path): Path to the Ultralytics settings YAML file. Default is USER_CONFIG_DIR / 'settings.yaml'.version (str): Settings version. In case of local version mismatch, new default settings will be saved."""def __init__(self, file=SETTINGS_YAML, version="0.0.4"):"""Initialize the SettingsManager with default settings, load and validate current settings from the YAMLfile."""import copy  # 导入用于深拷贝对象的模块import hashlib  # 导入用于哈希计算的模块from ultralytics.utils.checks import check_version  # 导入版本检查函数from ultralytics.utils.torch_utils import torch_distributed_zero_first  # 导入分布式训练相关的函数root = GIT_DIR or Path()  # 设置根目录为环境变量GIT_DIR的值或当前路径datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve()  # 设置数据集根目录路径self.file = Path(file)  # 将传入的文件路径转换为Path对象self.version = version  # 设置版本号self.defaults = {  # 设置默认配置字典"settings_version": version,"datasets_dir": str(datasets_root / "datasets"),"weights_dir": str(root / "weights"),"runs_dir": str(root / "runs"),"uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),  # 计算当前机器的唯一标识符"sync": True,"api_key": "","openai_api_key": "","clearml": True,  # 各种集成配置"comet": True,"dvc": True,"hub": True,"mlflow": True,"neptune": True,"raytune": True,"tensorboard": True,"wandb": True,}self.help_msg = (f"\nView settings with 'yolo settings' or at '{self.file}'""\nUpdate settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. ""For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.")super().__init__(copy.deepcopy(self.defaults))  # 调用父类构造函数并使用深拷贝的默认配置with torch_distributed_zero_first(RANK):  # 在分布式环境中,仅主节点执行以下操作if not self.file.exists():  # 如果配置文件不存在,则保存默认配置self.save()self.load()  # 载入配置文件中的设置correct_keys = self.keys() == self.defaults.keys()  # 检查载入的配置键与默认配置键是否一致correct_types = all(type(a) is type(b) for a, b in zip(self.values(), self.defaults.values()))  # 检查各项设置的类型是否与默认设置一致correct_version = check_version(self["settings_version"], self.version)  # 检查配置文件中的版本与当前版本是否一致if not (correct_keys and correct_types and correct_version):LOGGER.warning("WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem "f"with your settings or a recent ultralytics package update. {self.help_msg}")self.reset()  # 将设置重置为默认值if self.get("datasets_dir") == self.get("runs_dir"):  # 如果数据集目录与运行目录相同,则发出警告LOGGER.warning(f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' "f"must be different than 'runs_dir: {self.get('runs_dir')}'. "f"Please change one to avoid possible issues during training. {self.help_msg}")def load(self):"""Loads settings from the YAML file."""  # 从YAML文件加载设置super().update(yaml_load(self.file))  # 调用父类方法,使用yaml_load函数加载配置文件内容def save(self):"""将当前设置保存到YAML文件中。"""# 调用yaml_save函数,将当前对象转换为字典并保存到文件中yaml_save(self.file, dict(self))def update(self, *args, **kwargs):"""更新当前设置中的一个设置值。"""# 遍历关键字参数kwargs,检查每个设置项的有效性for k, v in kwargs.items():# 如果设置项不在默认设置中,则引发KeyError异常if k not in self.defaults:raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}")# 获取默认设置项k的类型t = type(self.defaults[k])# 如果传入的值v不是预期的类型t,则引发TypeError异常if not isinstance(v, t):raise TypeError(f"Ultralytics setting '{k}' must be of type '{t}', not '{type(v)}'. {self.help_msg}")# 调用父类的update方法,更新设置项super().update(*args, **kwargs)# 更新后立即保存设置到文件self.save()def reset(self):"""将设置重置为默认值并保存。"""# 清空当前设置self.clear()# 使用默认设置更新当前设置self.update(self.defaults)# 保存更新后的设置到文件self.save()
# 发出弃用警告的函数,用于提示已弃用的参数,建议使用更新的参数
def deprecation_warn(arg, new_arg):# 使用 LOGGER 发出警告消息,指出已弃用的参数和建议的新参数LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in in the future. " f"Please use '{new_arg}' instead.")# 清理 URL,去除授权信息,例如 https://url.com/file.txt?auth -> https://url.com/file.txt
def clean_url(url):# 使用 Path 对象将 URL 转换成 POSIX 格式的字符串,并替换 Windows 下的 "://" 为 "://"url = Path(url).as_posix().replace(":/", "://")  # Pathlib turns :// -> :/, as_posix() for Windows# 对 URL 解码并按 "?" 进行分割,保留问号前的部分作为最终的清理后的 URLreturn urllib.parse.unquote(url).split("?")[0]  # '%2F' to '/', split https://url.com/file.txt?auth# 将 URL 转换为文件名,例如 https://url.com/file.txt?auth -> file.txt
def url2file(url):# 清理 URL 后使用 Path 对象获取文件名部分作为结果return Path(clean_url(url)).name# 在 utils 初始化过程中运行以下代码 ------------------------------------------------------------------------------------# 检查首次安装步骤
PREFIX = colorstr("Ultralytics: ")  # 设置日志前缀
SETTINGS = SettingsManager()  # 初始化设置管理器
DATASETS_DIR = Path(SETTINGS["datasets_dir"])  # 全局数据集目录
WEIGHTS_DIR = Path(SETTINGS["weights_dir"])  # 全局权重目录
RUNS_DIR = Path(SETTINGS["runs_dir"])  # 全局运行目录
# 确定当前环境,根据不同情况设置 ENVIRONMENT 变量
ENVIRONMENT = ("Colab"if IS_COLABelse "Kaggle"if IS_KAGGLEelse "Jupyter"if IS_JUPYTERelse "Docker"if IS_DOCKERelse platform.system()
)
TESTS_RUNNING = is_pytest_running() or is_github_action_running()  # 检查是否正在运行测试
set_sentry()  # 初始化 Sentry 错误监控# 应用 Monkey Patch
from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_savetorch.load = torch_load  # 覆盖默认的 torch.load 函数
torch.save = torch_save  # 覆盖默认的 torch.save 函数
if WINDOWS:# 对于 Windows 平台,应用 cv2 的补丁以支持图像路径中的非 ASCII 和非 UTF 字符cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow

.\yolov8\ultralytics\__init__.py

# Ultralytics YOLO 🚀, AGPL-3.0 license# 定义模块版本号
__version__ = "8.2.69"# 导入操作系统模块
import os# 设置环境变量(放置在导入语句之前)
# 设置 OpenMP 线程数为 1,以减少训练过程中的 CPU 使用率
os.environ["OMP_NUM_THREADS"] = "1"# 从 ultralytics.data.explorer.explorer 模块中导入 Explorer 类
from ultralytics.data.explorer.explorer import Explorer
# 从 ultralytics.models 模块中导入 NAS、RTDETR、SAM、YOLO、FastSAM、YOLOWorld 类
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
# 从 ultralytics.utils 模块中导入 ASSETS 和 SETTINGS
from ultralytics.utils import ASSETS, SETTINGS
# 从 ultralytics.utils.checks 模块中导入 check_yolo 函数,并将其命名为 checks
from ultralytics.utils.checks import check_yolo as checks
# 从 ultralytics.utils.downloads 模块中导入 download 函数
from ultralytics.utils.downloads import download# 将 SETTINGS 赋值给 settings 变量
settings = SETTINGS# 定义 __all__ 变量,包含了可以通过 `from package import *` 导入的名字
__all__ = ("__version__","ASSETS","YOLO","YOLOWorld","NAS","SAM","FastSAM","RTDETR","checks","download","settings","Explorer",
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/792523.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Yolov8-源码解析-四十三-

Yolov8 源码解析(四十三) .\yolov8\ultralytics\utils\patches.py # Ultralytics YOLO 🚀, AGPL-3.0 license """Monkey patches to update/extend functionality of existing functions."""import time from pathlib import Pathimport cv2…

Yolov8-源码解析-四十二-

Yolov8 源码解析(四十二) .\yolov8\ultralytics\utils\loss.py # 导入PyTorch库中需要的模块 import torch import torch.nn as nn import torch.nn.functional as F# 从Ultralytics工具包中导入一些特定的功能 from ultralytics.utils.metrics import OKS_SIGMA from ultral…

Yolov8-源码解析-二十六-

Yolov8 源码解析(二十六) .\yolov8\tests\test_engine.py # 导入所需的模块和库 import sys # 系统模块 from unittest import mock # 导入 mock 模块# 导入自定义模块和类 from tests import MODEL # 导入 tests 模块中的 MODEL 对象 from ultralytics import YOLO # 导…

Yolov8-源码解析-二十八-

Yolov8 源码解析(二十八) .\yolov8\ultralytics\data\base.py # Ultralytics YOLO 🚀, AGPL-3.0 licenseimport glob # 导入用于获取文件路径的模块 import math # 导入数学函数模块 import os # 导入操作系统功能模块 import random # 导入生成随机数的模块 from copy…

Yolov8-源码解析-八-

Yolov8 源码解析(八)comments: true description: Learn how to manage and optimize queues using Ultralytics YOLOv8 to reduce wait times and increase efficiency in various real-world applications. keywords: queue management, YOLOv8, Ultralytics, reduce wait …

FLUX 源码解析(全)

.\flux\demo_gr.py # 导入操作系统相关模块 import os # 导入时间相关模块 import time # 从 io 模块导入 BytesIO 类 from io import BytesIO # 导入 UUID 生成模块 import uuid# 导入 PyTorch 库 import torch # 导入 Gradio 库 import gradio as gr # 导入 NumPy 库 import …

【优技教育】Oracle 19c OCP 082题库(第13题)- 2024年修正版

【优技教育】Oracle 19c OCP 082题库(Q 13题)- 2024年修正版 考试科目:1Z0-082 考试题量:90 通过分数:60% 考试时间:150min 本文为(CUUG 原创)整理并解析,转发请注明出处,禁止抄袭及未经注明出处的转载。 原文地址:http://www.cuug.com.cn/ocp/082kaoshitiku/3817564823…

最快捷查看电脑启动项内容

很多人好奇很多电脑的默认启动项从哪里的看,其实就在运行窗口开两个命令就行了。 第一个,看先用户端设置的启动项: shell:Startup 这个是针对当前登录用户的。 第二个,查看电脑最高权限的通用启动项shell:Common Startup 这个是针对所有用户的。 操作的方式很简单 就是把要…

react中使用echarts关系图

一,工作需求,展示几类数据关系,可缩放大小,可拖拽位置,在节点之间的连线上展示相关日期,每个节点展示本身信息,并且要求每个关系节点能点击。 实现情况如图所示:二,实现过程中遇到的问题: 关系图完美呈现,但关系节点点击后,整个关系图会杂乱无章的浮动,导致不知道…

易基因:中国农大田见晖教授团队揭示DNA甲基化保护早期胚胎线粒体基因组稳定性|项目文章

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 在早期哺乳动物胚胎中,线粒体氧化代谢增强是着床后生存和发育的重要特征;着床前期的线粒体重塑是正常胚胎发生的关键事件。在这些变化中,氧化磷酸化(OXPHOS)增强对于支持着床后胚胎的高能量需求至关重要,…

WebShell流量特征检测_哥斯拉篇

80后用菜刀,90后用蚁剑,95后用冰蝎和哥斯拉,以phpshell连接为例,本文主要是对这四款经典的webshell管理工具进行流量分析和检测。 什么是一句话木马? 1、定义 顾名思义就是执行恶意指令的木马,通过技术手段上传到指定服务器并可以正常访问,将我们需要服务器执行的命令上…

IPv6基于策略的地址分配

IPv6基于策略的地址分配 RA的周期性发送使用的是组播方式,但是针对 RS 的回复使用组播和单播两种可能;如果 RA 都是以组播方式发送,那么同一个广播域下的所有终端都可以收到,如果要基于终端mac/link-local地址来控制分配策略,则应该使用单播方式回复,以限制RA被接收的范围…