第99步 深度学习图像目标检测:SSDlite建模

基于WIN10的64位系统演示

一、写在前面

本期,我们继续学习深度学习图像目标检测系列,SSD(Single Shot MultiBox Detector)模型的后续版本,SSDlite模型。

二、SSDlite简介

SSDLite 是 SSD 模型的一个变种,旨在为移动设备和边缘计算设备提供更高效的目标检测。SSDLite 的主要特点是使用了轻量级的骨干网络和特定的卷积操作来减少计算复杂性,从而提高检测速度,同时在大多数情况下仍保持了较高的准确性。

以下是 SSDLite 的主要特性和组件:

(1)轻量级骨干:

SSDLite 不使用 VGG 或 ResNet 这样的重量级骨干。相反,它使用 MobileNet 作为骨干,特别是 MobileNetV2 或 MobileNetV3。这些网络使用深度可分离的卷积和其他轻量级操作来减少计算成本。

(2)深度可分离的卷积:

这是 MobileNet 的核心组件,也被用于 SSDLite。深度可分离的卷积将传统的卷积操作分解为两个较小的操作:一个深度卷积和一个点卷积,这大大减少了计算和参数数量。

(3)多尺度特征映射:

与原始的 SSD 相似,SSDLite 也从不同的层级提取特征图以检测不同大小的物体。

(4)默认框:

SSDLite 也使用默认框(或称为锚框)来进行边界框预测。

(5)单阶段检测:

与 SSD 相同,SSDLite 也是一个单阶段检测器,同时进行边界框回归和分类。

(6)损失函数:

SSDLite 使用与 SSD 相同的组合损失,包括平滑 L1 损失和交叉熵损失。

综上,SSDLite 是为了速度和效率而设计的,特别是针对计算和内存资源有限的设备。通过使用轻量级的骨干和深度可分离的卷积,它能够在减少计算负担的同时,仍然保持合理的检测准确性。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、SSDlite实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np# Function to parse XML annotations
def parse_xml(xml_path):tree = ET.parse(xml_path)root = tree.getroot()boxes = []for obj in root.findall("object"):bndbox = obj.find("bndbox")xmin = int(bndbox.find("xmin").text)ymin = int(bndbox.find("ymin").text)xmax = int(bndbox.find("xmax").text)ymax = int(bndbox.find("ymax").text)# Check if the bounding box is validif xmin < xmax and ymin < ymax:boxes.append((xmin, ymin, xmax, ymax))else:print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")return boxes# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]random.shuffle(all_images)split_idx = int(len(all_images) * split_ratio)train_images = all_images[:split_idx]val_images = all_images[split_idx:]return train_images, val_images# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):def __init__(self, image_dir, annotation_dir, image_list, transform=None):self.image_dir = image_dirself.annotation_dir = annotation_dirself.image_list = image_listself.transform = transformdef __len__(self):return len(self.image_list)def __getitem__(self, idx):image_path = os.path.join(self.image_dir, self.image_list[idx])image = Image.open(image_path).convert("RGB")xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))boxes = parse_xml(xml_path)# Check for empty bounding boxes and return Noneif len(boxes) == 0:return Noneboxes = torch.as_tensor(boxes, dtype=torch.float32)labels = torch.ones((len(boxes),), dtype=torch.int64)iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)target = {}target["boxes"] = boxestarget["labels"] = labelstarget["image_id"] = torch.tensor([idx])target["iscrowd"] = iscrowd# Apply transformationsif self.transform:image = self.transform(image)return image, target# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # Convert PIL image to tensortorchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))return tuple(zip(*batch))def get_ssdlite_model_for_finetuning(num_classes):# Load an SSDlite model with a MobileNetV3 Large backbone without pre-trained weightsmodel = ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)return model# Function to save the model
def save_model(model, path="SSDlite_mtb.pth", save_full_model=False):if save_full_model:torch.save(model, path)else:torch.save(model.state_dict(), path)print(f"Model saved to {path}")# Function to compute Intersection over Union
def compute_iou(boxA, boxB):xA = max(boxA[0], boxB[0])yA = max(boxA[1], boxB[1])xB = min(boxA[2], boxB[2])yB = min(boxA[3], boxB[3])interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)iou = interArea / float(boxAArea + boxBArea - interArea)return iou# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):batch = list(filter(lambda x: x is not None, batch))if len(batch) == 0:# Return placeholder batch if entirely emptyreturn [torch.zeros(1, 3, 224, 224)], [{}]return tuple(zip(*batch))#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):model.train()model.to(device)loss_values = []iou_values = []for epoch in range(num_epochs):epoch_loss = 0.0total_ious = 0num_boxes = 0for images, targets in train_loader:# Skip batches with placeholder dataif len(targets) == 1 and not targets[0]:continue# Skip batches with empty targetsif any(len(target["boxes"]) == 0 for target in targets):continueimages = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]loss_dict = model(images, targets)losses = sum(loss for loss in loss_dict.values())optimizer.zero_grad()losses.backward()optimizer.step()epoch_loss += losses.item()# Compute IoU for evaluationwith torch.no_grad():model.eval()predictions = model(images)for i, prediction in enumerate(predictions):pred_boxes = prediction["boxes"].cpu().numpy()true_boxes = targets[i]["boxes"].cpu().numpy()for pred_box in pred_boxes:for true_box in true_boxes:iou = compute_iou(pred_box, true_box)total_ious += iounum_boxes += 1model.train()avg_loss = epoch_loss / len(train_loader)avg_iou = total_ious / num_boxes if num_boxes != 0 else 0loss_values.append(avg_loss)iou_values.append(avg_iou)print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")# Plotting loss and IoU valuesplt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(loss_values, label="Training Loss")plt.title("Training Loss across Epochs")plt.xlabel("Epochs")plt.ylabel("Loss")plt.subplot(1, 2, 2)plt.plot(iou_values, label="IoU")plt.title("IoU across Epochs")plt.xlabel("Epochs")plt.ylabel("IoU")plt.show()# Save model after trainingsave_model(model)# Validation function
def validate_model(model, val_loader, device):model.eval()model.to(device)with torch.no_grad():for images, targets in val_loader:images = [image.to(device) for image in images]targets = [{k: v.to(device) for k, v in t.items()} for t in targets]model(images)# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"# Split data
train_images, val_images = split_data(image_dir)# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)# Model and optimizer
model = get_ssdlite_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")

需要从头训练的,就不跑了,摆烂了。

五、写在后面

目标检测模型门槛更高了,运行起来对硬件要求也很高,时间也很久,都是小时起步的。因此只是简单介绍,算是入个门了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/216150.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

再生式收音机踩坑记

下载《A Simple Regen Radio for Beginners》这篇文章也有好几年了&#xff0c;一直没有动手&#xff0c;上周末抽空做了一个&#xff0c;结果相当令人沮丧&#xff0c;一个台也收不到&#xff0c;用示波器测量三极管振荡波形&#xff0c;只有在调节再生电位器R2过程中&#xf…

百度 文心一言 sdk 试用

JMaven Central: com.baidu.aip:java-sdk (sonatype.com) Java sdk地址如上&#xff1a; 文心一言开发者 文心一言 (baidu.com) ERNIE Bot SDK https://yiyan.baidu.com/developer/doc#Fllzznonw ERNIE Bot SDK提供便捷易用的接口&#xff0c;可以调用文心一言的能力&#…

rocky8.9配置K8S集群kubernetes,centos同理

rocky8.9配置K8S集群 节点主机名IP地址mastertang1192.168.211.101node1tang2192.168.211.102node2tang3192.168.211.103 1&#xff09;准备工作 全部主机都配置静态ip vi /etc/sysconfig/network-scriptsTYPEEthernet PROXY_METHODnone BROWSER_ONLYno BOOTPROTOstatic DE…

中低压MOSFET 2N7002W 60V 300mA 双N通道 SOT-323封装

2N7002W小电流双N通道MOSFET&#xff0c;电压60V电流300mA&#xff0c;采用SOT-323封装形式。超高密度电池设计&#xff0c;适用于极低的ros (on)&#xff0c;具有导通电阻和最大直流电流能力&#xff0c;ESD保护。可应用于笔记本中的电源管理&#xff0c;电池供电系统等产品应…

Redis大key与热Key

什么是 bigkey&#xff1f; 简单来说&#xff0c;如果一个 key 对应的 value 所占用的内存比较大&#xff0c;那这个 key 就可以看作是 bigkey。具体多大才算大呢&#xff1f;有一个不是特别精确的参考标准&#xff1a; bigkey 是怎么产生的&#xff1f;有什么危害&#xff1f;…

保姆级 ARM64 CPU架构下安装部署Docker + rancher + K8S 说明文档

1 K8S 简介 K8S是Kubernetes的简称&#xff0c;是一个开源的容器编排平台&#xff0c;用于自动部署、扩展和管理“容器化&#xff08;containerized&#xff09;应用程序”的系统。它可以跨多个主机聚集在一起&#xff0c;控制和自动化应用的部署与更新。 K8S 架构 Kubernete…

Arduino库之 LedControl 库说明文档

LedControl 库最初是为基于 8 位 AVR 处理器的 Arduino 板编写的。用于通过MAX7219芯片控制LED矩阵和7段数码管。但由于该代码不使用处理器的任何复杂的内部功能&#xff0c;因此具有高度可移植性&#xff0c;并且应该在任何支持 和 功能的 Arduino&#xff08;类似&#xff09…

Vue3 设置点击后滚动条移动到固定的位置

需求&#xff1a; 点击不通过按钮&#xff0c;显示红框中表单&#xff0c;且滚动条滚动到底部 &#xff08;显示红框中表单默认不显示&#xff09; <el-button click"onApprovalPass">不通过</el-button> <div class"item" v-if"app…

杰发科技AC7801——ADC软件触发的简单使用

前言 7801资料读起来不是很好理解&#xff0c;大概率是之前MTK的大佬写的。在此以简单的方式进行描述。我们做一个简单的规则组软件触发Demo。因为规则组通道只有一个数据寄存器&#xff0c;因此还需要用上DMA方式搬运数据到内存。 AC7801的ADC简介 7801的ADC是一种 12 位 逐…

【数据库篇】关系模式的表示——(1)问题的提出

1、关系模式的表示 R&#xff1a;表示关系的名字比如&#xff1a;sc选课表&#xff0c;student学生表。 U&#xff1a;表示一个关系模式的所有属性&#xff0c;比如student表&#xff1a;U&#xff08;sno&#xff0c;sname&#xff0c;sage&#xff0c;ssex&#xff09;。 …

<蓝桥杯软件赛>零基础备赛20周--第7周--栈和二叉树

报名明年4月蓝桥杯软件赛的同学们&#xff0c;如果你是大一零基础&#xff0c;目前懵懂中&#xff0c;不知该怎么办&#xff0c;可以看看本博客系列&#xff1a;备赛20周合集 20周的完整安排请点击&#xff1a;20周计划 每周发1个博客&#xff0c;共20周&#xff08;读者可以按…

qt实现播放视屏的时候,加载外挂字幕(.srt文件解析)

之前用qt写了一个在windows下播放视频的软件&#xff0c;具体介绍参见qt编写的视频播放器&#xff0c;windows下使用&#xff0c;精致小巧_GreenHandBruce的博客-CSDN博客 后来发现有些视频没有内嵌字幕&#xff0c;需要外挂字幕&#xff0c;这时候&#xff0c;我就想着把加载…