iou的cpu和gpu源码实现

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

简介

IoU(Intersection over Union)是一种测量在特定数据集中检测相应物体准确度的一个标准,通常用于目标检测中预测框(bounding box)之间准确度的一个度量(预测框和实际目标框)。
在这里插入图片描述

IoU计算的是“预测的边框”和“真实的边框”的交叠率,即它们的交集和并集的比值。最理想情况是完全重叠,即比值为1。

IoU的计算方法如下:

计算两个框的交集面积,即两个框的左、上、右、下四个点的交集。
计算两个框的并集面积,即两个框的左、上、右、下四个点的并集。
计算交集面积和并集面积的比值,即为 IoU 值。
IoU的优点是可以反映预测检测框与真实检测框的检测效果,并且具有尺度不变性,即对尺度不敏感。但是,IoU也存在一些缺点,例如无法反映两个框之间的距离大小(重合度),如果两个框没有相交,则 IoU 值为 0,无法进行学习训练。

源码实现:

cpu版源码实现:

def iou_core(box1: Tensor, box2: Tensor, area_sum: Tensor):overlap_w = torch.min(box1[2],box2[2]) - torch.max(box1[0],box2[0])overlap_h = torch.min(box1[3],box2[3]) - torch.max(box1[1],box2[1])if overlap_w <= 0 or overlap_h <= 0:return 0overlap_area = overlap_h * overlap_wreturn overlap_area / (area_sum - overlap_area)def iou_cpu(box1: Tensor, box2: Tensor):box1_num = box1.size(0)box2_num = box2.size(0)box1_dim = box1.size(1)box2_dim = box2.size(1)if box1_dim != 4 or box2_dim != 4:return -1box1_area = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])box2_area = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])result = torch.zeros(size=(box1_num, box2_num))for i in range(box1_num):for j in range(box2_num):if box1_area[i] >= 0 and box2_area[j] >= 0:result[i, j] = iou_core(box1[i], box2[j], box1_area[i] + box2_area[j])else:result[i, j] = 9999return result

gpu版源码实现:

__device__ float iou_core(const float* box1 ,const float* box2){float box1_x0 = *(box1 + 0);float box1_y0 = *(box1 + 1);float box1_x1 = *(box1 + 2);float box1_y1 = *(box1 + 3);float box2_x0 = *(box2 + 0);float box2_y0 = *(box2 + 1);float box2_x1 = *(box2 + 2);float box2_y1 = *(box2 + 3);if(!(box1_x0 < box1_x1 && box1_y0 < box1_y1 && box2_x0 < box2_x1 && box2_y0 < box2_y1)){return 9999;}float inter_x0 = std::max(box1_x0, box2_x0);float inter_x1 = std::min(box1_x1, box2_x1);float inter_y0 = std::max(box1_y0, box2_y0);float inter_y1 = std::min(box1_y1, box2_y1);float inter_area = (inter_x1 - inter_x0)*(inter_y1-inter_y0);inter_area = std::max(inter_area, 0.0f);float box1_area = (box1_x1 - box1_x0)*(box1_y1-box1_y0);float box2_area = (box2_x1 - box2_x0)*(box2_y1-box2_y0);float iou = inter_area / (box1_area + box2_area - inter_area);printf("iou =%f\n",iou);return iou;
}__global__ void iou_gpu_kernel(const int box1_num,
const float* box1_ptr,
const int box2_num,
const float* box2_ptr,
float* result_ptr){const int box1_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;const int box2_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;printf("gpu: box1_idx = %d, box2_idy= %d\n",box1_idx,box2_idx);if(box1_idx>=box1_num || box2_idx>=box2_num){return;}printf("gpu: box1_idx = %d, box2_idy= %d, result_id= %d\n",box1_idx,box2_idx,box1_idx * box2_num + box2_idx);const float* box1 = box1_ptr + box1_idx * 4;const float* box2 = box2_ptr + box2_idx * 4;float iou = iou_core(box1, box2);*(result_ptr + box1_idx * box2_num + box2_idx) = iou;
}void iou_gpu_launch(const int box1_num,
const float* box1_ptr,
const int box2_num,
const float* box2_ptr,
float* result_ptr){dim3 blocks(DIVUP(box1_num, THREADS_PER_BLOCK),DIVUP(box2_num, THREADS_PER_BLOCK));//每个grid的blocksdim3 threads(THREADS_PER_BLOCK,THREADS_PER_BLOCK);//每个block里面的threadprintf("blocks=(%d %d), threads=(%d %d)\n",DIVUP(box1_num, THREADS_PER_BLOCK),DIVUP(box2_num, THREADS_PER_BLOCK),THREADS_PER_BLOCK,THREADS_PER_BLOCK);iou_gpu_kernel<<<blocks,threads>>>(box1_num,box1_ptr,box2_num,box2_ptr,result_ptr);cudaDeviceSynchronize();// waiting for gpu workprintf("gpu done\n");
}

耗时测试:

import torch
from iou import iou_gpu, iou_cpu
from utils import TicTocdevice = torch.device('cuda:0')
input1 = torch.Tensor([[0, 0, 1, 1],[0, 2, 1, 3],[0.2, 0, 1, 1],[0.1, 2, 1, 3],[0.11, 0, 1, 1],[0, 2.4, 1, 3],[0.2, 0.1, 1, 1],[0.7, 2.5, 1, 3],[0, 0, 6, 1],[1.5, 2, 1, 3]]).to(device)
input2 = torch.Tensor([[0.5, 0, 1.5, 1],[0, 0.5, 1, 1.5],[0.5, 0.5, 1.5, 1.5],[0, 0.5, 1, 2.5]]).to(device)tictic = TicToc('iou fun')
for i in range(1000):result = iou_gpu(input1, input2)
tictic.toc()
tictic.tic()
for i in range(1000):result2 = iou_cpu(input1.to('cpu'), input2.to('cpu'))
tictic.toc()
pass

具体流程说明:

IoU的计算方法如下:
计算两个框的交集面积,即两个框的左、上、右、下四个点的交集。
计算两个框的并集面积,即两个框的左、上、右、下四个点的并集。
计算交集面积和并集面积的比值,即为 IoU 值。
在实际应用中,通常设定 IoU 的阈值,例如 0.5 或 0.7 等,当 IoU 值大于阈值时,认为预测成功。通过调整阈值,可以得到不同的模型,再通过不同的评价指标(如 ROC 曲线、F1 值等)来确定最优模型。

如需获取全套代码请参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/427133.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Tomcat 优化

1、隐藏版本信息 隐藏 HTTP 头部的版本信息 # 为 Connector 添加 server 属性 vim /usr/local/tomcat/conf/server.xml <Connector port"8080" protocol"HTTP/1.1" connectionTimeout"20000" redirectPort"8443" server"w…

红日二靶场

红日二靶场 靶场的搭建配置环境 一&#xff0c;信息收集1.网段探测2.端口扫描 二&#xff0c;渗透测试1.漏洞发现2.漏洞利用1.上传木马文件2.生成哥斯拉木马文件3.连接哥斯拉4.生成msf恶意程序5.本机开启监听6.将生成的msf文件上传 3.永恒之蓝漏洞1.EarthWorm隧道搭建2.ms17-01…

MyBatis详解(1)-- ORM模型

MyBatis详解&#xff08;1&#xff09; JDBC的弊端&#xff1a; ORM 模型常见的ORM模型&#xff1a;mybatis和Hibernate的区别 ***优势&#xff1a;mybatis解决问题&#xff1a;优点&#xff1a; MyBatisMyBatis环境搭建项目架构mybatis生命周期 JDBC的弊端&#xff1a; 1.硬编…

【极数系列】Flink环境搭建(02)

【极数系列】Flink环境搭建&#xff08;02&#xff09; 引言 1.linux 直接在linux上使用jdk11flink1.18.0版本部署 2.docker 使用容器部署比较方便&#xff0c;一键启动停止&#xff0c;方便参数调整 3.windows 搭建Flink 1.18.0版本需要使用Cygwin或wsl工具模拟unix环境…

[设计模式Java实现附plantuml源码~创建型] 产品族的创建——抽象工厂模式

前言&#xff1a; 为什么之前写过Golang 版的设计模式&#xff0c;还在重新写Java 版&#xff1f; 答&#xff1a;因为对于我而言&#xff0c;当然也希望对正在学习的大伙有帮助。Java作为一门纯面向对象的语言&#xff0c;更适合用于学习设计模式。 为什么类图要附上uml 因为很…

基于web的亚热带常见自然林病虫害识别系统——数据集与数据集划分

文章目录 概要数据收集数据集划分技术代码小结 概要 本篇文章先为病虫害识别进行数据的分类&#xff0c;划分训练集&#xff0c;划分为三个数据集&#xff0c;病虫害的数据集我已经放在我的资源里面&#xff0c;有需要的小伙伴可以自己下载。 声明&#xff1a; 我的数据集照片…

【nowcoder】链表的回文结构

牛客题目链接 链表的回文结构 /* struct ListNode {int val;struct ListNode *next;ListNode(int x) : val(x), next(NULL) {} };*/ #include <cstdlib> // 建议大伙自己对照我的代码画下图&#xff0c;假设A链表是&#xff1a;1 2 3 2 1 class PalindromeList { publi…

JavaScript进阶:WebAPIs重点知识整理3

1 本地存储 存储容量大&#xff1a;约5M 1.1 localStorage 1.1.1 存储 localStorage.setItem(username,张三) localStorage.setItem(password,123456) 1.1.2 获取 console.log(localStorage.getItem(username)) 1.1.3 删除 localStorage.removeItem(username) 1.2 session…

【业务功能篇133】 Mysql连接串优化性能问题

rewriteBatchedStatementstrue开启了MySQL驱动程序的批量处理功能。 spring.datasource.urljdbc:mysql://localhost:3306/mydatabase?rewriteBatchedStatementstrue 在MyBatis Plus框架中&#xff0c;批量插入是一种高效的数据库操作方式。通过开启rewriteBatchedStatementstr…

栈实现队列(附带源码)

一、思路图解 首先&#xff0c;队列&#xff1a;先进先出 栈&#xff1a;先进后出 那么&#xff0c;怎么用栈实现队列呢&#xff1f; 很简单&#xff0c;首先&#xff0c;创建两个栈 一个叫pushsatck,用来入队列 一个叫popstack,用来出队列 队列的核心在于先进先出&#xf…

2024.1.24 C++QT 作业

思维导图 练习题 1.提示并输入一个字符串&#xff0c;统计该字符中大写、小写字母个数、数字个数、空格个数以及其他字符个数 #include <iostream> #include <string.h> #include <array> using namespace std;int main() {string str;cout << "…

Linux shell编程学习笔记41:lsblk命令

边缘计算的挑战和机遇 边缘计算面临着数据安全与隐私保护、网络稳定性等挑战&#xff0c;但同时也带来了更强的实时性和本地处理能力&#xff0c;为企业降低了成本和压力&#xff0c;提高了数据处理效率。因此&#xff0c;边缘计算既带来了挑战也带来了机遇&#xff0c;需要我…