PyTorch 内 LibTorch/TorchScript 的使用

PyTorch 内 LibTorch/TorchScript 的使用

  • 1. .pt .pth .bin .onnx 格式
    • 1.1 模型的保存与加载到底在做什么?
    • 1.2 为什么要约定格式?
    • 1.3 格式汇总
      • 1.3.1 .pt .pth 格式
      • 1.3.2 .bin 格式
      • 1.3.3 直接保存完整模型
      • 1.3.4 .onnx 格式
      • 1.3.5 jit.trace
      • 1.3.6 jit.script
    • 1.4 总结
  • 2. TorchScript 的转换
    • 2.1 jit trace 注意事项
    • 2.2 jit trace 验证技巧
    • 2.3 混合使用 trace 和 script
    • 2.4 trace 和 script 的性能
    • 2.5 总结
  • 3. LibTorch 的使用
    • 3.1 LibTorch 的链接
    • 3.2 接口和实现

Reference:

  1. [Pytorch].pth转.pt文件
  2. Pytorch格式 .pt .pth .bin .onnx 详解
  3. pytorch 基于tracing/script方式转ONNX

1. .pt .pth .bin .onnx 格式

1.1 模型的保存与加载到底在做什么?

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

  • 模型代码
    1. 包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
    2. 包含了我们如何定义的训练过程,包括epoch batch_size等参数;
    3. 包含了我们如何加载数据和使用;
    4. 包含了我们如何测试评估模型。
  • 模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
    1. 包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
    2. 包含optimizer.state_dict(),这是模型的优化器中的参数;
    3. 包含我们其他参数信息,如epoch/batch_size/loss等。
  • 数据集
    1. 包含了我们训练模型使用的所有数据;
    2. 可以提示对方如何去准备同样格式的数据来训练模型。
  • 使用文档
    1. 根据使用文档的步骤,每个人都可以重现模型;
    2. 包含了模型的使用细节和我们相关参数的设置依据等信息。

可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。

现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。

1.2 为什么要约定格式?

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。

torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的 README.md 才知道。而在使用 torch.load() 方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于 torch.load() 方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取

torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.

1.3 格式汇总

格式解释适用场景可对应的后缀
.pt 或 .pthPyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型.pt 或 .pth
.bin一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据需要将 PyTorch 模型转换为通用的二进制格式的场景.bin
ONNX一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景.onnx
TorchScriptPyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.git.trace 或 torch.git.script 函数将 PyTorch 模型转换为 TorchScript 格式需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景.pt 或 .pth

1.3.1 .pt .pth 格式

一个完整的 PyTorch 模型文件,包含了如下参数:

  • model_state_dict:模型参数
  • optimizer_state_dict:优化器的状态
  • epoch:当前的训练轮数
  • loss:当前的损失值

下面是一个 .pt 文件的保存和加载示例(注意,后缀也可以是 .pth):

  • .state_dict():包含所有的参数和持久化缓存的字典,model 和 optimizer 都有这个方法
  • torch.save():将所有的组件保存到文件中

模型保存

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 初始化优化器loss = nn.MSELoss()# 初始化损失函数PATH = "model.pth" # 保存路径# 保存模型
torch.save({'epoch': 10,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, PATH)

netron 可得:
在这里插入图片描述

模型加载

import torch
import torch.nn as nn# 定义同样的模型结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return x# 加载模型
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
PATH = "model.pth"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

1.3.2 .bin 格式

.bin 文件是一个二进制文件,可以保存 PyTorch 模型的参数和持久化缓存。.bin 文件的大小较小,加载速度较快,因此在生产环境中使用较多。

下面是一个.bin文件的保存和加载示例(注意:也可以使用 .pt .pth 后缀—后缀无意义):
保存模型

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()
# 保存参数到.bin文件
torch.save(model.state_dict(), PATH)

加载模型

import torch
import torch.nn as nn# 定义相同的模型结构
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return x# 加载.bin文件
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

1.3.3 直接保存完整模型

可以看出来,我们在之前的保存方式中,都是保存了 .state_dict(),但是没有保存模型的结构,在其他地方使用的时候,必须先重新定义相同结构的模型(或兼容模型),才能够加载模型参数进行使用,如果我们想直接把整个模型都保存下来,避免重新定义模型,可以按如下操作:
保存模型

PATH = "entire_model.pt"
# PATH = "entire_model.pth"
# PATH = "entire_model.bin"
torch.save(model, PATH)

netron 可得:
在这里插入图片描述

可以看到与上面仅保存参数的方式相比,多了很多信息。

加载模型

model = torch.load("entire_model.pt")
model.eval()

1.3.4 .onnx 格式

上述保存的文件可以通过 PyTorch 提供的 torch.onnx.export 函数转化为ONNX格式,这样可以在其他深度学习框架中使用 PyTorch 训练的模型。转化方法如下:

import torch
import torch.onnx# 将模型保存为.bin文件
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
# torch.save(model.state_dict(), "model.pt")
# torch.save(model.state_dict(), "model.pth")# 将.bin文件转化为ONNX格式
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
# model.load_state_dict(torch.load("model.pt"))
# model.load_state_dict(torch.load("model.pth"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])

加载 ONNX 格式的代码可以参考以下示例代码(注意 ONNX 只能推理不能训练,不包含反向信息的):

import onnx
import onnxruntime# 加载ONNX文件
onnx_model = onnx.load("model.onnx")# 将ONNX文件转化为ORT格式
ort_session = onnxruntime.InferenceSession("model.onnx")# 输入数据
input_data = np.random.random(size=(1, 3)).astype(np.float32)# 运行模型
outputs = ort_session.run(None, {"input": input_data})# 输出结果
print(outputs)

注意,需要安装 onnxonnxruntime 两个 Python 包。此外,还需要使用 numpy 等其他常用的科学计算库。

1.3.5 jit.trace

保存模型

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()PATH = "model_trace.pth"# 保存模型
example = torch.rand(1, 10)
traced_module = torch.jit.trace(model, example)
traced_module.save(PATH)

在这里插入图片描述

1.3.6 jit.script

保存模型

import torch
import torch.nn as nn# 定义一个简单的模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()PATH = "model_script.pth" # 保存路径# 保存模型
scripted_module = torch.jit.script(model)
scripted_module.save(PATH)

netron 可得:
在这里插入图片描述

1.4 总结

综上,PyTorch 可以导出的模型的几种后缀格式,但是模型导出的关键并不是后缀,而是到处时候提供的信息到底是什么,只要知道了模型的 model.state_dict()optimizer.state_dict(),以及相应的epoch batch_size loss等信息,我们就能够重建出模型,至于要导出哪些信息,就取决于你了,务必在 readme.md 中写清楚,导出了哪些信息。

保存场景保存方法文件后缀
整个模型(保存模型结构)model = Net()
torch.save(model, PATH)
.pt .pth .bin
仅模型参数(不保存模型结构)model = Net()
torch.save(model.state_dict(), PATH)
.pt .pth .bin
checkpoints使用model = Net()
torch.save({‘epoch’:10,‘model_state_dict’:model.state_dict(),‘optimizer_state_dict’: optimizer.state_dict(),‘loss’: loss,}, PATH)
.pt .pth .bin
ONNX通用保存model = Net()
model.load_state_dict(torch.load(“model.bin”))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, “model.onnx”, input_names=[“input”], output_names=[“output”])
.onnx
TorchScript 无 Python 环境使用model = Net()
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save(‘model_scripted.pt’)
model = torch.jit.load(‘model_scripted.pt’)
model.eval()
.pt .pth

2. TorchScript 的转换

上文内提到 .pthpt 等价,而且后缀主要用于提示。不过相对来说,PyTorch 的模型文件一般保存为 .pth 文件的更多一点,而 C++ 接口一般读取的是 .pt 文件,因此,C++ 在调用 PyTorch 训练好的模型文件的时候,就需要转换为以 .pt 为代表的 TorchScript 文件,才能够读取。

Script mode 通过 torch.jit.trace 或者 torch.jit.script 来调用。这两个函数都是将 Python 代码转换为 TorchScript 的两种不同的方法。

  • torch.jit.trace:将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个 PyTorch 模型,torch.jit.trace 会跟踪此 input 在 model 中的计算过程,然后将其转换为 Torch 脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。

  • torch.jit.script 直接将 Python 函数(或者一个 Python 模块)通过 Python 语法规则和编译转换为 Torch 脚本。torch.jit.script 更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于 RNN 或者一些具有可变序列长度的模型,使用 torch.jit.script 会更为方便。

在通常情况下,更应该倾向于使用 torch.jit.trace 而不是 torch.jit.script

在模型部署方面,ONNX 被大量使用。而导出 ONNX 的过程,也是 model 进行 torch.jit.trace 的过程,因此这里我们把 torch 的 trace 做稍微详细一点的介绍。

2.1 jit trace 注意事项

为了能够把模型编写的更能够被 jit trace,需要对代码做一些妥协,例如:

  1. 如果 model 中有 DataParallel 的子模块,或者 model 中有将 tensors 转换为 numpy arrays,或者调用了 OpenCV 的函数等,这种情况下,model 不是一个正确的在单个设备上、正确连接的 graph,这种情况下,不管是使用 torch.jit.script 还是 torch.jit.trace 都不能 trace 出正确的 TorchScript 来。

  2. model 的输入输出应该是 Union[Tensor, Tuple[Tensor], Dict[str, Tensor]] 的类型,而且在 dict 中的值,应该是同样的类型。但是对于 model 中间子模块的输入输出,可以是任意类型,例如 dicts of Any, classes, kwargs 以及 Python 支持的都可以。对于 model 输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:

    outputs = model(inputs)   # inputs和outputs是python的类型, 例如dictsor classes
    # torch.jit.trace(model, inputs)  # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
    adapter = TracingAdapter(model, inputs)  # 使用Adapter,将model inputs包装为trace支持的类型
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在以trace成功# Traced model的输出只能是tuple tensors类型:
    flattened_outputs = traced(*adapter.flattened_inputs)
    # 再通过adapter转换为想要的输出类型
    new_outputs = adapter.outputs_schema(flattened_outputs)
    
  3. 一些数值类型的问题。比如下面的代码片段:

    import torch
    a=torch.tensor([1,2])
    print(type(a.size(0)))
    print(type(a.size()[0]))
    print(type(a.shape[0]))
    

    在eager mode下,这几个返回值的类型都是int型。上面代码的输出为:

    <class 'int'>
    <class 'int'>
    <class 'int'>
    

    但是在 trace mode 下,这几个表达式的返回值类型都是 Tensor 类型。因此,有些表达式使用不当,如果在 trace 过程中,一些 shape 表达式的返回值类型是 int 型,那么可能造成这块代码没有被 trace。在代码中,可以通过使用 torch.jit.is_tracing 来检查这块代码在 trace mode 下有没有被执行。

  4. 由于动态的 control flow,造成模型没有被完整的 trace。看下面的例子:

    import torchdef f(x):return torch.sqrt(x) if x.sum() > 0 else torch.square(x)m = torch.jit.trace(f, torch.tensor(3))
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:return torch.sqrt(x)
    

    可以看到 trace 后的 model 只保留了一条分支。因此由于输入造成的 dynamic 的 control flow,trace 后容易出现错误。

    这种情况下,我们可以使用 torch.jit.script 来进行 TorchScript 的转换。

    import torchdef f(x):return torch.sqrt(x) if x.sum() > 0 else torch.square(x)m = torch.jit.script(f)
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:if bool(torch.gt(torch.sum(x), 0)):_0 = torch.sqrt(x)else:_0 = torch.square(x)return _0
    

    在大多数情况下,我们应该使用 torch.jit.trace,但是像上面的这种 dynamic control flow 的情况,我们可以混合使用 torch.jit.tracetorch.jit.script,在后面会进行阐述
    另外在一些 Blog 中,对于 dynamic control flow 的定义是有错误的,例如 if x[0] == 4: x += 1 是 dynamic control flow,但是:

    model: nn.Sequential = ...
    for m in model:x = m(x)
    

    以及:

    class A(nn.Module):backbone: nn.Modulehead: Optiona[nn.Module]def forward(self, x):x = self.backbone(x)if self.head is not None:x = self.head(x)return x
    

    都不是 dynamic control flowdynamic control flow 是由于对输入条件的判断造成的不同分支的执行

  5. trace 过程中,将变量 trace 成了常量。看下面一个例子:

    import torch
    a, b = torch.rand(1), torch.rand(2)def f1(x): return torch.arange(x.shape[0])
    def f2(x): return torch.arange(len(x))print(torch.jit.trace(f1, a)(b))
    # 输出: tensor([0, 1])
    # 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的print(torch.jit.trace(f2, a)(b))
    # 输出:
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    # tensor([0])
    # 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。# 我们打印一下两者的区别
    print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
    # 输出
    # def f1(x: Tensor) -> Tensor:
    #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
    #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _1#  def f2(x: Tensor) -> Tensor:
    #   _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _0# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.# 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1
    

    我们导出 ONNX 的过程,也是进行 torch.jit.trace 的过程,在导出 ONNX 的时候,有时候也会遇到

    TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

    这样的提示信息,这时候要检查一下代码中是不是有可能 trace 过程中,变量会被当做常量的情况,有可能会导致导出的 ONNX 精度异常。

    • 关于 ONNX
      ONNX 默认基于 trace 的方式,运行一次模型,记录下和 tensor 的相关操作。trace 将不会捕获根据输入数据而改变的行为。比如 if 语句,只会记录执行的那一条分支,同样的,for 循环的次数,导出与跟踪运行完全相同的静态图。如果要使用动态控制流导出模型,则需要使用 torch.jit.script
      torch.jit.script:真正的去编译,在 PYTHON 的 AST 语法树做语法分析句法分析。因此可以使用if等动态控制流。返回 ScriptModule。
      torch.onnx.export 在运行时,先判断是否是 SriptModule,如果不是,则进行 torch.jit.trace,因此 export 需要一个随机生成的输入参数。
      import torch.nn as nn
      import torch
      import torch.nn.functional as F
      import cv2
      import numpy as np
      import onnx
      import onnxruntime as ort#from torch.onnx import register_custom_op_symbolic # 私有层支持class test_net(nn.Module):def __init__(self,):super(test_net, self).__init__()#self.model = nn.MaxPool3d(kernel_size=(1,3,3), stride=(2,1,2))#self.model = nn.AvgPool3d(kernel_size=(1,3,3), stride=(2,1,2)) #-> AveragePoolself.model = nn.Conv3d(3,64,kernel_size=(1,3,3), stride=(2,1,2))self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.relu66 = nn.ReLU6()def forward(self, x):out1 = self.model(x)f_mean = torch.mean(out1) # -> ReduceMean#f_mean = torch.mean(out1).item() # item()会将f_mean转换为常数 会丢失 mean操作# script模式转onnx会报错 torch._C._jit_pass_erase_number_types(graph) RuntimeError: Unknown number type: Scalarout2 = torch.div(out1, f_mean)#outlist = list()#for i in range(3):#    if i in [0]:#        #outlist.append(nn.ReLU()(out2))  # script模式下报错 类对象要提前构建#        outlist.append(self.relu(out2))   # scrip_to_onnx 报错 找不到25 BUG#    else:#        #outlist.append(nn.ReLU6()(out2))#        outlist.append(self.relu6(out2))#out = torch.cat(outlist)# 上述 for循环构图在tracing模式下会展开# script模式下难转换,报错# 手动平铺o1 = self.relu(out2)o2 = self.relu6(out2)#o3 = self.relu6(out2)   # script模式下被优化掉了 BUGo3 = self.relu66(out2)   # script模式下被优化掉了out = torch.cat([o1,o2,o3])return out# 模型构建和运行
      imgh, imgw = 24, 94
      net = test_net().eval() # 若存在batchnorm、dropout层则一定要eval() 使得BN层参数不更新
      dummy_input = torch.randn(1,3,3,imgh, imgw)# n c d h w
      torch_out = net.forward(dummy_input)# net(dummy_input)# export onnx
      dynamic_axes = {'input': {3: 'height', 4: 'width'}, 'output': {3: 'height', 4: 'width'}} # 配置动态分辨率
      onnx_pth = "test-conv-relu.onnx"# 传入原model,采用默认trace方式捕获模型,需要运行模型
      torch.onnx.export(net, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes)
      # 也可传入 scriptModule
      #net_script= torch.jit.script(test_net())
      # 需要外加配置 example_outputs,用来获取输出的shape和dtype,无需运行模型
      #torch.onnx.export(net_script, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes, example_outputs=[torch_out])# ort run
      oxx_m = ort.InferenceSession(onnx_pth)
      onnx_blob = dummy_input.data.numpy()
      onnx_out = oxx_m.run(None, {'input':onnx_blob})[0]dummy_input2 = torch.randn(1,3,3,imgh*2, imgw*2)
      onnx_blob2 = dummy_input2.data.numpy()
      onnx_out2 = oxx_m.run(None, {'input':onnx_blob2})[0]# opencv run
      #cv_m = cv2.dnn.readNet(onnx_pth)print('mean diff = ', np.mean(onnx_out - torch_out.data.numpy()))
      

    除了 len 会导致 trace 错误,其他几个也会导致 trace 出现问题:

    • .item() 会在 trace 过程中将 tensors 转为 int/float

    • 任何将 torch 类型转为 numpy/python 类型的代码

    • 一些有问题的算子,例如 advanced indexing

    • torch.jit.trace 不会对传入的 device 生效

      import torch
      def f(x):return torch.arange(x.shape[0], device=x.device)
      m = torch.jit.trace(f, torch.tensor([3]))
      print(m.code)
      # 输出
      # def f(x: Tensor) -> Tensor:
      #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
      #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
      #   return _1
      print(m(torch.tensor([3]).cuda()).device)
      # 输出:device(type='cpu')
      

      trace 不会对传入的 cuda device 生效。

2.2 jit trace 验证技巧

为了保证trace的正确,我们可以通过一下的一些方法来尽量保证 trace 后的模型不会出错:
1.注意 warnings 信息。类似这样的:

TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个 warning 等级。
2. 做单元测试。需要验证一下 eager mode 的模型输出与 trace 后的模型输出是否一致。

assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
  1. 避免一些特殊的情况。例如下面的代码:
if x.numel() > 0:output = self.layers(x)
else:output = torch.zeros((0, C, H, W))  # 会创建一个空的输出

避免一些特殊情况比如空的输入输出之类的。

  1. 注意shape的使用。前面提到,tensor.size()在trace过程中会返回Tensor类型的数据,Tensor类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:
  • 使用 torch.size(0) 来代替 len(tensor),因为 torch.size(0) 返回的是 Tensor,len(tensor) 返回的是 int。对于自定义类,实现一个 .size 方法或者使用 .__len__() 方法来代替 len() ,例如这个例子
  • 不要使用 int() 或者 torch.as_tensor 来转换 size 的类型,因为这些操作也会被视为常量。
  1. 混合 tracing 和 scripting 方法。可以使用 torch.jit.script 来转换一些 torch.jit.trace 不能搞定的小的代码片段,混合使用 tracing 和 scripting,基本可以解决所有的问题。

2.3 混合使用 trace 和 script

trace 和 script 都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用 torch.jit.trace,必要时才使用 torch.jit.script

  1. 在使用 torch.jit.trace 时,使用 @script_if_tracing 装饰器可以让被装饰的函数使用 script 方式进行编译

    def forward(self, ...):# ... some forward logic@torch.jit.script_if_tracingdef _inner_impl(x, y, z, flag: bool):# use control flow, etc.return ...output = _inner_impl(x, y, z, flag)# ... other forward logic
    

    但是使用 @script_if_tracing 时,需要保证函数中没有 PyTorch 的 modules,如果有的话,需要做一些修改,例如下面的:

    # 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
    if x.numel() > 0:x = preprocess(x)output = self.layers(x)
    else:# Create empty outputsoutput = torch.zeros(...)
    

    这里需要做如下修改:

    # 需要将self.layers移出if判断,这时候可以用@script_if_tracing
    if x.numel() > 0:x = preprocess(x)
    else:# Create empty inputsx = torch.zeros(...)
    # 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
    output = self.layers(x)
    
  2. 合并多次 trace 的结果
    使用 torch.jit.script 生成的模型相比使用 torch.jit.trace 有两个好处:

    • 可以使用条件控制流,例如模型中使用一个 bool 值来控制 forward 的 flow,在 traced modules 里面是不支持的
    • 使用 traced module,只能有一个 forward() 函数,但是使用 scripted module,可以有多个前向计算的函数
    class Detector(nn.Module):do_keypoint: booldef forward(self, img):box = self.predict_boxes(img)if self.do_keypoint:kpts = self.predict_keypoint(img, box)@torch.jit.exportdef predict_boxes(self, img): pass@torch.jit.exportdef predict_keypoint(self, img, box): pass
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):def forward(self, img, do_keypoint: bool):if do_keypoint:return self[0](img)else:return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):def forward(self, img, do_keypoint: bool):if do_keypoint:return self[0](img)else:return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

2.4 trace 和 script 的性能

trace 总是会比 script 生成一样或者更简单的计算图,因此性能会更好一些。因为 script 会完整的表达 Python 代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:

class A(nn.Module):def forward(self, x1, x2, x3):z = [0, 1, 2]xs = [x1, x2, x3]for k in z: x1 += xs[k]return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   z = [0, 1, 2]
#   xs = [x1, x2, x3]
#   x10 = x1
#   for _0 in range(torch.len(z)):
#     k = z[_0]
#     x10 = torch.add_(x10, xs[k])
#   return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   x10 = torch.add_(x1, x1)
#   x11 = torch.add_(x10, x2)
#   return torch.add_(x11, x3)

2.5 总结

trace 具有明显的局限性:这篇文章的大部分篇幅都在谈论 trace 的局限性以及如何解决这些问题。实际上,这正是 trace 的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。

相反,script 更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复 script 的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。

trace 和 script 都会影响代码的编写方式,但 trace 因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:

  • 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
  • 它需要修改一些代码才能通用(例如在 trace 时添加一些 script),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。

3. LibTorch 的使用

在得到所需模型后,可以尝试在 C++ 环境下使用得到的模型,这里就用到了 LibTorch。

3.1 LibTorch 的链接

结合自己环境的 CUDA 版本,去官网下载对应版本的 libTorch。例如 CUDA 版本为 11.1,则需要在下载地址中找到 libtorch-cxx11-abi-shared-with-deps-1.9.1%2Bcu111.zip 进行下载。

链接进需要再 cmake 内加上这几行即可:

set(TORCH_PATH "/home/yj/libtorch/share/cmake/Torch")
message("TORCH_PATH set to: ${TORCH_PATH}")
set(Torch_DIR ${TORCH_PATH})find_package(Torch REQUIRED)
message(STATUS "Torch version is: ${Torch_VERSION}")# <target> is your target's name
target_link_libraries(<target> ${TORCH_LIBRARIES}
)

3.2 接口和实现

  1. 头文件引入 :

    #include <torch/script.h>
    #include <torch/torch.h>
    
  2. 加载模型

    module = torch::jit::load(PATH);
    
  3. 函数实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/417672.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AtCoder Regular Contest 115 E. LEQ and NEQ(容斥 单调栈优化dp)

题目 n(n<5e5)个数&#xff0c;第i个数ai(1<ai<1e9) 构造一个序列b&#xff0c;要求bi∈[1,ai]&#xff0c;且b[i]不等于b[i1] 求方案数&#xff0c;答案对998244353取模 思路来源 洛谷题解Xu_brezza 一模一样的cf题&#xff1a; Codeforces Round 759 (Div. 2…

解析智能酒精壁炉不完全燃烧的成因及潜在问题

解析智能酒精壁炉不完全燃烧的成因及潜在问题 智能酒精壁炉作为一种环保、高效、现代化的取暖工具&#xff0c;其采用酒精作为燃料进行燃烧&#xff0c;但在一些情况下&#xff0c;可能会出现酒精燃烧不完全的问题。下面将深入探讨这一现象的成因以及可能引发的问题。 成因分析…

php isset和array_key_exists区别

在PHP中&#xff0c;可以使用array_key_exists函数或者isset函数来判断一个字典&#xff08;关联数组&#xff09;中是否存在某个下标。 使用 array_key_exists 函数: $myArray array("key1" > "value1", "key2" > "value2",…

基于动态顺序表实现通讯录项目

本文中&#xff0c;我们将使用顺序表的结构来完成通讯录的实现。 我们都知道&#xff0c;顺序表实际上就是一个数组。而使用顺序表来实现通讯录&#xff0c;其内核是将顺序表中存放的数据类型改为结构体&#xff0c;将联系人的信息存放到结构体中&#xff0c;通过对顺序表的操…

【数据结构与算法】1.时间复杂度和空间复杂度

&#x1f4da;博客主页&#xff1a;爱敲代码的小杨. ✨专栏&#xff1a;《Java SE语法》 ❤️感谢大家点赞&#x1f44d;&#x1f3fb;收藏⭐评论✍&#x1f3fb;&#xff0c;您的三连就是我持续更新的动力❤️ &#x1f64f;小杨水平有限&#xff0c;欢迎各位大佬指点&…

架构的演进

1.1单体架构 单体架构也称之为单体系统或者是单体应用。就是一种把系统中所有的功能、模块耦合在一个应用中的架构方式。 存在的问题&#xff1a; 代码耦合&#xff1a;模块的边界模糊、依赖关系不清晰&#xff0c;整个项目非常复杂&#xff0c;每次修改代码都心惊胆战迭代困…

linux基础学习(5):yum

yum是为了解决rpm包安装依赖性而产生的一种安装工具 1.yum源 1.1配置文件位置 yum源的配置文件在/etc/yum.repos.d/中 *Base源是网络yum源&#xff0c;也就是需要联网才能使用的yum源。默认情况下&#xff0c;系统会使用Base源 *Media源是光盘yum源&#xff0c;是本地yum源…

论文阅读笔记AI篇 —— Transformer模型理论+实战 (三)

论文阅读笔记AI篇 —— Transformer模型理论实战 &#xff08;三&#xff09; 第三遍阅读&#xff08;精读&#xff09;3.1 Attention和Self-Attention的区别&#xff1f;3.2 Transformer是如何进行堆叠的&#xff1f;3.3 如何理解Positional Encoding&#xff1f;3.x 文章涉及…

【数据结构】详谈队列的顺序存储及C语言实现

循环队列及其基本操作的C语言实现 前言一、队列的顺序存储1.1 队尾指针与队头指针1.2 基本操作实现的底层逻辑1.2.1 队列的创建与销毁1.2.2 队列的增加与删除1.2.3 队列的判空与判满1.2.4 逻辑的局限性 二、循环队列2.1 循环队列的实现逻辑一2.2 循环队列的实现逻辑二2.3 循环队…

CodeReview 小工具

大家开发中有没有遇到一个版本开发的非常杂&#xff0c;开发很多个项目&#xff0c;改动几周后甚至已经忘了自己改了些什么&#xff0c;领导要对代码review的时候&#xff0c;理不清楚自己改过的代码&#xff0c;只能将主要改动的大功能过一遍。这样就很容易造成review遗漏&…

软件测试(一)

软件测试——测试用例 &#x1f3d0;测试用例要素&#xff08;四个重要的要素&#xff09;&#x1f3d0;测试用例的设计方法&#x1f3c0;基于需求的设计方法&#x1f3c0;等价类&#x1f3c0;边界值&#x1f3c0;判定表&#x1f3c0;正交表法&#x1f3c0;场景设计法&#x1f…

Prompt高级技巧:Few-Shots、COT、SC、TOT、Step-Back

CRISPE框架 如图所示。所谓CRISPE框架&#xff0c;指的是&#xff1a; CR&#xff1a;Capacity and Role&#xff08;能力与角色&#xff09;。你希望 ChatGPT 扮演怎样的角色。I&#xff1a;Insight&#xff08;洞察&#xff09;&#xff0c;背景信息和上下文。S:&#xff08…