记录pytorch实现自定义算子并转onnx文件输出

概览:记录了如何自定义一个算子,实现pytorch注册,通过C++编译为库文件供python端调用,并转为onnx文件输出

整体大概流程:

  • 定义算子实现为torch的C++版本文件
  • 注册算子
  • 编译算子生成库文件
  • 调用自定义算子

一、编译环境准备

1,在pytorch官网下载如下C++的libTorch package,下载完成后解压文件,是一个libtorch文件夹。

2,提前准备好python,以及pytorch

3,本示例使用了opencv库,所以需要提前安装好opencv。

二、自定义算子的实现

1,实现自定义算子函数

在解压后的libtorch文件夹统计目录,实现自定义算子,用opencv库实现的图像投射函数:warp_perspective。warp_perspective函数后面几行就是实现自定义算子的注册

warpPerspective.cpp文件:

#include "torch/script.h"
#include "opencv2/opencv.hpp"torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {// BEGIN image_matcv::Mat image_mat(/*rows=*/image.size(0),/*cols=*/image.size(1),/*type=*/CV_32FC1,/*data=*/image.data_ptr<float>());// END image_mat// BEGIN warp_matcv::Mat warp_mat(/*rows=*/warp.size(0),/*cols=*/warp.size(1),/*type=*/CV_32FC1,/*data=*/warp.data_ptr<float>());// END warp_mat// BEGIN output_matcv::Mat output_mat;cv::warpPerspective(image_mat, output_mat, warp_mat, /*dsize=*/{ image.size(0),image.size(1) });// END output_mat// BEGIN output_tensortorch::Tensor output = torch::from_blob(output_mat.ptr<float>(), /*sizes=*/{ image.size(0),image.size(1) });return output.clone();// END output_tensor
}
//static auto registry = torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);  // torch.__version__: 1.5.0torch.__version__ >= 1.6.0  torch/include/torch/library.h
TORCH_LIBRARY(my_ops, m) {m.def("warp_perspective", warp_perspective);
}

2,同级目录创建CMakeList.txt文件

里面需要修改你自己的python下torch的路径,以及你对应安装python版pytorch是cpu还是gpu的。

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(warp_perspective)set(CMAKE_VERBOSE_MAKEFILE ON)
# >>> build type 
set(CMAKE_BUILD_TYPE "Release")				# 指定生成的版本
set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall")set(TORCH_ROOT "/home/xxx/anaconda3/lib/python3.10/site-packages/torch")   
include_directories(${TORCH_ROOT}/include)
link_directories(${TORCH_ROOT}/lib/)# Opencv
find_package(OpenCV REQUIRED)# Define our library target
add_library(warp_perspective SHARED warpPerspective.cpp)# Enable C++14
target_compile_features(warp_perspective PRIVATE cxx_std_17)# libtorch库文件
target_link_libraries(warp_perspective # CPUc10 torch_cpu# GPU# c10_cuda # torch_cuda)# opencv库文件
target_link_libraries(warp_perspective${OpenCV_LIBS}
)add_definitions(-D _GLIBCXX_USE_CXX11_ABI=0)

3,编译生成库文件

同级目录创建build文件夹,进入build文件夹利用CMakeList.txt进行编译,生成libwarp_perspective.so库文件

mkdir build
cd build
cmake ..
make

4,python版pytorch进行自定义算子的测试

注意我的以上代码都是放在了/data/xxx/mylib路径下,所以torch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")就找到库文件的位置。

这里我随便找了一张图片,和直接用python版的opencv做投射变换的结果作为golden对比。如下分别是原图,golden, 自定义pytorch算子的输出。自定义算子的输出不太对,但是图像轮廓和投射效果是对的,后面有时间我再检查一下是什么原因。

测试代码: 

import torch
import cv2
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")im=cv2.imread("/data/xxx/mylib/cat.jpg",0)pst1 = np.float32([[56,65], [368,52], [28,387], [389,390]])
pst2 = np.float32([[100,145], [300,100], [80,290], [310,300]])
#2.2获取透视变换矩阵
T = cv2.getPerspectiveTransform(pst1, pst2)in_data =torch.from_numpy(np.float32(im))
in2_data = torch.Tensor(T)out1=torch.ops.my_ops.warp_perspective(in_data,in2_data)
dst0=np.uint8(out1.numpy())
cv2.imwrite("/data/xxx/mylib/cat_warp.jpg",dst0)dst = cv2.warpPerspective(im, np.float32(T), (im.shape[1], im.shape[0]))
cv2.imwrite("/data/xxx/mylib/cat_warp_gold.jpg",dst)

三、自定义算子导出为onnx文件

将注册的pytorch的自定义算子导出为onnx文件查看,效果图如下:

导出代码文件如下

import torch
import numpy as nptorch.ops.load_library("/data/xxx/mylib/build/libwarp_perspective.so")
class MyNet(torch.nn.Module):def __init__(self, name):super(MyNet, self).__init__()self.model_name = namedef forward(self, in_data, warp_data):return torch.ops.my_ops.warp_perspective(in_data, warp_data)def my_custom(g, in_data, warp_data):return g.op("cus_ops::warp_perspective", in_data, warp_data)
torch.onnx.register_custom_op_symbolic("my_ops::warp_perspective", my_custom, 9)if __name__ == "__main__":net = MyNet("my_ops")in_data = torch.randn((32, 32))warp_data = torch.rand((3, 3))out = net(in_data, warp_data)print("out: ", out)# export onnxtorch.onnx.export(net,(in_data, warp_data),"./my_ops_export_model2.onnx",input_names=["img_data", "warp_mat"],output_names=["out_img"],custom_opsets={"cus_ops": 11},dynamic_axes={"img_data": {0: "width", 1: "height"},"out_img": {0: "width", 1: "height"}})

参考:https://www.cnblogs.com/xiaxuexiaoab/p/15524047.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/176322.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++算法:矩阵中的最长递增路径

涉及知识点 拓扑排序 题目 给定一个 m x n 整数矩阵 matrix &#xff0c;找出其中 最长递增路径 的长度。 对于每个单元格&#xff0c;你可以往上&#xff0c;下&#xff0c;左&#xff0c;右四个方向移动。 你 不能 在 对角线 方向上移动或移动到 边界外&#xff08;即不允…

面试经典(6/150)轮转数组

面试经典&#xff08;6/150&#xff09;轮转数组 给定一个整数数组 nums&#xff0c;将数组中的元素向右轮转 k 个位置&#xff0c;其中 k 是非负数。 以下为自己的思路&#xff0c;我不明白最终的返回值为什么有误&#xff0c;好像是题目里要求原地解决问题&#xff0c;而我创…

如何把小米路由器刷入OpenWRT系统并通过内网穿透工具实现公网远程访问

小米路由器4A千兆版刷入OpenWRT并远程访问 文章目录 小米路由器4A千兆版刷入OpenWRT并远程访问前言1. 安装Python和需要的库2. 使用 OpenWRTInvasion 破解路由器3. 备份当前分区并刷入新的Breed4. 安装cpolar内网穿透4.1 注册账号4.2 下载cpolar客户端4.3 登录cpolar web ui管理…

解析数据洁净之道:BI中数据清理对见解的深远影响

本文由葡萄城技术团队发布。转载请注明出处&#xff1a;葡萄城官网&#xff0c;葡萄城为开发者提供专业的开发工具、解决方案和服务&#xff0c;赋能开发者。 前言 随着数字化和信息化进程的不断发展&#xff0c;数据已经成为企业的一项不可或缺的重要资源。然而&#xff0c;这…

手把手带你学习 JavaScript 的 ES6 ~ ESn

文章目录 一、引言二、了解 ES6~ESn 的新特性三、掌握 ES6~ESn 的用法和实现原理四、深入挖掘和拓展《深入理解现代JavaScript》编辑推荐内容简介作者简介精彩书评目录 一、引言 JavaScript 是一种广泛使用的网络编程语言&#xff0c;它在前端开发中扮演着重要角色。随着时间的…

实战leetcode(二)

Practice makes perfect&#xff01; 实战一&#xff1a; 这里我们运用快慢指针的思想&#xff0c;我们的slow和fast都指向第一个节点&#xff0c;我们的快指针一次走两步&#xff0c;慢指针一次走一步&#xff0c;当我们的fast指针走到尾的时候&#xff0c;我们的慢指针正好…

Python接口测试框架选择之pytest+yaml+Allure!

一、为什么选择pytest&#xff1f; pytest完全兼容python自带的unittest pytest让单元测试更简单&#xff0c;能很好的管理测试用例。 对于实现接口测试的复杂场景&#xff0c;pytest的fixture、PDB等高阶用法都能实现需求。 入门简单&#xff0c;对于代码基础薄弱的团队人…

基于单片机智能浇花系统仿真设计

**单片机设计介绍&#xff0c; 基于单片机智能浇花系统仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的智能浇花系统可以实现自动化浇水、测土湿度和温度等功能&#xff0c;以下是一个基本的仿真设计步骤&am…

安全区域边界(设备和技术注解)

网络安全等级保护相关标准参考《GB/T 22239-2019 网络安全等级保护基本要求》和《GB/T 28448-2019 网络安全等级保护测评要求》 密码应用安全性相关标准参考《GB/T 39786-2021 信息系统密码应用基本要求》和《GM/T 0115-2021 信息系统密码应用测评要求》 1边界防护 1.1应保证跨…

JavaScript从入门到精通系列第三十七篇:详解JavaScript中文档的加载顺序

文章目录 一&#xff1a;文档加载说明 1&#xff1a;回顾一个代码 2&#xff1a;问题分析和说明 二&#xff1a;如何给JS换个位置&#xff1f; 1&#xff1a;过程分析 2&#xff1a;代码编写 3&#xff1a;运行结果 4&#xff1a;解释说明 大神链接&#xff1a;作者有幸…

Redis 事务是什么?又和MySQL事务有什么区别?

目录 1. Redis 事务的概念 2. Redis 事务和 MySQL事务的区别&#xff1f; 3. Redis 事务常用命令 1. Redis 事务的概念 下面是在 Redis 官网上找到的关于事务的解释&#xff0c;这里划重点&#xff0c;一组命令&#xff0c;一个步骤。 也就是说&#xff0c;在客户端与 Redi…

【QT系列教程】之二创建项目和helloworld案例

文章目录 一、QT创建项目1.1、创建项目1.2、选择创建项目属性1.3、选择路径和项目名称1.4、选择构建项目类型1.5、布局方式1.6、翻译文件&#xff0c;根据自己需求选择1.7、选择套件1.8、项目管理&#xff0c;自行配置1.9、配置完成&#xff0c;系统自动更新配置 二、QT界面介绍…