图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)

概述

DIS(Dichotomous Image Segmentation)是一种新的图像分割任务,旨在从自然图像中分割出高精度的物体。与传统的图像分割任务相比,DIS更侧重于具有单个或几个目标的图像,因此可以提供更丰富准确的细节。

为了研究DIS任务,研究人员创建了一个名为DIS5K的大规模、可扩展的数据集。DIS5K数据集包含了5,470张高分辨率图像,每张图像都配有高精度的二值分割掩码。这个数据集的建立有助于推动多个应用方向的发展,如图像去背景、艺术设计、模拟视图运动、基于图像的增强现实(AR)应用、基于视频的AR应用、3D视频制作等。

通过研究DIS任务和使用DIS5K数据集,研究人员可以探索新的图像分割方法,并为各种应用领域提供更精确、更可靠的图像分割技术,从而推动分割技术在更广泛的领域中的应用。

官网:https://xuebinqin.github.io/dis/index.html
Github:https://github.com/xuebinqin/DIS

数据集

图像二类分割是将图像分割成两个主要区域:前景和背景。在这种情况下,前景代表图像中的某个类别的物体,而背景则是除了该物体之外的所有内容。
官方公布了算所使用的数据集DIS5K, DIS5K数据集中的每张图像都经过了像素级别的手工标注,标注的真值掩码非常精确,每张图像的标记时间相当长。这种高精度的标注使得数据集中的每个像素都与其相应的类别关联起来,从而为模型提供了可靠的训练数据。这种高精度的标注是实现图像二类分割的关键,因为模型需要能够准确地识别和分割出前景物体。

在DIS5K数据集中,标注对象的类型多样,包括透明和半透明的物体,标注使用单个像素的二值掩码进行。这种精确的标注确保了模型训练的有效性和准确性,并且使得模型能够预测出高精度的物体分割结果。

DIS5K数据集网盘地址:https://pan.baidu.com/s/1umNk2AeBG5aB5kXlHTHdIg
提取码:7qfs

模型训练

模型训练可参考git上的官方的文档

模型推理

模型C++使用onnxruntime进行推理

#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>class DIS
{
public:DIS(std::string model_path);void inference(cv::Mat& cv_src, cv::Mat& cv_mask);
private:std::vector<float> input_image_;int inpWidth;int inpHeight;int outWidth;int outHeight;const float score_th = 0;Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "DIS");Ort::Session* ort_session = nullptr;Ort::SessionOptions sessionOptions = Ort::SessionOptions();std::vector<char*> input_names;std::vector<char*> output_names;std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputsstd::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};DIS::DIS(std::string model_path)
{std::wstring widestr = std::wstring(model_path.begin(), model_path.end());//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions);size_t numInputNodes = ort_session->GetInputCount();size_t numOutputNodes = ort_session->GetOutputCount();Ort::AllocatorWithDefaultOptions allocator;for (int i = 0; i < numInputNodes; i++){input_names.push_back(ort_session->GetInputName(i, allocator));Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();auto input_dims = input_tensor_info.GetShape();input_node_dims.push_back(input_dims);}for (int i = 0; i < numOutputNodes; i++){output_names.push_back(ort_session->GetOutputName(i, allocator));Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();auto output_dims = output_tensor_info.GetShape();output_node_dims.push_back(output_dims);}this->inpHeight = input_node_dims[0][2];this->inpWidth = input_node_dims[0][3];this->outHeight = output_node_dims[0][2];this->outWidth = output_node_dims[0][3];
}void DIS::inference(cv::Mat& cv_src, cv::Mat& cv_mask)
{cv::Mat cv_dst;cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight));this->input_image_.resize(this->inpWidth * this->inpHeight * cv_dst.channels());for (int c = 0; c < 3; c++){for (int i = 0; i < this->inpHeight; i++){for (int j = 0; j < this->inpWidth; j++){float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];this->input_image_[c * this->inpHeight * this->inpWidth + i * this->inpWidth + j] = pix / 255.0 - 0.5;}}}std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info,input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],&input_tensor_, 1, output_names.data(), output_names.size());   // 开始推理float* pred = ort_outputs[0].GetTensorMutableData<float>();cv::Mat mask(outHeight, outWidth, CV_32FC1, pred);double min_value, max_value;minMaxLoc(mask, &min_value, &max_value, 0, 0);mask = (mask - min_value) / (max_value - min_value);cv::resize(mask, cv_mask, cv::Size(cv_src.cols, cv_src.rows));
}void show_img(std::string name, const cv::Mat& img)
{cv::namedWindow(name, 0);int max_rows = 500;int max_cols = 600;if (img.rows >= img.cols && img.rows > max_rows) {cv::resizeWindow(name, cv::Size(img.cols * max_rows / img.rows, max_rows));}else if (img.cols >= img.rows && img.cols > max_cols) {cv::resizeWindow(name, cv::Size(max_cols, img.rows * max_cols / img.cols));}cv::imshow(name, img);
}cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{int width = cv_src.cols;int height = cv_src.rows;cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);float* alpha_data = (float*)alpha.data;for (int i = 0; i < height; i++){for (int j = 0; j < width; j++){float alpha_ = alpha_data[i * width + j];cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];}}return cv_matting;
}int main()
{DIS dis_net("isnet_general_use_720x1280.onnx");std::string path = "images";std::vector<std::string> filenames;cv::glob(path, filenames, false);for (auto file_name : filenames){cv::Mat cv_src = cv::imread(file_name);//std::vector<cv::Mat> cv_dsts;cv::Mat cv_dst, cv_mask;dis_net.inference(cv_src, cv_mask);std::vector<int> color{255, 0, 0};cv_dst=replaceBG(cv_src, cv_mask, color);show_img("src", cv_src);show_img("mask", cv_mask);show_img("dst", cv_dst);cv::waitKey(0);}
}

python推理代码也依赖onnxruntime

import argparse
import cv2
import numpy as np
import onnxruntime
### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']  inference failed
class DIS():def __init__(self, modelpath, score_th=None):so = onnxruntime.SessionOptions()so.log_severity_level = 3self.net = onnxruntime.InferenceSession(modelpath, so)self.input_height = self.net.get_inputs()[0].shape[2]self.input_width = self.net.get_inputs()[0].shape[3]self.input_name = self.net.get_inputs()[0].nameself.output_name = self.net.get_outputs()[0].nameself.score_th = score_thdef detect(self, srcimg):img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height))img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = img.astype(np.float32) / 255.0 - 0.5blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)outs = self.net.run([self.output_name], {self.input_name: blob})mask = np.array(outs[0]).squeeze()min_value = np.min(mask)max_value = np.max(mask)mask = (mask - min_value) / (max_value - min_value)if self.score_th is not None:mask = np.where(mask < self.score_th, 0, 1)mask *= 255mask = mask.astype('uint8')mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)return maskdef generate_overlay_image(srcimg, mask):overlay_image = np.zeros(srcimg.shape, dtype=np.uint8)overlay_image[:] = (255, 255, 255)mask = np.stack((mask,) * 3, axis=-1).astype('uint8') mask_image = np.where(mask, srcimg, overlay_image)return mask, mask_imageif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg')parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx')args = parser.parse_args()mynet = DIS(args.modelpath)srcimg = cv2.imread(args.imgpath)mask = mynet.detect(srcimg)mask, overlay_image = generate_overlay_image(srcimg, mask)winName = 'Deep learning object detection in onnxruntime'cv2.namedWindow(winName, cv2.WINDOW_NORMAL)cv2.imshow(winName, np.hstack((srcimg, mask)))cv2.waitKey(0)cv2.destroyAllWindows()

推理结果
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
资源和模型下载地址:https://download.csdn.net/download/matt45m/89024664

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/568522.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据容器-dict以及总结-Python

师从黑马程序员 字典的定义 同样使用{},不过存储的元素是以个个的&#xff1a;键值对&#xff0c;如下语法&#xff1a; #定义字典 my_dict1{"王力宏":99,"周杰伦":88,"林俊杰":77} #定义空字典 my_dict2{} my_dict3dict() print(f"字典1…

ICC2:postmask eco限制绕线层次

更多学习内容请关注「拾陆楼」知识星球 拾陆楼知识星球入口 做postmask eco常会遇到限制绕线层次的要求 ICC2可以使用set_ignore_layer来限制: set_ignore_layer -min M1 -max M3 set_app_options -name route.common.net_max_layer_mode -value hard set_app_options -na…

常用类五(File类)

目录 File 类的基本用法 File 类的常见构造方法&#xff1a;public File(String pathname) 通过 File 对象可以访问文件的属性 通过 File 对象创建空文件或目录&#xff08;在该对象所指的文件或目录不存在的情况下&#xff09; 递归遍历目录结构和树状展现 File 类的基本…

Selenium 自动化 —— 浏览器窗口操作

更多内容请关注我的专栏&#xff1a; 入门和 Hello World 实例使用WebDriverManager自动下载驱动Selenium IDE录制、回放、导出Java源码 当用 Selenium 打开浏览器后&#xff0c;我们就可以通过 Selenium 对浏览器做各种操作&#xff0c;就像我们日常用鼠标和键盘操作浏览器一…

[计算机效率] 文件搜索工具:Listary(附详细使用教程)

3.5 文件搜索工具&#xff1a;Listary Listary是一款实用的搜索工具&#xff0c;它能为我的电脑&#xff08;资源管理器&#xff09;增添许多智能命令&#xff0c;提高用户日常收藏和整理文件的效率。它具备多种实用功能&#xff0c;例如收藏文件夹、快速打开最近浏览的文件夹…

BS系统的登录鉴权流程演变(高级必备)

BS系统的登录鉴权流程演变 1 基础知识1.1 Http Cookie1.2 重定向与前端路由Vue-router1.2.1 后端重定向1.2.2 Vue-router 1.3.JWT简介1.4 Spring-Security1.4.1 过滤器链[24]1.4.3 DelegationFilterProxy的实例化和拦截配置1.4.4 在项目中使用Spring Security1.4.5 用户认证 2.…

Spring Cloud - Openfeign 实现原理分析

OpenFeign简介 OpenFeign 是一个声明式 RESTful 网络请求客户端。OpenFeign 会根据带有注解的函数信息构建出网络请求的模板,在发送网络请求之前,OpenFeign 会将函数的参数值设置到这些请求模板中。虽然 OpenFeign 只能支持基于文本的网络请求,但是它可以极大简化网络请求的…

六西格玛绿带培训:量化进步的关键

在追求卓越的道路上&#xff0c;六西格玛绿带培训成为了一种革命性的思维方式&#xff0c;让我们能够以科学和系统的方法提升过程性能。但在这个追求中&#xff0c;我们如何确定我们的进步&#xff1f;过程能力分析为我们提供了明确的答案。通过计算过程能力&#xff0c;即6σ&…

竞赛 python 爬虫与协同过滤的新闻推荐系统

1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; python 爬虫与协同过滤的新闻推荐系统 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;3分创新点&#xff1a;4分 该项目较为新颖&…

python关于字符串基础学习

字符串 python字符串是不可改变的 Python不支持单字符类型&#xff0c;单字符也是作为一个字符串使用的。 字符串编码 python3直接支持Unicode,可以表示世界上任何书面语言的字符 python3的字符默认就是16位Unicode编码&#xff0c;ASCII是Unicode的子集 使用内置函数 ord()…

二十六 超级数据查看器 讲解稿 用输入值批量更新字段

二十六 超级数据查看器 讲解稿 用输入值批量更新字段 ​点击此处 以新页面 打开B站 播放当前教学视频 点击访问app下载页面 百度手机助手 下载地址 ​ 大家好&#xff0c;今天我们讲一下超级数据查看器的输入更新功能&#xff0c;输入更新功能是将选择的TXT文档的数据&…

【竞技宝】DOTA2-PGL联赛:niu神无解 LGD2-0轻松击败DH

北京时间2024年3月26日,PGL联赛中国区的比赛在昨日正式打响,首日共进行了四场胜者组首轮的比赛,第四场比赛由LGD对阵DH。本场比赛,DH两局都在前中期和LGD有来有回,但niu的中期节奏完全摧毁了DH,最终LGD2-0轻松击败DH。以下是本场比赛的详细战报。 第一局: 首局比赛,LGD在天辉方…