模型部署 - TensorRT Triton 学习

news/2024/11/15 17:24:47/文章来源:https://www.cnblogs.com/ai-ldj/p/18300487

先介绍TensorRT、Triton的关系和区别:

TensorRT:为inference(推理)为生,是NVIDIA研发的一款针对深度学习模型在GPU上的计算,显著提高GPU上的模型推理性能。即一种专门针对高性能推理的模型框架,也可以解析其他框架的模型如tensorflow、torch。

主要优化手段如下:

 Triton:类似于TensorFlow Serving,但triton包括server和client。

 triton serving能够实现不同模型的统一部署和服务,提供http和grpc协议,给triton client请求模型推理。

---------------------------------------分割线------------------------------------------------

如果是要将模型和推理嵌入在服务或软硬件中,那么TensorRT是很好的选择,使用它来加载模型进行推理,提升性能(tensorrt runtime);

不然,常规的做法是模型推理和其他业务隔离,模型统一部署在triton server,然后其他业务通过triton client来进行模型推理的请求。

实验环境:Ubuntu18.04, GeForce RTX 2080Ti

Triton部署

安装

通过docker的形式,首先拉取镜像

# <xx.yy>为Triton的版本
docker pull nvcr.io/nvidia/tritonserver:<xx.yy>-py3# 例如,拉取 20.12
docker pull nvcr.io/nvidia/tritonserver:20.12-py3

<要注意不同版本的tritonserver对cuda驱动最低版本要求,以及对应的tensorrt版本>

例如,20.12的版本需要NVIDIA Driver需要455以上,支持TensorRT 7.2.2。TensorRT版本要对应,不然模型可能会无法部署。

其他版本信息可以前往官网查看:

启动

CPU版本的启动 NVIDIA Triton Inference Server 

 

docker run --rm -p8000:8000 -p8001:8001 -p8002:8002 -v/full/path/to/docs/examples/model_repository:/models nvcr.io/nvidia/tritonserver:22.01-py3 tritonserver --model-repository=/models

GPU版本的启动,使用1个gpu

docker run --gpus=1 --rm -p8000:8000 -p8001:8001 -p8002:8002 -v/full/path/to/docs/examples/model_repository:/models nvcr.io/nvidia/tritonserver:22.01-py3 tritonserver --model-repository=/models
  • /full/path/to/docs/examples/model_repository:模型仓库的路径。除了本地文件系统,还支持Google Cloud、S3、Azure这些云存储:
  • --rm:表示容器停止运行时会删除容器
  • --gpus=1: 分配 1 个 GPU 资源给容器使用。
  • 8000为http端口,8001为grpc端口
  • -p8000:8000 -p8001:8001 -p8002:8002: 将容器内部的 8000、8001 和 8002 端口映射到宿主机的对应端口。这样可以从宿主机访问容器内部的服务。
  • -v/full/path/to/docs/examples/model_repository:/models: 将宿主机上的 /full/path/to/docs/examples/model_repository 目录挂载到容器内的 /models 目录。这样容器可以访问宿主机上的模型文件。
  • nvcr.io/nvidia/tritonserver:22.01-py3: 使用 NVIDIA 提供的 Triton Inference Server 22.01 版本的 Python 3 镜像作为容器的基础镜像。
  • tritonserver --model-repository=/models: 启动 Triton Inference Server 服务,并指定模型仓库目录为 /models,也就是我们挂载的宿主机目录。

正常启动的话,可以看到部署的模型运行状态,以及对外提供的服务端口

 

模型生成

Triton支持以下模型:TensorRT、ONNX、TensorFlow、Torch、OpenVINO、DALI,还有Python backend自定义生成的Python模型。

我们以一个简单的模型结构来演示:

我们以一个简单的模型结构来演示:

  1. INPUT0节点通过四则运算得到OUTPUT0节点;
  2. INPUT1节点通过embedding table映射为OUTPUT1

 

TensorFlow

tensorflow可以生成SavedModel或者GraphDef的模型格式

SavedModel模型需要按照以下的目录结构进行存储:

<model-repository-path>/<model-name>/config.pbtxt1/model.savedmodel/<saved-model files>

GraphDef:

<model-repository-path>/<model-name>/config.pbtxt1/model.graphdef
import os
import tensorflow as tf
from tensorflow.python.framework import graph_iodef create_modelfile(model_version_dir, max_batch,save_type="graphdef",version_policy=None):# your model netinput0_shape = [None, 2]input1_shape = [None, 2]x1 = tf.placeholder(tf.float32, input0_shape, name='INPUT0')inputs_id = tf.placeholder(tf.int32, input1_shape, name='INPUT1')out = tf.add(tf.multiply(x1, 0.5), 2)embedding = tf.get_variable("embedding_table", shape=[100, 10])pre = tf.nn.embedding_lookup(embedding, inputs_id)out0 = tf.identity(out, "OUTPUT0")out1 = tf.identity(pre, "OUTPUT1")try:os.makedirs(model_version_dir)except OSError as ex:pass  # ignore existing dirwith tf.Session() as sess:sess.run(tf.global_variables_initializer())if save_type == 'graphdef':create_graphdef_modelfile(model_version_dir, sess,outputs=["OUTPUT0", "OUTPUT1"])elif save_type == 'savemodel':create_savedmodel_modelfile(model_version_dir,sess,inputs={"INPUT0": x1,"INPUT1": inputs_id},outputs={"OUTPUT0": out,"OUTPUT1": pre})else:raise ValueError("save_type must be one of ['tensorflow_graphdef', 'tensorflow_savedmodel']")create_modelconfig(models_dir=os.path.dirname(model_version_dir),max_batch=max_batch,save_type=save_type,version_policy=version_policy)def create_graphdef_modelfile(model_version_dir, sess, outputs):"""
    tensorflow graphdef只能保存constant,无法保存Variable可以借助tf.graph_util.convert_variables_to_constants将Variable转化为constant:param model_version_dir::param sess::return:"""
    graph = sess.graph.as_graph_def()new_graph = tf.graph_util.convert_variables_to_constants(sess=sess,input_graph_def=graph,output_node_names=outputs)graph_io.write_graph(new_graph,model_version_dir,"model.graphdef",as_text=False)def create_savedmodel_modelfile(model_version_dir, sess, inputs, outputs):"""
:param model_version_dir::param sess::param inputs: dict, {input_name: input_tensor}:param outputs: dict, {output_name: output_tensor}:return:"""
    tf.saved_model.simple_save(sess,model_version_dir + "/model.savedmodel",inputs=inputs,outputs=outputs)

torch

pytorch模型的目录结构格式:<model-repository-path>/
<model-name>/
config.pbtxt
1/
model.pt
import os
import torch
from torch import nnclass MyNet(nn.Module):def __init__(self):super(MyNet, self).__init__()self.embedding = nn.Embedding(num_embeddings=100,embedding_dim=10)def forward(self, input0, input1):# tf.add(tf.multiply(x1, 0.5), 2)output0 = torch.add(torch.multiply(input0, 0.5), 2)output1 = self.embedding(input1)return output0, output1def create_modelfile(model_version_dir, max_batch,version_policy=None):# your model net# 定义输入的格式example_input0 = torch.zeros([2], dtype=torch.float32)example_input1 = torch.zeros([2], dtype=torch.int32)my_model = MyNet()traced = torch.jit.trace(my_model, (example_input0, example_input1))try:os.makedirs(model_version_dir)except OSError as ex:pass  # ignore existing dirtraced.save(model_version_dir + "/model.pt")

ONNX

ONNX的目录结构:

<model-repository-path>/<model-name>/config.pbtxt1/model.onnx

ONNX提供一种开源的深度学习和传统的机器学习模型格式,目的在于模型在不同框架之间进行转移。

下面我们介绍最常用的tensorflow和torch模型转成ONNX的方法。

tensorflow模型 --> ONNX

pip install -U tf2onnx# savedmodel
python -m tf2onnx.convert --saved-model tensorflow-model-path --output model.onnx# checkpoint
python -m tf2onnx.convert --checkpoint tensorflow-model-meta-file-path --output model.onnx --inputs input0:0,input1:0 --outputs output0:0# graphdef
python -m tf2onnx.convert --graphdef tensorflow-model-graphdef-file --output model.onnx --inputs input0:0,input1:0 --outputs output0:0

torch --> ONNX

import os
import torch
import torch.onnxdef torch2onnx(model_version_dir, max_batch):# 定义输入的格式example_input0 = torch.zeros([max_batch, 2], dtype=torch.float32)example_input1 = torch.zeros([max_batch, 2], dtype=torch.int32)my_model = MyNet()try:os.makedirs(model_version_dir)except OSError as ex:pass  # ignore existing dirtorch.onnx.export(my_model,(example_input0, example_input1),os.path.join(model_version_dir, 'model.onnx'),# 输入节点的名称input_names=("INPUT0", "INPUT1"),# 输出节点的名称output_names=("OUTPUT0", "OUTPUT1"),# 设置batch_size的维度dynamic_axes={"INPUT0": [0], "INPUT1": [0], "OUTPUT0": [0], "OUTPUT1": [0]},verbose=True)

TensorRT

需要注意:TensorRT仅支持GPU。

<model-repository-path>/<model-name>/config.pbtxt1/model.plan

比较推荐的方式是从ONNX解析得到TensorRT模型(TensorRT)

import tensorrt as trt
import osdef onnx2trt(model_version_dir, onnx_model_file, max_batch):logger = trt.Logger(trt.Logger.WARNING)builder = trt.Builder(logger)# The EXPLICIT_BATCH flag is required in order to import models using the ONNX parsernetwork = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))parser = trt.OnnxParser(network, logger)success = parser.parse_from_file(onnx_model_file)for idx in range(parser.num_errors):print(parser.get_error(idx))if not success:pass  # Error handling code hereprofile = builder.create_optimization_profile()# INPUT0可以接收[1, 2] -> [max_batch, 2]的维度profile.set_shape("INPUT0", [1, 2], [1, 2], [max_batch, 2])profile.set_shape("INPUT1", [1, 2], [1, 2], [max_batch, 2])config = builder.create_builder_config()config.add_optimization_profile(profile)# tensorrt 8.x# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)  # 1 MiB# tensorrt 7.xconfig.max_workspace_size = 1 << 20try:engine_bytes = builder.build_serialized_network(network, config)except AttributeError:engine = builder.build_engine(network, config)engine_bytes = engine.serialize()del enginewith open(os.path.join(model_version_dir, 'model.plan'), "wb") as f:f.write(engine_bytes)

模型配置文件

name: "tf_savemodel"
platform: "tensorflow_savedmodel"
max_batch_size: 8
version_policy: { latest { num_versions: 1 }}
input [{name: "INPUT0"data_type: TYPE_FP32dims: [ 2 ]},{name: "INPUT1"data_type: TYPE_INT32dims: [ 2 ]}
]
output [{name: "OUTPUT0"data_type: TYPE_FP32dims: [ 2 ]},{name: "OUTPUT1"data_type: TYPE_FP32dims: [ 2,10 ]}
]

name:模型名称,要跟模型路径对应。

platform:不同的模型存储格式都有自己对应的值。

max_batch_size:最大的batch_size,客户端超过这个batch_size的请求会报错。

version_policy:版本控制,这里是使用最新的一个版本。

input、output:输入和输出节点的名称,数据类型,维度。

维度一般不包括batch_size这个维度;

下表为不同框架对应的platform:

 下表是不同框架的数据类型对应关系:Model Config是配置文件的,API是triton client。其他框架是c++源码的命名空间,不过很好理解,主要包括16位和32位的int和float等等。

 

Triton Client

上述提到了,我们可以通过triton client来进行模型推理的请求,并且提供了http和grpc两种协议。

接下来,将以python来演示,仍然是上面那个简单的模型请求例子。

# 安装依赖包
pip install tritonclient[all]
import gevent.ssl
import numpy as np
import tritonclient.http as httpclientdef client_init(url="localhost:8000",ssl=False, key_file=None, cert_file=None, ca_certs=None, insecure=False,verbose=False):"""
:param url::param ssl: Enable encrypted link to the server using HTTPS:param key_file: File holding client private key:param cert_file: File holding client certificate:param ca_certs: File holding ca certificate:param insecure: Use no peer verification in SSL communications. Use with caution:param verbose: Enable verbose output:return:"""
    if ssl:ssl_options = {}if key_file is not None:ssl_options['keyfile'] = key_fileif cert_file is not None:ssl_options['certfile'] = cert_fileif ca_certs is not None:ssl_options['ca_certs'] = ca_certsssl_context_factory = Noneif insecure:ssl_context_factory = gevent.ssl._create_unverified_contexttriton_client = httpclient.InferenceServerClient(url=url,verbose=verbose,ssl=True,ssl_options=ssl_options,insecure=insecure,ssl_context_factory=ssl_context_factory)else:triton_client = httpclient.InferenceServerClient(url=url, verbose=verbose)return triton_clientdef infer(triton_client, model_name,input0='INPUT0', input1='INPUT1',output0='OUTPUT0', output1='OUTPUT1',request_compression_algorithm=None,response_compression_algorithm=None):"""
:param triton_client::param model_name::param input0::param input1::param output0::param output1::param request_compression_algorithm: Optional HTTP compression algorithm to use for the request body on client side.Currently supports "deflate", "gzip" and None. By default, no compression is used.:param response_compression_algorithm::return:"""
    inputs = []outputs = []# batch_size=8# 如果batch_size超过配置文件的max_batch_size,infer则会报错# INPUT0、INPUT1为配置文件中的输入节点名称inputs.append(httpclient.InferInput(input0, [8, 2], "FP32"))inputs.append(httpclient.InferInput(input1, [8, 2], "INT32"))# Initialize the data# np.random.seed(2022)inputs[0].set_data_from_numpy(np.random.random([8, 2]).astype(np.float32), binary_data=False)# np.random.seed(2022)inputs[1].set_data_from_numpy(np.random.randint(0, 20, [8, 2]).astype(np.int32), binary_data=False)# OUTPUT0、OUTPUT1为配置文件中的输出节点名称outputs.append(httpclient.InferRequestedOutput(output0, binary_data=False))outputs.append(httpclient.InferRequestedOutput(output1,binary_data=False))query_params = {'test_1': 1, 'test_2': 2}results = triton_client.infer(model_name=model_name,inputs=inputs,outputs=outputs,request_compression_algorithm=request_compression_algorithm,response_compression_algorithm=response_compression_algorithm)print(results)# 转化为numpy格式print(results.as_numpy(output0))print(results.as_numpy(output1))

 

grpc的代码基本相同,就不展示了,

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/743203.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vulnhub靶场 | DC系列 | DC-1

DC-1 目录环境搭建 渗透测试1. 信息收集 2. 漏洞利用 3. 提权DC ~ VulnHub VulnHub provides materials allowing anyone to gain practical hands-on experience with digital security, computer applications and network administration tasks. https://www.vulnhub.com/s…

Smart-doc:零注解侵入的API接口文档生成插件

零注解侵入的API接口文档生成插件——Smart-doc smart-doc 是一款同时支持 JAVA REST API 和 Apache Dubbo RPC 接口文档生成的工具,在业内率先提出基于JAVA泛型定义推导的理念, 完全基于接口源码来分析生成接口文档,不采用任何注解侵入到业务代码中。 你只需要按照java-doc…

如何对Linux系统进行基准测试4工具Sysbench

Sysbench简介 Sysbench是一款多用途基准测试工具,可对CPU、内存、I/O甚至数据库性能进行测试。它是一个基本的命令行工具,提供了直接、简便的系统测试方法。github地址:https://github.com/akopytov/sysbench 。主要功能:CPU: 衡量CPU执行计算密集型任务的能力。 内存: 衡量…

containerd 容器基础环境组件的搭建

1 基础环境说明(1)本次所有部署软件版本说明软件名称 版本号操作系统内核(后续升级为 lt-5.4.278)CentOS 7.9.2009 (3.10.0-1160.el7) 1c1GB 20GBCentOS-7-x86_64-Minimal-2009.isocontainerd v1.6.6cfssl v1.6.1cni v1.1.1crictl v1.24.2nerdctl 1.7.6buildkit v0.14.1(2)系统…

【JavaScript】聊一聊js中的浅拷贝与深拷贝与手写实现

什么是深拷贝与浅拷贝?深拷贝与浅拷贝是js中处理对象或数据复制操作的两种方式。‌在聊深浅拷贝之前咱得了解一下js中的两种数据类型:前言 什么是深拷贝与浅拷贝?深拷贝与浅拷贝是js中处理对象或数据复制操作的两种方式。‌在聊深浅拷贝之前咱得了解一下js中的两种数据类型:…

分页查询及其拓展应用案例

分页查询 分页查询是处理大量数据时常用的技术,通过分页可以将数据分成多个小部分,方便用户逐页查看。SQLAlchemy 提供了简单易用的方法来实现分页查询。 本篇我们也会在最终实现这样的分页效果:1. 什么是分页查询 分页查询是将查询结果按照一定数量分成多页展示,每页显示固…

delphi dev cxgrid 列绑定Richedti 支持过滤

默认是不支持过滤的,这里需要改到内部的一些源码文件。 先说思路: 1.要让列支持过滤需要重载richedit类的 GetSupportedOperations,typeTcxRichEditProperties = class(cxRichEdit.TcxRichEditProperties)publicfunction GetSupportedOperations: TcxEditSupportedOperation…

《项目管理》-笔记1

PMBOK解读 1.1项目和项目管理 项目:项目是为创造独特的产品、服务或成果而进行的临时性工作。 项目管理:在项目的活动中运用知识、技术、工具、技巧,以满足项目要求。 1.2十大知识领域 (1)项目整合管理 项目整合管理包括为识别、定义、组合、统一和协调各项目管理过程组的各…

Cilium Ingress 特性(转载)

Cilium Ingress 特性Cilium Ingress 特性(转载) 一、环境信息主机 IPubuntu 10.0.0.234软件 版本docker 26.1.4helm v3.15.0-rc.2kind 0.18.0kubernetes 1.23.4ubuntu os Ubuntu 22.04.6 LTSkernel 5.15.0-106二、Cilium Ingress 流程图Cilium 现在提供了开箱即用的 Kuberne…

N1盒子挂载阿里云盘-Alist工具

Markdown Example.centered-text { text-align: center; font-size: 40px; font-family: "Times New Roman", Georgia, serif }N1盒子挂载阿里云盘安装Alist手动安装 参考:官方文档step 1step 2配置-启动step 3打开web网页:http://192.168.1.254:5244/ 登录界面、拉…

WindowsLinux搭建frp内网穿透(自用)

Linux服务器搭建服务端 1、下载官方frp包,软件是开源的,下载链接: https://github.com/fatedier/frp/releases 根据自己的版本需求,自行下载对应的版本号,本文章以0.37版本为例wget -c https://github.com/fatedier/frp/releases/download/v0.37.1/frp_0.37.1_linux_amd64…

2024-07-13:用go语言,给定一个从0开始的长度为n的整数数组nums和一个从0开始的长度为m的整数数组pattern,其中pattern数组仅包含整数-1、0和1。 一个子数组nums[i.

2024-07-13:用go语言,给定一个从0开始的长度为n的整数数组nums和一个从0开始的长度为m的整数数组pattern,其中pattern数组仅包含整数-1、0和1。 一个子数组nums[i..j]的大小为m+1,如果满足以下条件,则我们称该子数组与模式数组pattern匹配: 1.若pattern[k]为1,则nums[i+…