c++通过tensorRT调用模型进行推理

模型来源
算法工程师训练得到的onnx模型

c++对模型的转换
拿到onnx模型后,通过tensorRT将onnx模型转换为对应的engine模型,注意:训练用的tensorRT版本和c++调用的tensorRT版本必须一致。

如何转换:

  1. 算法工程师直接转换为.engine文件进行交付。
  2. 自己转换,进入tensorRT安装目录\bin目录下,将onnx模型拷贝到bin目录,地址栏中输入cmd回车弹出控制台窗口,然后输入转换命令,如:

trtexec --onnx=model.onnx --saveEngine=model.engine --workspace=1024 --optShapes=input:1x13x512x640 --fp16

然后回车,等待转换完成,完成后如图所示:
在这里插入图片描述
并且在bin目录下生成.engine模型文件。

c++对.engine模型文件的调用和推理
首先将tensorRT对模型的加载及推理进行封装,命名为CTensorRT.cpp,老样子贴代码:

//CTensorRT.cpp
class Logger : public nvinfer1::ILogger {void log(Severity severity, const char* msg) noexcept override {if (severity <= Severity::kWARNING)std::cout << msg << std::endl;}
};Logger logger;
class CtensorRT
{
public:CtensorRT() {}~CtensorRT() {}private:std::shared_ptr<nvinfer1::IExecutionContext> _context;std::shared_ptr<nvinfer1::ICudaEngine> _engine;nvinfer1::Dims _inputDims;nvinfer1::Dims _outputDims;
public:void cudaCheck(cudaError_t ret, std::ostream& err = std::cerr){if (ret != cudaSuccess){err << "Cuda failure: " << cudaGetErrorString(ret) << std::endl;abort();}}bool loadOnnxModel(const std::string& filepath){auto builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(logger));if (!builder){return false;}const auto explicitBatch = 1U << static_cast<uint32_t>(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);auto network = std::unique_ptr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(explicitBatch));if (!network){return false;}auto config = std::unique_ptr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());if (!config){return false;}auto parser = std::unique_ptr<nvonnxparser::IParser>(nvonnxparser::createParser(*network, logger));if (!parser){return false;}parser->parseFromFile(filepath.c_str(), static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));std::unique_ptr<IHostMemory> plan{ builder->buildSerializedNetwork(*network, *config) };if (!plan){return false;}std::unique_ptr<IRuntime> runtime{ createInferRuntime(logger) };if (!runtime){return false;}_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(plan->data(), plan->size()));if (!_engine){return false;}_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());if (!_context){return false;}int nbBindings = _engine->getNbBindings();assert(nbBindings == 2); // 输入和输出,一共是2个for (int i = 0; i < nbBindings; i++){if (_engine->bindingIsInput(i))_inputDims = _engine->getBindingDimensions(i);    // (1,3,752,752)else_outputDims = _engine->getBindingDimensions(i);}return true;}bool loadEngineModel(const std::string& filepath){std::ifstream file(filepath, std::ios::binary);if (!file.good()){return false;}std::vector<char> data;try{file.seekg(0, file.end);const auto size = file.tellg();file.seekg(0, file.beg);data.resize(size);file.read(data.data(), size);}catch (const std::exception& e){file.close();return false;}file.close();auto runtime = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(logger));_engine = std::shared_ptr<nvinfer1::ICudaEngine>(runtime->deserializeCudaEngine(data.data(), data.size()));if (!_engine){return false;}_context = std::shared_ptr<nvinfer1::IExecutionContext>(_engine->createExecutionContext());if (!_context){return false;}int nbBindings = _engine->getNbBindings();assert(nbBindings == 2); // 输入和输出,一共是2个// 为输入和输出创建空间for (int i = 0; i < nbBindings; i++){if (_engine->bindingIsInput(i))_inputDims = _engine->getBindingDimensions(i);    //得到输入结构else_outputDims = _engine->getBindingDimensions(i);//得到输出结构}return true;}void ONNX2TensorRT(const char* ONNX_file, std::string save_ngine){// 1.创建构建器的实例nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);// 2.创建网络定义uint32_t flag = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);nvinfer1::INetworkDefinition* network = builder->createNetworkV2(flag);// 3.创建一个 ONNX 解析器来填充网络nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);// 4.读取模型文件并处理任何错误parser->parseFromFile(ONNX_file, static_cast<int32_t>(nvinfer1::ILogger::Severity::kWARNING));for (int32_t i = 0; i < parser->getNbErrors(); ++i){std::cout << parser->getError(i)->desc() << std::endl;}// 5.创建一个构建配置,指定 TensorRT 应该如何优化模型nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();// 7.指定配置后,构建引擎nvinfer1::IHostMemory* serializedModel = builder->buildSerializedNetwork(*network, *config);// 8.保存TensorRT模型std::ofstream p(save_ngine, std::ios::binary);p.write(reinterpret_cast<const char*>(serializedModel->data()), serializedModel->size());// 9.序列化引擎包含权重的必要副本,因此不再需要解析器、网络定义、构建器配置和构建器,可以安全地删除delete parser;delete network;delete config;delete builder;// 10.将引擎保存到磁盘,并且可以删除它被序列化到的缓冲区delete serializedModel;}uint32_t getElementSize(nvinfer1::DataType t) noexcept{switch (t){case nvinfer1::DataType::kINT32: return 4;case nvinfer1::DataType::kFLOAT: return 4;case nvinfer1::DataType::kHALF: return 2;case nvinfer1::DataType::kBOOL:case nvinfer1::DataType::kINT8: return 1;}return 0;}int64_t volume(const nvinfer1::Dims& d){return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());}bool infer(unsigned char* input, int real_input_size, cv::Mat& out_mat){tensor_custom::BufferManager buffer(_engine);cudaStream_t stream;cudaStreamCreate(&stream); // 创建异步cuda流int binds = _engine->getNbBindings();for (int i = 0; i < binds; i++){if (_engine->bindingIsInput(i)){size_t input_size;float* host_buf = static_cast<float*>(buffer.getHostBufferData(i, input_size));memcpy(host_buf, input, real_input_size);break;}}// 将输入传递到GPUbuffer.copyInputToDeviceAsync(stream);// 异步执行bool status = _context->enqueueV2(buffer.getDeviceBindngs().data(), stream, nullptr);if (!status)return false;buffer.copyOutputToHostAsync(stream);for (int i = 0; i < binds; i++){if (!_engine->bindingIsInput(i)){size_t output_size;float* tmp_out = static_cast<float*>(buffer.getHostBufferData(i, output_size));//do your something herebreak;}}cudaStreamSynchronize(stream);cudaStreamDestroy(stream);return true;}
};

调用方式

int main()
{vector<int> dims = { 1,13,512,640 };vector<float> vall;for (int i=0;i<13;i++){string file = "D:\\xxx\\" + to_string(i) + ".png";cv::Mat mt = imread(file, IMREAD_GRAYSCALE);cv::resize(mt, mt, cv::Size(640,512));mt.convertTo(mt, CV_32F, 1.0 / 255);cv::Mat shape_xr = mt.reshape(1, mt.total() * mt.channels());std::vector<float> vec_xr = mt.isContinuous() ? shape_xr : shape_xr.clone();vall.insert(vall.end(), vec_xr.begin(), vec_xr.end());}cv::Mat mt_4d(4, &dims[0], CV_32F, vall.data());string engine_model_file = "model.engine";CtensorRT cTensor;if (cTensor.loadEngineModel(engine_model_file)){cv::Mat out_mat;if (!cTensor.infer(mt_4d.data, vall.size() * 4, out_mat))std::cout << "infer error!" << endl;elsecv::imshow("out", out_mat);}elsestd::cout << "load model file failed!" << endl;cv::waitKey(0);return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/101883.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【2023高教社杯】D题 圈养湖羊的空间利用率 问题分析、数学模型及MATLAB代码

【2023高教社杯】D题 圈养湖羊的空间利用率 问题分析、数学模型及MATLAB代码 1 题目 题目 D 题 圈养湖羊的空间利用率 规模化的圈养养殖场通常根据牲畜的性别和生长阶段分群饲养&#xff0c;适应不同种类、不同阶段的牲畜对空间的不同要求&#xff0c;以保障牲畜安全和健康&a…

linux-进程-execl族函数

exec函数的作用&#xff1a; 我们用fork函数创建新进程后&#xff0c;经常会在新进程中调用exec函数去执行另外一个程序。当进程调用exec函数时&#xff0c;该进程被完全替换为新程序。因为调用exec函数并不创建新进程&#xff0c;所以前后进程的ID并没有改变。 简单来说就是&…

SpringBoot粗浅分析

应用分析 1、依赖管理机制 在springBoot项目中&#xff0c;导入starter-web所有想换依赖都会被导入&#xff0c;甚至不用去规定它们的版本号。它是根据Maven的依赖传递原则来设置&#xff0c;只需要导入场景启动器&#xff0c;场景启动器自动把这个场景的所有核心依赖全部导入…

广东成人高考报名将于9月14日开始!

截图来自广东省教育考试院官网* 今年的广东成人高考正式报名时间终于确定了&#xff01; 报名时间&#xff1a;2023年 9 月14—20日 准考证打印时间&#xff1a;考前一周左右 考试时间&#xff1a;2023年10月21—22日 录取时间&#xff1a;2023年12 月中上旬 报名条件: …

Android 状态栏显示运营商名称

Android 原生设计中在锁屏界面会显示运营商名称&#xff0c;用户界面中&#xff0c;大概是基于 icon 数量长度显示考虑&#xff0c;对运营商名称不作显示。但是国内基本都加上运营商名称。对图标显示长度优化基本都是&#xff1a;缩小运营商字体、限制字数长度、信号图标压缩上…

FPGA实战小项目2

基于FPGA的贪吃蛇游戏 基于FPGA的贪吃蛇游戏 基于fpga的数字密码锁ego1 基于fpga的数字密码锁ego1 基于fpga的数字时钟 basys3 基于fpga的数字时钟 basys3

磁盘分析 wiztree[win32] baobab[linux]

磁盘分析 wiztree[win32] && baobab[linux] wiztree[win32]baobab 又叫 Disk Usage Analyzer[linux]安装使用 参考 wiztree[win32] baobab 又叫 Disk Usage Analyzer[linux] baobab 又叫 Disk Usage Analyzer&#xff0c;是 Ubuntu 系统默认自带的磁盘分析工具&#x…

WebGIS外包开发流程

WebGIS开发流程需要综合考虑前端和后端开发、地理信息数据处理、用户需求和安全性等多个方面。成功的WebGIS应用程序需要不断地进行更新和维护&#xff0c;以适应变化的需求和技术。WebGIS开发是一个复杂的过程&#xff0c;通常包括以下主要步骤。北京木奇移动技术有限公司&…

C++项目实战——基于多设计模式下的同步异步日志系统-①-项目介绍

文章目录 专栏导读项目介绍开发环境核心技术环境搭建日志系统介绍1.为什么需要日志系统2.日志系统技术实现2.1同步写日志2.2异步写日志 专栏导读 &#x1f338;作者简介&#xff1a;花想云 &#xff0c;在读本科生一枚&#xff0c;C/C领域新星创作者&#xff0c;新星计划导师&a…

java+ssm+mysql电梯管理系统

项目介绍&#xff1a; 使用javassmmysql开发的电梯管理系统&#xff0c;系统包含管理员&#xff0c;监管员、安全员、维保员角色&#xff0c;功能如下&#xff1a; 管理员&#xff1a;系统用户管理&#xff08;监管员、安全员、维保员&#xff09;&#xff1b;系统公告&#…

Java测试(10)--- selenium

1.定位一组元素 &#xff08;1&#xff09;如何打开本地的HTML页面 拼成一个URL &#xff1a;file: /// 文件的绝对路径 import os os.path.abspath(文件的绝对路径&#xff09; &#xff08;2&#xff09;先定位出同一类元素&#xff08;tag name&#xff0c;name&…

【Linux】LVM原理及核心概念

LVM是什么&#xff1f;LVM核心概念LVM的优势在Linux上使用LVM感谢 &#x1f496; LVM是什么&#xff1f; LVM是一种高级的磁盘管理工具&#xff0c;用于在Linux和其他类Unix操作系统中管理磁盘存储。它的核心思想是将底层物理存储抽象为逻辑存储单元&#xff0c;从而提供了更大…