【目标检测从零开始】torch实现yolov3数据加载

文章目录

  • 数据简介
  • Dataset读取
    • Step1:类别定义
    • Step2:解析xml
    • Step3:实现Dataset
    • Step4:数据增强
    • Step5:添加dataset_collate
    • Step6:测试
  • 小结

数据简介

  • 林业病虫害防治项目用到的AI识虫数据集,该数据集提供了2183张图片,其中训练集1693张,验证集245,测试集245张。下载地址

  • 图片和标签示例如下:

# 根据坐标把框画到图上
import xml.etree.ElementTree as ET
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import osdef read_xml(xml_path):tree = ET.parse(xml_path)root = tree.getroot()boxes = []for obj in root.findall('object'):bbox = obj.find('bndbox')xmin = int(bbox.find('xmin').text)ymin = int(bbox.find('ymin').text)xmax = int(bbox.find('xmax').text)ymax = int(bbox.find('ymax').text)# Read class labelclass_label = obj.find('name').textboxes.append((xmin, ymin, xmax, ymax, class_label))return boxesdef visualize_boxes(image_path, boxes):# Read the image using OpenCVimage = cv2.imread(image_path)# Convert BGR image to RGBimage_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# Create figure and axesfig, ax = plt.subplots(1)# Display the imageax.imshow(image_rgb)# Add bounding boxes to the imagefor box in boxes:xmin, ymin, xmax, ymax, class_label = boxrect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='g', facecolor='none')ax.add_patch(rect)# Display class labelplt.text(xmin, ymin, class_label, color='r', fontsize=8, bbox=dict(facecolor='white', alpha=0.7))# Set the title as the file nameplt.title(os.path.splitext(os.path.basename(image_path))[0])# Show the plotplt.show()if __name__ == "__main__":xml_folder = r"D:\work\data\insects\train\annotations\xmls"image_folder = r"D:\work\data\insects\train\images"# Specify the file name of the image you want to visualizeimage_file_name = "1.jpeg"xml_file = os.path.join(xml_folder, os.path.splitext(image_file_name)[0] + ".xml")image_path = os.path.join(image_folder, image_file_name)boxes = read_xml(xml_file)visualize_boxes(image_path, boxes)

Dataset读取

继承torch.utils.Dataset类来读取数据集,在getitem函数中返回图片、框坐标、框类别,主要分为以下步骤:

Step1:类别定义

  • 定义数据集的路径、类别

    DATA_ROOT = r'D:\work\data\insects'
    CATEGORY_NAMES = ['Boerner', 'Leconte', 'Linnaeus','acuminatus', 'armandi', 'coleoptera', 'linnaeus']
    # 根据类名返回对应的id
    def get_insect_names():insect_category2id = {}for i, item in enumerate(CATEGORY_NAMES):insect_category2id[item] = ireturn insect_category2idCATEGORY_NAME_ID = get_insect_names()
    NUM_CLASSES = len(CATEGORY_NAMES)
    

Step2:解析xml

  • 解析xml文件,获取框的位置、类别

  • 框坐标从xyxy改成了xywh

    import xml.etree.ElementTree as ET
    import os
    import numpy as npdef read_xml(xml_path):"""解析xml文件,返回坐标和类别信息:param xml_path::return:"""tree = ET.parse(xml_path)root = tree.getroot()fname = os.path.basename(xml_path).split()[0]objs = tree.findall('object')# 存框坐标和类别gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)gt_class = np.zeros((len(objs),), dtype=np.int32)difficult = np.zeros((len(objs),), dtype=np.int32)for i, obj in enumerate(root.findall('object')):bbox = obj.find('bndbox')xmin = int(bbox.find('xmin').text)ymin = int(bbox.find('ymin').text)xmax = int(bbox.find('xmax').text)ymax = int(bbox.find('ymax').text)_difficult = int(obj.find('difficult').text)cname = obj.find('name').text# 直接改成 xywh格式gt_bbox[i] = [(xmin + xmax) / 2.0, (ymin + ymax) / 2.0, ymax - ymin + 1., ymax - ymin + 1.]gt_class[i] = CATEGORY_NAME_ID[cname]difficult[i] = _difficultrecord = {'fname': fname,'gt_bbox': gt_bbox,'gt_class': gt_class,'difficult': difficult}return record
    

Step3:实现Dataset

  • 继承torch.nn.Dataset,定义InsectDataset类,包含 init/getitem/lenget_annotations四个方法

    • init():定义数据集路径、数据增强等参数
    • **len():**数据集数量
    • get_annotations():将Step2中解析出来的xml结果包裹起来,获取所有框
    • get_item():读取records,根据idx拿到对应图片的框(同时将框改成相对坐标)

    returns: image, gt_boxes, labels

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoaderclass InsectDataset(Dataset):""":returns img, gt_boxes, labelsimg: tensorgt_boxes: list 框的相对位置labels: list   框的标签"""def __init__(self, datadir, mode='train', transforms=None):super(InsectDataset, self).__init__()self.datadir = os.path.join(datadir, mode)self.records = self.get_annotations()self.transforms = transformsdef __getitem__(self, idx):record = self.records[idx]gt_boxes = record['gt_bbox']labels = record['gt_class']image = np.array(Image.open(record['im_file']))w = image.shape[0]h = image.shape[1]# gt_bbox 用相对值gt_boxes[:, 0] = gt_boxes[:, 0] / float(w)gt_boxes[:, 1] = gt_boxes[:, 1] / float(h)gt_boxes[:, 2] = gt_boxes[:, 2] / float(w)gt_boxes[:, 3] = gt_boxes[:, 3] / float(h)if self.transforms:transformed = self.transforms(image=image, bboxes=gt_boxes, class_labels=labels)image = transformed['image']gt_boxes = np.array(transformed['bboxes'])labels = np.array(transformed['class_labels'])image = image.transpose((2,1,0)) # h,w,c -> c,w,hreturn image, gt_boxes, labelsdef __len__(self):return len(self.records)def get_annotations(self):"""从xml目录下面读取所有文件的标注信息:param cname2cid::param datadir::return: record:[{im_file:    arraygt_boxes:   arraygt_classes: arraydifficult:  array}]"""datadir = self.datadirfilenames = os.listdir(os.path.join(datadir, 'annotations', 'xmls'))records = []for fname in filenames:# 拿到文件名fid = fname.split('.')[0]fpath = os.path.join(datadir, 'annotations', 'xmls', fname)img_file = os.path.join(datadir, 'images', fid + '.jpeg')# 解析xml文件record = read_xml(fpath)record['im_file'] = img_file  # 把图片路径加上records.append(record)return records

Step4:数据增强

  • 这里采用albumentations进行数据增强,参考官网的目标检测数据增强教程即可,这里加入normalize、resize以及一些常见的数据增强策略,后续完善

    import albumentations as Atransforms = A.Compose([# A.RandomCrop(width=450, height=450),A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0),A.Resize(width=640, height=640),A.HorizontalFlip(p=0.5),A.RandomBrightnessContrast(p=0.2),
    ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
    
  • 在调用的时候注意框坐标的format,这里统一用yolo格式(xywh相对坐标)

    if self.transforms:transformed = self.transforms(image=image, bboxes=gt_boxes, class_labels=labels)image = transformed['image']gt_boxes = np.array(transformed['bboxes'])labels = np.array(transformed['class_labels'])
    

Step5:添加dataset_collate

由于不同图片的框数量不同,在用dataloader加载数据的时候,getitem的返回值shape不同会报错,因此用一个list包裹起来

def dataset_collate(batch):"""用list包一下 img, bboxes, labels:param batch::return:"""images = []bboxes = []labels = []for img, box, label in batch:images.append(img)bboxes.append(box)labels.append(label)images = torch.tensor(np.array(images))return images, bboxes, labels

Step6:测试

  • 测试Dataset的getitem函数以及用Dataloader加载后能否正常读取
if __name__ == '__main__':dataset = InsectDataset(DATA_ROOT, transforms=transforms)print(dataset.__len__())print('image_shape: ', dataset.__getitem__(1)[0].shape)batch_size = 4print()train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0,collate_fn=dataset_collate)for inputs in train_loader:print('img_shape:', inputs[0].shape)print('gt_boxes:', inputs[1])print('gt_labels:', inputs[2])

小结

  • 读取voc格式的数据集主要以下三个点需要注意一下

    • 解析xml文件,获取关键的框坐标和类别信息,并非所有信息都有作用
    • 弄清楚数据集格式是xyxy还是xywh,是相对坐标还是绝对坐标(既然要做数据增强变换图像大小,那相对坐标更方便)
    • 用Dataloader读取的时候每个图片的框数量不一样,加上dataset_collate用list包裹一下。
  • 把画框的代码单独放在一个文件里,但其中read_xml的方法跟dataset中类似,框架搭好之后进一步优化一下

  • 如果是anchor base的模型后续还需要根据锚框来处理得到每个锚框的objectness和坐标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/263412.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Base64编码解码

一、Base64编码技术简介 Base64编码是一种广泛应用于网络传输和数据存储的编码方式。它将原始数据转换为可打印的字符形式,以便于传输和存储。Base64编码后的数据长度是原始数据长度的约3/4,具有一定的压缩效果。 Base64编码解码 -- 一个覆盖广泛主题工…

【蜗牛到家】获南明电子信息产业引导基金战略投资

智慧社区生活服务平台「蜗牛到家」已于近期获得贵阳南明电子信息产业引导基金、华科明德战略投资。 贵阳南明电子信息产业引导基金属于政府旗下产业引导基金,贵州华科明德基金管理有限公司擅长电子信息产业、高科技产业、城市建设及民生保障领域的投资,双…

主窗体、QFile、编码转换、事件、禁止输入特殊字符

主窗体 部件构成 菜单栏、工具栏、主窗体、状态栏。 UI 编辑器设计主窗体 💡 简易记事本的实现(part 1) 菜单栏 工具栏(图标) 主窗体 完善菜单栏: mainwindow.cpp #include "mainwindow.h"…

《PySpark大数据分析实战》-01.关于数据

📋 博主简介 💖 作者简介:大家好,我是wux_labs。😜 热衷于各种主流技术,热爱数据科学、机器学习、云计算、人工智能。 通过了TiDB数据库专员(PCTA)、TiDB数据库专家(PCTP…

PHP 二维码内容解析、二维码识别

目录 1.首先是一些错误的示例 2.正确示例 3.二维码解析 4.完整示例,含生成 5.代码执行结果 6.参考文档 1.首先是一些错误的示例 本示例使用的是php7.3 通过搜索各种结果逐个尝试以后,得出一个可使用版本 解析错误经历:vendor核心报错 …

[C++]:10.vector使用

vector使用 一.vector使用1.构造函数:2.迭代器遍历数据:3.空间问题:1.size():返回有效数据个数:2.capacity():返回容量大小:3.容量检测:4.emptr():判断顺序表是否为空:5.…

Linux6-配置网络、源码包的编译和安装

配置 linux 网络 配置主机名 修改/etc/hostname 配置文件,永久配置主机名 [rootlocalhost ~]# vim /etc/hostname svr7.tedu.cn [rootlocalhost ~]# cat /etc/hostname svr7.tedu.cn [rootlocalhost ~]# reboot #重启生效命令行永久修改主机名 [rootlocalhost ~…

Vue3使用Tailwind CSS

安装 Tailwind 以及其它依赖项 npm install -D tailwindcsslatest postcsslatest autoprefixerlatest生成配置文件: npx tailwindcss init -p.修改配置文件 tailwind.config.js 2.6版本 : module.exports {purge: [./index.html, ./src/**/*.{vue,j…

arm-none-eabi-gcc not find

解决办法:安装:gcc-arm-none-eabi sudo apt install gcc-arm-none-eabi; 如果上边解决问题了就不用管了,如果解决不了,加上下面这句试试运气: $ sudo apt-get install lsb-core看吧方正我是运气还不错,感…

call,apply,bind

1.这三个方法都能改变this的指向 2.代码实战 let obj1 {name: "小红",age: 20,fn: function () {console.log(当前this的指向,this);console.log(我叫${this.name},今年${this.age}岁);},};obj1.fn(); 这里的代码,obj1是一个对象,里面有属性name和age 正常情况下我…

计数排序详解

前言:这篇文章会给大家把计数排序安排的明明白白,详细的讲解计数排序的原理 例子:现在我有一个数组不知道里面到底有多少个元素,但是我要把它进行排序,怎么排序呢? 我先随便拿一个数组(你假装你…

Java JMM

JMM 全称: Java Memory Model (Java 内存模式)。 它是一种虚拟机规范, 用于屏蔽掉各种硬件和操作系统的内存访问差异, 以实现 Java 程序在各种平台下都能达到一致的并发效果。 主要规定了以下两点 一个线程如何以及何时可以看到其他线程修改过后的共享变量的值, 即线程之间共享…