本文介绍 ONNX里面的元运算

ONNX里面的元运算

This article was original written by Jin Tian, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat: jintianiloveu

其实这个有点标题党,但这篇文章要做的事情,还是很有意义的,我们很多时候再用onnx转模型的时候,遇到的都不是一个层的问题,而往往是某个node不对,要么缺少attribute,要么你的后端不支持,那么这个node到底是如何运算的你就需要心知肚明了,就好像这个node是你写的一样,这个我就叫做unit onnx, 将一个大的东西拆解,拆解到内裤都不剩,看看这里面到底卖的什么药。

onnx元运算第一步做什么?简单,我们拆解一个很简单的node: Slice.

元运算

先看一下Slice在onnx标准中的定义:

Produces a slice of the input tensor along multiple axes. Similar to numpy:

说白了,就是根据你指定的idx选择不同的目标,用numpy来表示就是:

1
2
3
>>> x = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> x[1:7:2]
array([1, 3, 5])

这个Node可以说非常简单了,但是如果是你,你会怎么来实现呢?先不着急实现,在不同的onnx opset中,对Slice的定义其实是有区别的:

  • opset 10:

    opset10中的slice inputs要求3-5

    • data : T

      Tensor of data to extract slices from.

    • starts : Tind

      1-D tensor of starting indices of corresponding axis in axes

    • ends : Tind

      1-D tensor of ending indices (exclusive) of corresponding axis in axes

    • axes (optional) : Tind

      1-D tensor of axes that starts and ends apply to.

    • steps (optional) : Tind

      1-D tensor of slice step of corresponding axis in axes. Default to 1.

  • opset9:

    opset10以前版本中的slice,还有attribute属性:

    attritbute:

    • axes : list of ints

      Axes that starts and ends apply to. It's optional. If not present, will be treated as [0, 1, …, len(starts) - 1].

    • ends : list of ints (required)

      Ending indices (exclusive) of corresponding axis in axes`

    • starts : list of ints (required)

      Starting indices of corresponding axis in axes

    inputs:

    • data : T

      Tensor of data to extract slices from.

    outputs:

    • output : T

      Sliced data tensor.

总结来说,在opset10以前的版本,这些node导出来的时候,他们的属性值是不一样的,比如opset9,在attributes属性下,获取axes, starts, ends这些参数,而在opset10,直接通过inputs获取即可,其中还增加了steps参数,而这个还是可选的。

所以当你在打印一个onnx里面Slice层的时候是这样的:

input: "3020"
input: "3029"
input: "3030"
input: "3028"
output: "3031"
name: "in: 3020;3029;3030;3028. out: 3031"
op_type: "Slice"

那么这个opset是10或者10以上,但假如你将这个onnx转到trt,得到了如下错误:

onnx-tensorrt/builtin_op_importers.cpp:1803 In function importSlice:
[8] Assertion failed: input_name == "axes" || input_name == "steps"

这是啥意思呢?错误出在这个地方:

auto const& input_name = node.input(3);
ASSERT(input_name == "axes" ||  input_name == "steps", ErrorCode::kUNSUPPORTED_NODE);

上面每一个input都是有名字的,我们打印这个input的名字来看一下?难道说,上面的input应该是 data, axes, steps? 难道不应该是上一个Node的名字吗?这个就有点纠结了.

有点跑题了,继续说我们的ONNX元运算。假如我们要保存一个很小的ONNX graph,小到只有两个Node,我们应该怎么做呢?其实也很简单:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""

unit onnx node to tiny ones
"""
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
import numpy as np

node_conv = onnx.helper.make_node(
    'Abs',
    inputs=['x'],
    outputs=['123'],
    name='122'
)

node_def = onnx.helper.make_node(
    'Slice',
    inputs=['x', '123', '123', '123', '123'],
    outputs=['y'],
    name='123'
)

x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [20, 10, 5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 10])
W = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 10])
starts = helper.make_tensor_value_info(
    'starts', TensorProto.INT64, [0, 0])
ends = helper.make_tensor_value_info(
    'ends', TensorProto.INT64, [3, 10])
axes = helper.make_tensor_value_info(
    'axes', TensorProto.INT64, [0, 1])
steps = helper.make_tensor_value_info(
    'steps', TensorProto.INT64, [1, 1])

graph_def = helper.make_graph(
    [node_conv, node_def],
    'testmodel',
    [x],
    [y],
)

model_def = helper.make_model(graph_def, producer_name='fucking company')
onnx.checker.check_model(model_def)
onnx.save(model_def, 'testmodel.onnx')

这里我们定义了一张图,里面其实只有两个node,一个是Abs,一个是Slice,但是实际上他们的运算没有任何意义。目的是让大家知道,如何保存这么一个元图。