基于华为atlas的unet分割模型探索

Unet模型使用官方基于kaggle Carvana Image Masking Challenge数据集训练的模型。

模型输入为572*572*3,输出为572*572*2。分割目标分别为,0:背景,1:汽车。

Pytorch的pth模型转化onnx模型:

import torchfrom unet import UNetmodel = UNet(n_channels=3, n_classes=2, bilinear=False)
model = model.to(memory_format=torch.channels_last)state_dict = torch.load("unet_carvana_scale1.0_epoch2.pth", map_location="cpu")
#del state_dict['mask_values']
model.load_state_dict(state_dict)dummy_input = torch.randn(1, 3, 572, 572)torch.onnx.export(model, dummy_input, "unet.onnx", verbose=True)

模型输入输出节点分析:

使用工具Netron查看模型结构,确定模型输入节点名称为input.1,输出节点名称为/outc/conv/Conv

onnx模型转化atlas模型:

atc --model=./unet.onnx --framework=5 --output=unet --soc_version=Ascend310P3  --input_shape="input.1:1,3,572,572" --output_type="/outc/conv/Conv:0:FP32" --out_nodes="/outc/conv/Conv:0"

推理代码实现:

import base64
import json
import os
import timeimport numpy as np
import cv2import MxpiDataType_pb2 as mxpi_data
from StreamManagerApi import InProtobufVector
from StreamManagerApi import MxProtobufIn
from StreamManagerApi import StreamManagerApidef check_dir(dir):if not os.path.exists(dir):os.makedirs(dir, exist_ok=True)class SDKInferWrapper:def __init__(self): # 完成初始化self._stream_name = Noneself._stream_mgr_api = StreamManagerApi()if self._stream_mgr_api.InitManager() != 0:raise RuntimeError("Failed to init stream manager.")pipeline_name = './nested_unet.pipeline'self.load_pipeline(pipeline_name)self.width = 572self.height = 572def load_pipeline(self, pipeline_path):with open(pipeline_path, 'r') as f:pipeline = json.load(f)self._stream_name = list(pipeline.keys())[0].encode() # 'unet_pytorch'if self._stream_mgr_api.CreateMultipleStreams(json.dumps(pipeline).encode()) != 0:raise RuntimeError("Failed to create stream.")def do_infer(self, img_bgr):# preprocessimage = cv2.resize(img_bgr, (self.width, self.height))image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)image = image.astype('float32') / 255.0image = image.transpose(2, 0, 1)tensor_pkg_list = mxpi_data.MxpiTensorPackageList()tensor_pkg = tensor_pkg_list.tensorPackageVec.add()tensor_vec = tensor_pkg.tensorVec.add()tensor_vec.deviceId = 0tensor_vec.memType = 0for dim in [1, *image.shape]:tensor_vec.tensorShape.append(dim) # tensorshape属性为[1,3,572,572]input_data = image.tobytes()tensor_vec.dataStr = input_datatensor_vec.tensorDataSize = len(input_data)protobuf_vec = InProtobufVector()protobuf = MxProtobufIn()protobuf.key = b'appsrc0'protobuf.type = b'MxTools.MxpiTensorPackageList'protobuf.protobuf = tensor_pkg_list.SerializeToString()protobuf_vec.push_back(protobuf)unique_id = self._stream_mgr_api.SendProtobuf(self._stream_name, 0, protobuf_vec)if unique_id < 0:raise RuntimeError("Failed to send data to stream.")infer_result = self._stream_mgr_api.GetResult(self._stream_name, unique_id)if infer_result.errorCode != 0:raise RuntimeError(f"GetResult error. errorCode={infer_result.errorCode}, "f"errorMsg={infer_result.data.decode()}")output_tensor = self._parse_output_data(infer_result)output_tensor = np.squeeze(output_tensor)output_tensor = softmax(output_tensor)mask = np.argmax(output_tensor, axis =0)score = np.max(output_tensor, axis = 0)mask = cv2.resize(mask, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)score = cv2.resize(score, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)return mask, scoredef _parse_output_data(self, output_data):infer_result_data = json.loads(output_data.data.decode())content = json.loads(infer_result_data['metaData'][0]['content'])tensor_vec = content['tensorPackageVec'][0]['tensorVec'][0]data_str = tensor_vec['dataStr']tensor_shape = tensor_vec['tensorShape']infer_array = np.frombuffer(base64.b64decode(data_str), dtype=np.float32)return infer_array.reshape(tensor_shape)def draw(self, mask):color_lists = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]drawed_img = np.stack([mask, mask, mask], axis = 2)for i in np.unique(mask):drawed_img[:,:,0][drawed_img[:,:,0]==i] = color_lists[i][0]drawed_img[:,:,1][drawed_img[:,:,1]==i] = color_lists[i][1]drawed_img[:,:,2][drawed_img[:,:,2]==i] = color_lists[i][2]return drawed_imgdef softmax(x):exps = np.exp(x - np.max(x))return exps/np.sum(exps)def sigmoid(x):y = x.copy()y[x >= 0] = 1.0 / (1 + np.exp(-x[x >= 0]))y[x < 0] = np.exp(x[x < 0]) / (1 + np.exp(x[x < 0]))return ydef check_dir(dir):if not os.path.exists(dir):os.makedirs(dir, exist_ok=True)def test():dataset_dir = './sample_data'output_folder = "./infer_result"   os.makedirs(output_folder, exist_ok=True)sdk_infer = SDKInferWrapper()# read imgimage_name = "./sample_data/images/111.jpg"img_bgr = cv2.imread(image_name)# infert1 = time.time()mask, score = sdk_infer.do_infer(img_bgr)t2 = time.time()print(t2-t1, mask, score)drawed_img = sdk_infer.draw(mask)cv2.imwrite("infer_result/draw.png", drawed_img)if __name__ == "__main__":test()

运行代码:

set -e
. /usr/local/Ascend/ascend-toolkit/set_env.sh
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }#export MX_SDK_HOME=/home/work/mxVision
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/pythonpython3 unet.py
exit 0

运行效果:

个人思考:

华为atlas的参考案例细节不到位,步骤缺失较多,摸索困难,代码写法较差,信创化道路任重而道远。

参考资料:

GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images

https://gitee.com/ascend/samples/tree/master/python/level2_simple_inference/3_segmentation/unet++

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/513869.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux笔记--make

使用上一节的 main.c、add.c、sub.c文件进行编译&#xff0c;编译的过程有很多步骤&#xff0c;如果要重新编译&#xff0c;还需要再重来一遍&#xff0c;能不能一步完成这些步骤?将这些步骤写到makefile文件中&#xff0c;通过make工具进行编译 一个工程中的源文件不计其数&a…

Vue基础入门(4)- Vuex的使用

Vue基础入门&#xff08;4&#xff09;- Vuex的使用 Vuex 主要内容&#xff1a;Store以及其中的state、mutations、actions、getters、modules属性 介绍&#xff1a;Vuex 是一个 Vue 的 状态管理工具&#xff0c;状态就是数据。 大白话&#xff1a;Vuex 是一个插件&#xff…

Ubuntu18.04运行ORB-SLAM3

ORB-SLAM3复现(ubuntu18) 文章目录 ORB-SLAM3复现(ubuntu18)1 坐标系与外参Intrinsic parameters2 内参Intrinsic parameters2.1 相机内参① 针孔模型Pinhole② KannalaBrandt8模型③ Rectified相机 2.2 IMU内参 3 VI标定—外参3.1 Visual calibration3.2 Inertial calibration…

CSS标准文档流与脱离文档流,分享一点面试小经验

大厂面试真题整理 CSS&#xff1a; 1&#xff0c;盒模型 2&#xff0c;如何让一个盒子水平垂直居中&#xff1f; 3&#xff0c;css 优先级确定 4&#xff0c;解释下浮动和它的工作原理&#xff0c;清除浮动的方法&#xff1f; 5&#xff0c;CSS隐藏元素的几种方法 6&#xff0…

CSS常用五类选择器,附面试题

学习路线 第一阶段&#xff1a;网页制作 HTML&#xff1a;常用标签&#xff0c;锚点&#xff0c;列表标签&#xff0c;表单标签&#xff0c;表格标签&#xff0c;标签分类&#xff0c;标签语义化&#xff0c;注释&#xff0c;字符实体 CSS&#xff1a;CSS介绍&#xff0c;全局…

docker 安装rabbitmq并配置hyperf使用

这里我想完成的是 制作消息&#xff08;多个协程制造&#xff09;——》推送到rabbitmq——》订阅消息队列——》消费消息&#xff08;ws协程客户端【一次消费多条】/ws前端&#xff09; 利用 WebSocket 协议让客户端和服务器端保持有状态的长链接&#xff0c;保存链接上来的客…

06 - 镜像管理

1 了解镜像 Docker镜像是一个特殊的文件系统&#xff0c;除了提供容器运行时所需的程序、库、资源、配置等文件外&#xff0c;还包含了一些为运行时准备的一些配置参数&#xff08;如匿名卷、环境变量、用户等&#xff09;。 但注意&#xff0c; 镜像不包含任何动态数据&#…

【Mining Data】收集数据(使用 Python 挖掘 Twitter 数据)

@[TOC](【Mining Data】收集数据(使用 Python 挖掘 Twitter 数据)) 具体步骤 第一步是注册您的应用程序。特别是,您需要将浏览器指向 http://apps.twitter.com,登录 Twitter(如果您尚未登录)并注册新应用程序。您现在可以为您的应用程序选择名称和描述(例如“Mining Demo”…

政安晨【TypeScript高级用法】(四):模块与声明文件

TypeScript是一种静态类型的JavaScript超集语言&#xff0c;它支持模块化开发和声明文件。 模块化开发是一种将代码分割为独立的模块&#xff0c;每个模块只关注自己的功能&#xff0c;然后通过导入和导出来实现模块之间的交互和复用。在TypeScript中&#xff0c;可以使用impo…

设置video的进度条常显

根据网上的方法&#xff0c;把隐藏的元素显示 在chrome中F12或者通过其他方式打开开发者工具&#xff08;相信应该知道从哪里打开&#xff09;&#xff0c;然后点击右上的齿轮&#xff0c;进入设置&#xff0c;勾选Show user agent shadow DOM 就能在elements里面查看视频的播…

【Leetcode】1588.所有奇数长度子数组的和

题目描述 思路 题目要求我们求解所有奇数长度数组的和。若暴力循环求解&#xff0c;时间复杂度过高。所以&#xff0c;我们可以采用前缀和优化。 如上图输入arr数组&#xff0c;sum[i]用于计算arr数组中前i个数的和。(在程序中&#xff0c;先给sum[0]赋值&#xff0c;等于arr[0…

libigl 极小曲面(全局优化)

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 二、实现代码 #include <igl/colon.h> #include <igl/harmonic.h>