YOLOv5目标检测学习(5):源码解析之:推理部分dectet.py

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • 一、导入相关包与路径、模块配置
    • 1.1 导入相关的python包
    • 1.2 获取当前文件的相对路径
    • 1.3 加载自定义模块
    • 1.4 总结
  • 二、执行主体的main函数
    • 所以执行推理代码,核心就是两个函数:pares_opt()函数和run()函数
  • 三、pares_opt()函数
    • 3.1 参数设置部分
    • 3.2 py语法部分
    • 3.3 opt变量
  • 四、run()函数
    • 4.1 载入参数
    • 4.2 判断source的参数及类型
    • 4.3 保存目录
    • 4.4 载入模型
    • 4.5 载入模型
    • 4.6 核心推理代码
      • 4.6.1 数据的预热
      • 4.6.2 可视化和预测结果处理
      • 4.6.3 非极大值抑制和CSV文件操作
      • 4.6.4 预测的过程
        • ①对每张图像的预测结果进行遍历处理,更新计数器并根据不同情况处理图像信息
        • ②对路径进行处理并生成保存路径和文本文件路径,输出图像尺寸信息,进行坐标归一化处理,以便后续保存图像文件、标签文件和处理边界框坐标等操作
        • ③程序遍历检测结果中的每个类别,统计每个类别的检测数量,并将类别名称和对应的检测数量添加到字符串s中,用于打印输出检测结果
      • 4.6.5 打印目标检测结果
      • 4.6.6 流式展示检测结果
      • 4.6.7 保存检测后的图像及视频流
      • 4.6.8 打印推断时间、打印结果、保存结果以及更新模型
  • 五、 对于具体推理部分run()函数的代码总结


前言

为了完成一个深度学习目标检测的全过程,会按照以下顺序进行:

①配置部分(yolov5s.yaml):用于确定yolov5的网络结构和参数

②工具部分(yolo.py、common.py):提供一些函数,用于辅助后续部分

③训练部分(train.py):首先,使用训练部分的代码(train.py)来训练模型。

④验证部分(val.py):使用验证部分的代码(val.py)来评估模型在验证数据集上的性能。

⑤推理部分(detect.py):在完成训练和验证后,可以使用推理部分的代码(detect.py)来对新的图像或视频进行目标检测。

今天要学习的是detect.py。通常这个文件是用来预测一张图片或者一个视频的,也可以预测一个图片文件夹或者是一些网络流.下载后直接运行默认是对date/images文件夹下的两张照片进行检测识别。即默认运行后得到一个带预测框的图片。
在这里插入图片描述

一、导入相关包与路径、模块配置

1.1 导入相关的python包

import argparse
import csv
import os
import platform
import sys
from pathlib import Path
import torch

这段代码是一个Python脚本的开头部分,主要包括了导入一些必要的库和模块。解释一下这些导入的内容:

argparse:argparse是Python标准库中用于解析命令行参数和选项的模块。它可以编写用户友好的命令行界面,解析命令行参数并生成帮助信息。

csv:csv是Python标准库中用于读写CSV文件(逗号分隔值文件)的模块。它提供了一种简单的方式来处理CSV文件中的数据。

os:os模块提供了与操作系统交互的功能,包括文件和目录操作、进程管理等。通过os模块,可以执行各种操作系统相关的任务。

platform:platform模块提供了访问平台特定属性(如操作系统、硬件架构等)的功能。它可以帮助编写跨平台的代码。

sys:sys模块提供了与Python解释器交互的功能。它包含了一些与Python解释器和环境相关的变量和函数。

pathlib:pathlib模块提供了一种面向对象的方式来操作文件路径和文件系统。它可以简化文件路径的处理和操作

torch:torch是PyTorch深度学习框架的主要模块。通过导入torch,可以使用PyTorch提供的各种功能和类来构建和训练深度学习模型

1.2 获取当前文件的相对路径

'''=====================2.获取当前文件的相对路径=============================='''
FILE = Path(__file__).resolve()  # __file__指的是当前文件(即detect.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/detect.py
ROOT = FILE.parents[0]  # YOLOv5 root directory  ROOT保存着当前项目的父目录,比如 D://yolov5
if str(ROOT) not in sys.path:  # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径sys.path.append(str(ROOT))  # add ROOT to PATH  就把ROOT添加到运行路径上
ROOT = Path(os.path.relpath(ROOT, Path.cwd()))  # relative ROOT设置为相对路径

这段代码主要做了以下几件事情:

1 获取文件路径和根目录:
FILE = Path(file).resolve():获取当前脚本文件的绝对路径。
ROOT = FILE.parents[0]:通过获取父目录,确定YOLOv5的根目录。

2 将根目录添加到系统路径中:
if str(ROOT) not in sys.path::检查根目录是否已经在系统路径中。
sys.path.append(str(ROOT)):如果根目录不在系统路径中,则将根目录添加到系统路径中。

3 计算相对路径:
ROOT = Path(os.path.relpath(ROOT, Path.cwd())):计算根目录相对于当前工作目录的相对路径。

这段代码的作用是将YOLOv5的根目录添加到系统路径中,以便在后续的代码中可以方便地引用根目录下的模块和文件。通过计算相对路径,可以确保在不同环境中都能正确地定位到根目录。

这里要提到相对路径与绝对路径的区别:

  1. 绝对路径: 绝对路径是从文件系统的根目录开始描述文件或目录位置的方式。它提供了完整的路径信息,包括所有父目录直到目标文件或目录的路径。
    例如,在Unix/Linux系统中,绝对路径可能类似于/home/user/documents/file.txt,而在Windows系统中可能类似于C:\Users\User\Documents\File.txt。绝对路径是唯一确定文件或目录位置的方式,不受当前工作目录的影响。

  2. 相对路径: 相对路径是相对于当前工作目录或其他参考位置描述文件或目录位置的方式。它不包含完整的路径信息,而是相对于某个基准位置的路径。
    例如,相对路径可能是…/documents/file.txt,表示目标文件在当前目录的父目录下的documents文件夹中。相对路径依赖于当前工作目录或其他参考位置,因此在不同环境中可能会有不同的解释。

1.3 加载自定义模块

'''=====================3..加载自定义模块============================='''
from ultralytics.utils.plotting import Annotator, colors, save_one_boxfrom models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import (LOGGER,Profile,check_file,check_img_size,check_imshow,check_requirements,colorstr,cv2,increment_path,non_max_suppression,print_args,scale_boxes,strip_optimizer,xyxy2xywh,
)
from utils.torch_utils import select_device, smart_inference_mode

这段代码是在Python脚本中导入了一系列自定义模块和函数,让我为您解释一下这些导入的内容:

  1. ultralytics.utils.plotting:
    从ultralytics.utils.plotting模块中导入了Annotator、colors和save_one_box等函数或类。这些函数可能用于绘制标注、处理颜色、保存检测框等可视化操作。

  2. models.common:
    从models.common模块中导入了DetectMultiBackend类。这个类可能包含了一些用于多后端检测的通用功能。

  3. utils.dataloaders:
    从utils.dataloaders模块中导入了IMG_FORMATS、VID_FORMATS、LoadImages、LoadScreenshots和LoadStreams等类或函数。这些类和函数可能用于加载图像、视频以及数据预处理。

  4. utils.general:
    从utils.general模块中导入了一系列函数和类,包括LOGGER、Profile、check_file、check_img_size、check_imshow、check_requirements等。这些函数可能用于日志记录、性能分析、文件检查、图像尺寸检查、参数打印等通用功能。

  5. utils.torch_utils:
    从utils.torch_utils模块中导入了select_device和smart_inference_mode等函数。这些函数可能用于选择设备(CPU或GPU)以及设置智能推理模式等PyTorch相关功能。

    通过导入这些模块和函数,脚本可以利用这些功能来实现目标检测中的各种操作,包括数据加载、模型推理、结果可视化等

    另外,这些包具有如下作用:
    在这里插入图片描述

1.4 总结

这段代码的作用可以总结如下:

  • 导入Python相关包:
    导入了一系列Python标准库和第三方库,包括argparse、csv、os、platform、sys、Path和torch等,用于后续代码中的各种功能和操作。

  • 获取当前文件的绝对路径:
    获取当前脚本文件的绝对路径,并确定YOLOv5的根目录。将根目录添加到系统路径中,并计算根目录相对于当前工作目录的相对路径。

  • 加载自定义的模块: 导入了一系列自定义模块和函数,包括可视化模块、通用模块、数据加载模块以及PyTorch工具模块等。

综合来看,这段代码的主要作用是准备工作,包括导入必要的库和模块、确定根目录路径以及加载自定义模块和函数,为后续的目标检测任务提供必要的基础支持和功能扩展

二、执行主体的main函数

def main(opt):"""Executes YOLOv5 model inference with given options, checking requirements before running the model."""check_requirements(ROOT / "requirements.txt", exclude=("tensorboard", "thop"))run(**vars(opt))if __name__ == "__main__":opt = parse_opt()main(opt)

main函数:

  • main(opt)函数是程序的主要执行逻辑。它执行YOLOv5模型推理,并在运行模型之前检查所需的依赖项。

  • check_requirements(ROOT / “requirements.txt”, exclude=(“tensorboard”,
    “thop”)):检查项目所需的依赖项,排除了"tensorboard"和"thop"这两个依赖。

  • run(**vars(opt)):运行模型推理,其中opt是通过parse_opt()函数解析得到的命令行参数。

程序入口:

  • if name == “main”::这是Python中的惯用写法,表示当脚本直接运行时(而不是被导入为模块时),以下代码块将被执行。
  • opt = parse_opt():调用parse_opt()函数,解析命令行参数并将其存储在opt变量中。
  • main(opt):调用main函数,传入解析后的命令行参数opt,开始执行YOLOv5模型推理。

所以执行推理代码,核心就是两个函数:pares_opt()函数和run()函数

三、pares_opt()函数

这段代码是一个Python 脚本中的一个函数,用于解析命令行参数并返回这些参数的值。
主要功能是为模型进行推理时提供参数。我们说的调参调参,调的就是这个参。

def parse_opt():"""Parses command-line arguments for YOLOv5 detection, setting inference options and model configurations."""parser = argparse.ArgumentParser()parser.add_argument("--weights", nargs="+", type=str, default=ROOT / "yolov5s.pt", help="model path or triton URL")parser.add_argument("--source", type=str, default=ROOT / "data/images", help="file/dir/URL/glob/screen/0(webcam)")parser.add_argument("--data", type=str, default=ROOT / "data/coco128.yaml", help="(optional) dataset.yaml path")parser.add_argument("--imgsz", "--img", "--img-size", nargs="+", type=int, default=[640], help="inference size h,w")parser.add_argument("--conf-thres", type=float, default=0.25, help="confidence threshold")parser.add_argument("--iou-thres", type=float, default=0.45, help="NMS IoU threshold")parser.add_argument("--max-det", type=int, default=1000, help="maximum detections per image")parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu")parser.add_argument("--view-img", action="store_true", help="show results")parser.add_argument("--save-txt", action="store_true", help="save results to *.txt")parser.add_argument("--save-csv", action="store_true", help="save results in CSV format")parser.add_argument("--save-conf", action="store_true", help="save confidences in --save-txt labels")parser.add_argument("--save-crop", action="store_true", help="save cropped prediction boxes")parser.add_argument("--nosave", action="store_true", help="do not save images/videos")parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3")parser.add_argument("--agnostic-nms", action="store_true", help="class-agnostic NMS")parser.add_argument("--augment", action="store_true", help="augmented inference")parser.add_argument("--visualize", action="store_true", help="visualize features")parser.add_argument("--update", action="store_true", help="update all models")parser.add_argument("--project", default=ROOT / "runs/detect", help="save results to project/name")parser.add_argument("--name", default="exp", help="save results to project/name")parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")parser.add_argument("--line-thickness", default=3, type=int, help="bounding box thickness (pixels)")parser.add_argument("--hide-labels", default=False, action="store_true", help="hide labels")parser.add_argument("--hide-conf", default=False, action="store_true", help="hide confidences")parser.add_argument("--half", action="store_true", help="use FP16 half-precision inference")parser.add_argument("--dnn", action="store_true", help="use OpenCV DNN for ONNX inference")parser.add_argument("--vid-stride", type=int, default=1, help="video frame-rate stride")opt = parser.parse_args()opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1  # expandprint_args(vars(opt))return opt

这段代码是一个Python 脚本中的一个函数,用于解析命令行参数并返回这些参数的值。

3.1 参数设置部分

主要功能是为模型进行推理时提供参数。下面简要解释每个参数的作用和默认值:

1. --weights:训练的权重路径,可以使用自己训练的权重,也可以使用官网提供的权重。默认官网的权重yolov5s.pt(yolov5n.ptlyolov5s.ptlyolov5m.ptlyolov5l.ptlyolov5x.pt/区别在于网络的宽度和深度以此增加)
2. --source:测试数据,可以是图片/视频路径,也可以是’0(电脑自带摄像头),也可以是rtsp等视频流,默认data/images
3. --data:配置数据文件路径,包括imagellabel/classes等信息,训练自己的文件,需要作相应更改,可以不用管
4. —imgsz:预测时网络输入图片的尺寸,默认值为[640]
5. . --conf-thres:置信度阈值,默认为0.50
6. . --iou-thres:非极大抑制时的loU阈值,默认为0.45
7. . --max-det:保留的最大检测框数量,每张图片中检测目标的个数最多为1000类
8. .–device:使用的设备,可以是cuda设备的ID(例如0、0,1,2,3)或者是
‘cpu’,默认为’0’–view-img:是否展示预测之后的图片/视频,默认False
9. --save-txt:是否将预测的框坐标以trt文件形式保存,默认False,使用–save-txt在路径
10. runsldetectlexp’labels * txt下生成每张图片预测的txt文件
11. --save-conf:是否保存检测结果的置信度到 txt文件,默认为False
12. --save-crop:是否保存裁剪预测框图片,默认为False,使用–save-crop在runs/detectlexp1lcrop/剪切类别文件夹/路径下会保存每个接下来的目标
13. --nosave:不保存图片、视频,要保存图片,不设置
14. --nosave在runs/detectlexp
/会出现预测的结果
15. --classes:仅检测指定类别,默认为None
16. . --agnostic-nms:是否使用类别不敏感的非极大抑制(即不考虑类别信息),默认为False-
17. --augment:是否使用数据增强进行推理,默认为False
18. . --visualize:是否可视化特征图,默认为False
19. --update:如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为F alse
20. --project:结果保存的项目目录路径,默认为’ROOT/runs/detect’
21. --name:结果保存的子目录名称,默认为’exp’
22. --exist-ok:是否覆盖已有结果,默认为False
23. -line-thickness:画bounding box时的线条宽度,默认为3.
24. --hide-labels:是否隐藏标签信息,默认为False
25. --hide-conf:是否隐藏置信度信息,默认为False.
26. --half:是否使用FP16半精度进行推理,默认为False
27. --dnn:是否使用OpenCV DNN进行ONNX 推理,默认为False
*

这一部分的详细参数调整,我后面会专门写一篇学习笔记。

3.2 py语法部分

  1. 解析命令行参数: 使用argparse.ArgumentParser()创建一个命令行参数解析器
    ** 添加了一系列命令行参数**,包括模型权重路径、数据源路径、数据集配置文件路径、推理图像尺寸、置信度阈值、NMS、IoU阈值等各种推理选项和模型配置参数。 设置了各个参数的默认值、类型以及帮助信息,以便用户了解每个参数的作用和用法。
  2. 解析参数并处理: 使用parser.parse_args()解析命令行参数,将解析结果存储在opt变量中
    对opt.imgsz进行了处理,如果opt.imgsz长度为1,则将其值扩展为原来的两倍,以确保推理尺寸正确。
    调用print_args(vars(opt))函数,打印解析后的参数信息,方便用户查看。
  3. 返回解析结果: 将解析后的参数opt返回给调用者,供后续程序使用

3.3 opt变量

在Python中,opt变量是一个命名空间(Namespace)对象,它包含了通过命令行解析器argparse.ArgumentParser()解析得到的命令行参数及其取值。命名空间对象是一个简单的容器,可以将多个属性(参数)存储在其中,并通过属性名(参数名)来访问和操作这些属性的取值。

因此,opt变量不是一个单一的数据类型,而是一个包含多个属性的对象。每个属性对应一个命令行参数,其取值可以是字符串、整数、浮点数等不同的数据类型,取决于参数在解析时的设置和用户输入的值。通过opt对象,程序可以方便地访问和获取各个命令行参数的取值,以便在后续的程序逻辑中使用。

四、run()函数

run()函数按照逻辑顺序可以分为载入参数、初始化配置、保存结果、加载模型、加载数据、推理部分、在终端里打印出运行的结果,这七个主要部分。

4.1 载入参数

这些参数就是上面的parse_opt()函数确定的参数。parse_opt()函数的主要作用是声明了这些参数的意义、默认取值和帮助信息,以便用户在命令行中传入相应的参数值。一旦用户在命令行中指定了这些参数的取值,argparse模块会解析这些参数,并将它们存储在一个命名空间对象(通常是opt)中。而run()函数则是载入这些确定好具体数值的参数。

weights=ROOT / "yolov5s.pt",  # model path or triton URLsource=ROOT / "data/images",  # file/dir/URL/glob/screen/0(webcam)data=ROOT / "data/coco128.yaml",  # dataset.yaml pathimgsz=(640, 640),  # inference size (height, width)conf_thres=0.25,  # confidence thresholdiou_thres=0.45,  # NMS IOU thresholdmax_det=1000,  # maximum detections per imagedevice="",  # cuda device, i.e. 0 or 0,1,2,3 or cpuview_img=False,  # show resultssave_txt=False,  # save results to *.txtsave_csv=False,  # save results in CSV formatsave_conf=False,  # save confidences in --save-txt labelssave_crop=False,  # save cropped prediction boxesnosave=False,  # do not save images/videosclasses=None,  # filter by class: --class 0, or --class 0 2 3agnostic_nms=False,  # class-agnostic NMSaugment=False,  # augmented inferencevisualize=False,  # visualize featuresupdate=False,  # update all modelsproject=ROOT / "runs/detect",  # save results to project/namename="exp",  # save results to project/nameexist_ok=False,  # existing project/name ok, do not incrementline_thickness=3,  # bounding box thickness (pixels)hide_labels=False,  # hide labelshide_conf=False,  # hide confidenceshalf=False,  # use FP16 half-precision inferencednn=False,  # use OpenCV DNN for ONNX inferencevid_stride=1,  # video frame-rate stride

4.2 判断source的参数及类型

source = str(source)save_img = not nosave and not source.endswith(".txt")  # save inference imagesis_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://"))webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)screenshot = source.lower().startswith("screen")if is_url and is_file:source = check_file(source)  # download

这段代码对之前定义的source参数进行了进一步处理和判断,解释一下这段代码的功能:

  1. 转换source为字符串:将source参数转换为字符串类型,以确保后续操作的一致性。
  2. 判断是否保存推理图像:根据条件判断,确定是否保存推理图像。条件为不禁止保存(not nosave)且source不以".txt"结尾。
  3. 判断source的类型:判断source是文件还是URL。首先检查source的后缀是否在图片格式或视频格式中,以确定是否为文件。
  4. 判断source是否以特定协议开头,如"rtsp://", “rtmp://”, “http://”,
    “https://”,以确定是否为URL。
  5. 判断source是否为数字(摄像头编号)、以".streams"结尾或是URL但不是文件。
  6. 判断是否为截图或屏幕截图:判断source是否以"screen"开头,以确定是否为屏幕截图。
  7. 处理URL和文件的情况: 如果source同时是URL和文件,则调用check_file(source)函数进行下载处理。

通过这段代码,程序根据source参数的不同情况进行了不同的处理和判断,包括确定是否保存推理图像、判断source的类型(文件、URL、摄像头等)、是否为截图或屏幕截图,以及处理URL和文件的特殊情况

4.3 保存目录

# Directoriessave_dir = increment_path(Path(project) / name, exist_ok=exist_ok)  # increment run(save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir

这段代码主要涉及目录的处理:

  1. 保存目录路径:
    save_dir是根据project和name参数构建的保存目录路径。如果exist_ok为True,则会递增命名以避免覆盖已存在的目录。
  2. 创建目录:
    根据条件判断,如果save_txt为True,则在save_dir下创建一个名为"labels"的子目录;否则直接在save_dir下创建目录。
  3. 使用mkdir(parents=True, exist_ok=True)方法创建目录,确保父目录存在且避免因目录已存在而引发异常。

通过这段代码,程序根据用户指定的project和name参数构建保存目录路径,并根据save_txt参数的取值决定是否在目录下创建特定的子目录。这样的目录处理逻辑有助于组织和保存模型推理过程中生成的结果文件

4.4 载入模型

 device = select_device(device)model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)stride, names, pt = model.stride, model.names, model.ptimgsz = check_img_size(imgsz, s=stride)  # check image size

这段代码涉及设备选择、模型初始化和图像尺寸检查:

  1. 选择设备:
    select_device(device)函数用于选择设备,根据device参数指定的值选择CUDA设备(GPU编号)或CPU设备。
  2. 初始化模型:
    使用DetectMultiBackend类初始化模型,传入模型权重路径weights、设备类型device、是否使用OpenCV
    DNN进行推理dnn、数据集配置文件路径data以及是否使用FP16半精度推理half等参数。
  3. 获取模型信息: 从初始化的模型中获取模型的步长(stride)、类别名称列表(names)和模型的pt属性。
  4. 检查图像尺寸: check_img_size(imgsz,
    s=stride)函数用于检查图像尺寸是否符合要求,根据模型的步长(stride)调整图像尺寸,以确保推理过程中输入图像的尺寸符合模型要求。

通过这段代码,程序选择设备并初始化模型,获取模型的相关信息,并根据模型的要求调整输入图像的尺寸,以确保推理过程的顺利进行.

使用DetectMultiBackend类来初始化模型,其中
weights指模型的权重路径
device指设备
dnn 指是否使用OpenCV DNN. data指数据集配置文件的路径
fp16指是否使用半精度浮点数进行推理

接着从模型中获取stride、 names和pt等参数,其中
stride指下采样率
names指模型预测的类别名称.
pt 是Pytorch模型对象

4.5 载入模型

# Dataloaderbs = 1  # batch_sizeif webcam:view_img = check_imshow(warn=True)dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)bs = len(dataset)elif screenshot:dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)else:dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)vid_path, vid_writer = [None] * bs, [None] * bs

这段代码是根据输入的source参数来判断是否是通过webcam摄像头捕捉视频流
如果是,则使用LoadStreams加载视频流
否则,使用LoadImages加载图像
如果是webcam模式,则设置cudnn.benchmark =True以加速常量图像大小的推理。bs表示batch_size(批量大小),这里是1或视频流中的帧数。vid_path和vid_writer分别是视频路径和视频编写器,初始化为长度为batch_size 的空列表。

4.6 核心推理代码

4.6.1 数据的预热

model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz))  # warmupseen, windows, dt = 0, [], (Profile(device=device), Profile(device=device), Profile(device=device))for path, im, im0s, vid_cap, s in dataset:with dt[0]:im = torch.from_numpy(im).to(model.device)im = im.half() if model.fp16 else im.float()  # uint8 to fp16/32im /= 255  # 0 - 255 to 0.0 - 1.0if len(im.shape) == 3:im = im[None]  # expand for batch dimif model.xml and im.shape[0] > 1:ims = torch.chunk(im, im.shape[0], 0)

这段代码主要涉及模型的预热(warmup)和数据处理过程,让我为您解释一下这部分代码的功能:

  1. 模型预热:
    调用model.warmup()方法对模型进行预热,传入图像尺寸参数。如果使用PyTorch或Triton推理引擎,则将图像尺寸设置为(1
    if pt or model.triton else bs, 3, *imgsz)。
  2. 初始化变量:
    初始化seen、windows和dt变量。seen用于记录已处理的数据量,windows用于存储窗口信息,dt是包含三个Profile对象的元组,用于记录时间性能信息
  3. 遍历数据集并处理数据: for循环中遍历数据集中的每个数据项,包括路径、图像数据、原始图像数据、视频捕获对象和其他信息。
    将图像数据转换为PyTorch张量(Tensor),并移动到模型所在的设备上。
    根据模型是否使用FP16半精度推理,将图像数据转换为半精度或全精度浮点数。 将像素值从0-255缩放到0.0-1.0之间。
    如果图像数据维度为3维,则扩展一个维度以匹配模型的输入要求。
    如果模型需要XML格式输入并且图像数据批量大小大于1,则对图像数据进行分块处理

通过这段代码,程序对模型进行预热操作,初始化变量用于记录处理过程中的信息,并遍历数据集中的数据项,将图像数据转换为模型可接受的格式并进行必要的处理,为后续的推理操作做好准备

4.6.2 可视化和预测结果处理

这段代码主要涉及模型推理过程中的可视化和预测结果处理,让我为您解释一下这部分代码的功能:

        with dt[1]:visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else Falseif model.xml and im.shape[0] > 1:pred = Nonefor image in ims:if pred is None:pred = model(image, augment=augment, visualize=visualize).unsqueeze(0)else:pred = torch.cat((pred, model(image, augment=augment, visualize=visualize).unsqueeze(0)), dim=0)pred = [pred, None]else:pred = model(im, augment=augment, visualize=visualize)
  1. 可视化处理:
    在dt[1]时间性能记录块中,根据条件判断,如果visualize为真,则将保存目录路径和当前图像文件名的基本名称构建为可视化路径,并确保目录存在。如果不需要可视化,则将visualize设置为False。
  2. 模型推理:
    根据条件判断,如果模型需要XML格式输入并且图像数据批量大小大于1,则对分块后的图像数据进行推理,将每个图像的预测结果存储在pred中。否则,直接对单个图像数据进行推理,将预测结果存储在pred中。

4.6.3 非极大值抑制和CSV文件操作

 # NMSwith dt[2]:pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)# Second-stage classifier (optional)# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)# Define the path for the CSV filecsv_path = save_dir / "predictions.csv"# Create or append to the CSV filedef write_to_csv(image_name, prediction, confidence):"""Writes prediction data for an image to a CSV file, appending if the file exists."""data = {"Image Name": image_name, "Prediction": prediction, "Confidence": confidence}with open(csv_path, mode="a", newline="") as f:writer = csv.DictWriter(f, fieldnames=data.keys())if not csv_path.is_file():writer.writeheader()writer.writerow(data)

这段代码主要涉及非极大值抑制(NMS)处理、CSV文件操作和写入数据到CSV文件,让我为您解释一下这部分代码的功能:

  1. 非极大值抑制(NMS)
    在dt[2]时间性能记录块中,调用non_max_suppression函数对预测结果进行非极大值抑制处理。该函数会根据置信度阈值(conf_thres)、IoU阈值(iou_thres)、类别列表(classes)、是否使用类别不可知的NMS(agnostic_nms)以及最大检测数(max_det)等参数进行NMS操作,过滤掉重叠度高的边界框。
  2. CSV文件操作: 定义了CSV文件的路径为保存目录下的"predictions.csv"
    定义了一个函数write_to_csv,用于将图像名称、预测结果和置信度写入CSV文件。如果CSV文件不存在,则会创建文件并写入表头;如果文件已存在,则会在文件末尾追加数据。
  3. 写入数据到CSV文件
    在推理结果处理后,调用write_to_csv函数将每张图像的名称、预测结果和置信度写入CSV文件中,用于记录模型的预测结果。

通过这段代码,程序对模型的预测结果进行NMS处理,过滤掉重叠的边界框;同时将每张图像的预测结果和置信度写入CSV文件中,以便后续分析和展示。

问题一:NMS非极大值抑制是什么?有什么作用?

  • 非极大值抑制(Non-Maximum
    Suppression,NMS)是一种常用的目标检测算法中的后处理技术,用于筛选和去除重叠度高的边界框,保留最具代表性的目标框。其作用主要包括以下几点:
    去除重叠框:在目标检测任务中,同一个目标可能会被多个边界框检测到,导致重叠的边界框。NMS通过保留具有最高置信度的边界框,同时抑制与其高度重叠的其他边界框,从而减少冗余检测结果
    提高检测精度:通过NMS算法,可以有效地过滤掉冗余的边界框,使得最终的检测结果更加精确和准确。只保留最具代表性的边界框,有助于提高目标检测算法的性能。
    减少误检率:NMS可以帮助减少误检率,即减少将背景区域误判为目标的情况。通过去除重叠的边界框,可以减少对同一目标的多次检测,从而降低误检率。
    提高目标定位准确性:NMS可以帮助提高目标的定位准确性,确保最终的检测结果能够准确地框出目标的位置,避免边界框之间的重叠和干扰。
    总的来说,非极大值抑制在目标检测领域起着非常重要的作用,能够帮助优化检测结果,提高检测精度和准确性,同时减少冗余信息,使得目标检测算法更加高效和可靠。

问题二:CSV文件是什么?有什么用?

  • CSV文件是一种常见的文本文件格式,其全称为逗号分隔值(Comma-Separated
    Values)。在CSV文件中,数据以逗号(或其他分隔符,如分号、制表符等)分隔的形式存储,每行代表一条记录,每个字段(列)之间用分隔符进行分隔
  • CSV文件的主要作用包括:
    数据存储和交换:CSV文件是一种简单且通用的数据存储格式,可以用于存储结构化数据,如表格数据、数据库导出数据等。它易于生成和解析,方便数据的交换和共享。
    数据导入导出:许多软件和工具支持CSV格式,可以将数据导出为CSV文件,也可以从CSV文件中导入数据。这种灵活性使得CSV文件成为数据迁移和数据备份的常用格式。
    数据处理和分析:CSV文件可以被各种数据处理工具(如Excel、Python的pandas库等)直接读取和处理,方便进行数据分析、统计和可视化操作。
    数据交换和集成:在不同系统之间进行数据交换时,CSV文件是一种常用的中间格式,可以帮助不同系统之间实现数据集成和数据共享。
  • 总的来说,CSV文件是一种简单且通用的数据存储格式,具有易读易写、易处理的特点,被广泛应用于数据存储、数据交换、数据处理和数据分析等领域。它为数据的管理和处理提供了便利,是数据处理中常用的文件格式之一。

4.6.4 预测的过程

①对每张图像的预测结果进行遍历处理,更新计数器并根据不同情况处理图像信息
        for i, det in enumerate(pred):  # per imageseen += 1if webcam:  # batch_size >= 1p, im0, frame = path[i], im0s[i].copy(), dataset.counts += f"{i}: "else:p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0)

使用enumerate函数遍历每张图像的预测结果,其中i表示索引,det表示每张图像的检测结果。
对每张图像进行处理,包括更新seen计数器,根据是否使用摄像头数据源(webcam)来确定处理方式。

  • 如果使用摄像头数据源(webcam=True),则将当前图像的路径(path[i])、原始图像(im0s[i].copy())和数据集的帧数计数(dataset.count)分别赋值给p、im0和frame变量,并更新字符串S。

  • 如果不使用摄像头数据源,则将图像路径(path)、原始图像(im0s.copy())和数据集的帧数计数(getattr(dataset, “frame”, 0))分别赋值给p、im0和frame变量。

②对路径进行处理并生成保存路径和文本文件路径,输出图像尺寸信息,进行坐标归一化处理,以便后续保存图像文件、标签文件和处理边界框坐标等操作
p = Path(p)  # to Pathsave_path = str(save_dir / p.name)  # im.jpgtxt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}")  # im.txts += "%gx%g " % im.shape[2:]  # print stringgn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwhimc = im0.copy() if save_crop else im0  # for save_cropannotator = Annotator(im0, line_width=line_thickness, example=str(names))if len(det):# Rescale boxes from img_size to im0 sizedet[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
③程序遍历检测结果中的每个类别,统计每个类别的检测数量,并将类别名称和对应的检测数量添加到字符串s中,用于打印输出检测结果
 # Print resultsfor c in det[:, 5].unique():n = (det[:, 5] == c).sum()  # detections per classs += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

4.6.5 打印目标检测结果

# Write resultsfor *xyxy, conf, cls in reversed(det):c = int(cls)  # integer classlabel = names[c] if hide_conf else f"{names[c]}"confidence = float(conf)confidence_str = f"{confidence:.2f}"if save_csv:write_to_csv(p.name, label, confidence_str)if save_txt:  # Write to filexywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywhline = (cls, *xywh, conf) if save_conf else (cls, *xywh)  # label formatwith open(f"{txt_path}.txt", "a") as f:f.write(("%g " * len(line)).rstrip() % line + "\n")if save_img or save_crop or view_img:  # Add bbox to imagec = int(cls)  # integer classlabel = None if hide_labels else (names[c] if hide_conf else f"{names[c]} {conf:.2f}")annotator.box_label(xyxy, label, color=colors(c, True))if save_crop:save_one_box(xyxy, imc, file=save_dir / "crops" / names[c] / f"{p.stem}.jpg", BGR=True)

这段代码主要涉及将检测结果写入文件(包括CSV文件和文本文件)、在图像上绘制边界框和标签、保存裁剪的边界框等操作,解释一下这部分代码的功能:

  1. 写入结果: 对每个检测结果中的边界框坐标、置信度、类别进行处理,将其写入CSV文件和文本文件中。
    如果save_csv为True,则调用write_to_csv函数将图像名称、类别标签和置信度写入CSV文件。
    如果save_txt为True,则将归一化后的边界框坐标、类别、置信度写入文本文件,格式为cls, x_center,
    y_center, width, height, confidence。
  2. 绘制边界框和标签: 如果save_img、save_crop或view_img为True,则在图像上绘制边界框和标签。
    如果hide_labels为True,则不显示标签;否则根据hide_conf决定是否显示置信度。
    使用annotator.box_label函数在图像上绘制边界框和标签,颜色根据类别不同而变化。
  3. 保存裁剪的边界框: 如果save_crop为True,则将裁剪的边界框保存为单独的图像文件,文件名包括类别信息和图像名称。

通过这段代码,程序将检测结果写入文件(CSV和文本文件)、在图像上绘制边界框和标签,并保存裁剪的边界框,以便后续分析和展示

4.6.6 流式展示检测结果

 # Stream resultsim0 = annotator.result()if view_img:if platform.system() == "Linux" and p not in windows:windows.append(p)cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)  # allow window resize (Linux)cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])cv2.imshow(str(p), im0)cv2.waitKey(1)  # 1 millisecond

这段代码主要涉及流式展示检测结果,解释一下这部分代码的功能:

  1. 流式展示结果: 使用annotator.result()获取绘制了边界框和标签的图像im0。
    如果view_img为True,则将处理后的图像展示出来。
  2. 展示图像: 如果view_img为True,根据操作系统类型和窗口列表windows,判断是否需要创建新窗口并展示图像。
    在Linux系统下,如果图像路径p不在窗口列表windows中,则创建新窗口,并设置窗口属性为可调整大小和保持宽高比。
    使用OpenCV的cv2.imshow函数展示图像,窗口名称为图像路径p,图像内容为处理后的图像im0。
    使用cv2.waitKey(1)等待1毫秒,以便展示图像并等待用户操作。

通过这段代码,程序实现了对处理后的图像进行流式展示,方便用户实时查看检测结果。根据用户设置的参数,程序会在图像上绘制边界框和标签,并在图像上展示检测结果

4.6.7 保存检测后的图像及视频流

            # Save results (image with detections)if save_img:if dataset.mode == "image":cv2.imwrite(save_path, im0)else:  # 'video' or 'stream'if vid_path[i] != save_path:  # new videovid_path[i] = save_pathif isinstance(vid_writer[i], cv2.VideoWriter):vid_writer[i].release()  # release previous video writerif vid_cap:  # videofps = vid_cap.get(cv2.CAP_PROP_FPS)w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))else:  # streamfps, w, h = 30, im0.shape[1], im0.shape[0]save_path = str(Path(save_path).with_suffix(".mp4"))  # force *.mp4 suffix on results videosvid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))vid_writer[i].write(im0)

这段代码主要涉及保存带有检测结果的图像或视频流,让我为您解释一下这部分代码的功能:

  • 保存结果: 如果save_img为True,根据数据集模式(“image”、“video"或"stream”),将处理后的图像保存为文件。
    如果数据集模式为"image",直接使用cv2.imwrite保存图像到指定路径save_path。
    如果数据集模式为"video"或"stream",根据视频路径vid_path[i]和视频写入器vid_writer[i],将图像帧写入视频文件中。
    如果当前保存路径与之前不同,则更新保存路径,并根据视频捕获对象vid_cap的情况获取帧率、宽度和高度信息。
    如果之前存在视频写入器对象,先释放之前的视频写入器。
    创建新的视频写入器对象,设置视频编解码器为"mp4v",帧率为获取的帧率,宽度和高度为获取的宽度和高度信息。
    强制将结果视频文件的后缀名设置为".mp4",以确保视频文件格式正确。
    通过这段代码,程序根据数据集模式将处理后的图像保存为图像文件或视频流文件,根据用户设置的参数和视频信息,将带有检测结果的图像帧写入视频文件中

4.6.8 打印推断时间、打印结果、保存结果以及更新模型

# Print time (inference-only)LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")# Print resultst = tuple(x.t / seen * 1e3 for x in dt)  # speeds per imageLOGGER.info(f"Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}" % t)if save_txt or save_img:s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ""LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")if update:strip_optimizer(weights[0])  # update model (to fix SourceChangeWarning)

这段代码主要涉及打印推断时间、打印结果、保存结果以及更新模型,让我为您解释一下这部分代码的功能:

  1. 打印推断时间: 使用LOGGER.info打印推断时间信息,包括每个类别的检测数量和推断时间。 如果没有检测到目标,则打印"(no
    detections)",并打印推断时间(以毫秒为单位)。
  2. 打印结果: 计算每个图像的预处理时间、推断时间和NMS(非极大值抑制)时间,并打印在日志中。
    打印每个图像的预处理时间、推断时间和NMS时间,以及图像的形状信息。
  3. 保存结果: 如果需要保存文本文件或图像文件,则根据保存的文本标签文件数量和保存路径,打印结果保存信息。
  4. 更新模型:
    如果需要更新模型,则调用strip_optimizer函数来更新模型,以修复可能出现的SourceChangeWarning。

五、 对于具体推理部分run()函数的代码总结

运行函数,主要包括以下几个步骤:

  1. 载入参数:指定模型权重、数据源、推断尺寸、置信度阈值等参数。

  2. 初始化配置:根据数据源类型进行初始化配置,如判断是否为文件、URL、摄像头等。

  3. 保存结果:创建保存结果的目录,并根据需要创建标签目录。

  4. 加载模型:选择设备并加载目标检测模型。

  5. 加载数据:根据数据源类型加载数据集。 数据预热:对模型进行数据预热。

  6. 推理过程:对每个图像进行推理,包括可视化处理和结果保存。

  7. 非极大值抑制:对检测结果进行非极大值抑制处理。

  8. 推理结果处理:处理推理结果,包括保存结果到CSV文件和文本文件。

  9. 打印时间信息:打印推理时间信息和结果保存路径。

  10. 更新模型:如果需要更新模型,则调用函数进行更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/536925.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【易语言】夸克网盘通用一键转存工具

这标题很熟悉吧,没错,之前是用python写的。 然而py编译的exe好像兼容性贼差,好几个人跟我反馈闪退、卡死。 所以用易语言重写了一下。 主要更新了读取数据库链接的功能,设置好一定的时间范围,就相当于是可以每日更新链…

盘点9款AI论文写作神器,轻松写出高质量论文

0. 未来百科 未来百科,是一个全球最大的 AI 产品导航网站 —— 为发现全球优质 AI 工具而生 。目前已 聚集全球 10000优质 AI 工具产品 ,旨在帮助用户发现全球最好的 AI 工具,同时为研发 AI 垂直应用的创业公司提供展示窗口,迎接…

全国自然保护区生态功能区分布数据

自然保护区,是指对有代表性的自然生态系统、珍稀濒危野生动植物物种的天然集中分布区、有特殊意义的自然遗迹等保护对象所在的陆地、陆地水体或者海域,依法划出一定面积予以特殊保护和管理的区域。 【分级】按事权划分原则,我国自然保护区分为…

05-ESP32-S3-IDF USART

ESP32-S3 IDF USART详解 USART简介 USART是一种串行通信协议,广泛应用于微控制器和计算机之间的通信。USART支持异步和同步模式,因此它可以在没有时钟信号的情况下(异步模式)或有时钟信号的情况下(同步模式&#xff…

【LLMs+小羊驼】23.03.Vicuna: 类似GPT4的开源聊天机器人( 90%* ChatGPT Quality)

官方在线demo: https://chat.lmsys.org/ Github项目代码:https://github.com/lm-sys/FastChat 官方博客:Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality 模型下载: https://huggingface.co/lmsys/vicuna-7b-v1.5 | 所有的模…

Redirect相应重定向无法访问WEB-INF下的静态资源,可以跳到外部资源(比如www.baidu.com)

相应重定向无法访问WEB-INF目录下静态资源,WEB-INF目录下静态资源受保护。 访问外部资源 访问Servlet5.do,就跳到百度页面

【机器学习智能硬件开发全解】(三)—— 政安晨:嵌入式系统基本素养【计算机体系结构中的CPU关系】

通过上一篇文章的学习: 【机器学习智能硬件开发全解】(二)—— 政安晨:嵌入式系统基本素养【处理器原理】https://blog.csdn.net/snowdenkeke/article/details/136662796我们已经知道了CPU的设计流程和工作原理,紧接着一个新问题…

【PLC】现场总线和工业以太网汇总

1、 现场总线 1.1 什么是现场总线 1)非专业描述: 如下图:“人机界面”一般通过以太网连接“控制器(PLC)”,“控制器(PLC)”通过 “现场总线”和现场设备连接。 2)专业描述(维基百科) 现场总线…

WPS 云文档保存在本地的地址如何从c盘更改为其他盘?

程序代码园发文地址:WPS 云文档保存在本地的地址如何从c盘更改为其他盘?-程序代码园小说,Java,HTML,Java小工具,程序代码园,http://www.byqws.com/ ,WPS 云文档保存在本地的地址如何从c盘更改为其他盘?http://www.byqws.com/blog/3146.html?…

云计算 3月12号 (PEX)

什么是PXE? PXE,全名Pre-boot Execution Environment,预启动执行环境; 通过网络接口启动计算机,不依赖本地存储设备(如硬盘)或本地已安装的操作系统; 由Intel和Systemsoft公司于199…

MyBatis-Plus学习记录

目录 MyBatis-Plus快速入门 简介 快速入门 MyBatis-Plus核心功能 基于Mapper接口 CRUD 对比mybatis和mybatis-plus: CRUD方法介绍: 基于Service接口 CRUD 对比Mapper接口CRUD区别: 为什么要加强service层: 使用方式 CR…

cms垃圾回收

cms垃圾回收 CMS概述CMS收集器整体流程初始标记并发标记重新标记并发清除 CMS卡表什么是卡表(card table)什么是mod-union table CMS概述 CMS(Concurrent Mark Sweep)收集器是Java虚拟机中的一种老年代(old Generation)垃圾收集器,他主要目标是减少垃圾收集时的应用…