Skip to content

Commit

Permalink
refactor onnx importer to do more static imports by constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Feb 9, 2021
1 parent 9b37040 commit 9303222
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 69 deletions.
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,12 @@ def infer_type(node, mod=None):
return ret


def fold_constant(node, mod=None):
if mod is None:
mod = IRModule.from_expr(node)
return _transform.FoldConstantExpr(node, mod)


def infer_channels(inputs, transpose=False):
"""A hack for getting 'channels' or 'units' since caffe2 does not provide
these attributes. We check the shape of weights provided to get the number.
Expand Down
162 changes: 93 additions & 69 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
from .common import get_relay_op, new_var, infer_shape, infer_channels, fold_constant
from .common import infer_type, get_name


Expand Down Expand Up @@ -364,7 +364,7 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d
),
dtype="int64",
)
shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim])
shape = _op.strided_slice(shape_of(data, dtype="int64"), [2], [ndim])
# get input shape

# set up integer constants
Expand Down Expand Up @@ -545,19 +545,23 @@ class MatMul(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs))
# Need to check input shape as batch matmul must be supported.
a_shape = _op.shape_of(inputs[0])
a_shape = shape_of(inputs[0])
a_rank = infer_shape(a_shape)[0]
b_shape = _op.shape_of(inputs[1])
b_shape = shape_of(inputs[1])
b_rank = infer_shape(b_shape)[0]
# When performing a batch matmul, we need to properly handle N-dim shapes.
if a_rank > 2 or b_rank > 2:

def flatten_to_3d(x, x_shape):
ndims = infer_shape(x_shape)[0]
newshape = _op.concatenate(
[_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 2], [ndims])], 0
[
_expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype),
_op.strided_slice(x_shape, [ndims - 2], [ndims]),
],
0,
)
out = _op.reshape(x, newshape)
out = _op.reshape(x, fold_constant(newshape))
return out

# Convert a and b into 3 dimensional tensors.
Expand Down Expand Up @@ -598,7 +602,7 @@ def flatten_to_3d(x, x_shape):
],
0,
)
return _op.reshape(output, final_shape)
return _op.reshape(output, fold_constant(final_shape))
# Otherwise a simple dense op will get the job done.
input_1_t = _op.transpose(inputs[1], axes=(1, 0))
return _op.nn.dense(inputs[0], input_1_t)
Expand Down Expand Up @@ -646,7 +650,7 @@ def _impl_v11(cls, inputs, attr, params):
multiplier = _op.concatenate(
[_expr.const([1, 1], dtype="int64"), _expr.const(list(strides), dtype="int64")], axis=0
)
total_output_shape = multiplier * _op.shape_of(data, dtype="int64")
total_output_shape = multiplier * shape_of(data, dtype="int64")
# Add extra dimensions from kernel size and stride mismatch
total_output_shape += _op.concatenate(
[_expr.const([0, 0], "int64"), _expr.const(list(kernel_shape), "int64")], axis=0
Expand Down Expand Up @@ -792,11 +796,11 @@ def _impl_v2(cls, inputs, attr, params):
def _impl_v11(cls, inputs, attr, params):
pads = inputs[1]
if len(inputs) == 3:
value = _op.take(inputs[2], _op.const(0))
value = fold_constant(_op.take(inputs[2], _op.const(0)))
else:
value = 0

pad_width_expr = _op.transpose(_op.reshape(pads, (2, -1)))
pad_width_expr = fold_constant(_op.transpose(_op.reshape(pads, (2, -1))))
pad_mode = attr.get("mode", b"constant").decode("utf-8")

if not pad_mode in ["constant", "edge", "reflect"]:
Expand All @@ -823,7 +827,7 @@ class Prelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs))
input_shape = _op.shape_of(inputs[0])
input_shape = shape_of(inputs[0])
alpha = _op.broadcast_to_like(inputs[1], inputs[0])
alpha = _op.reshape(alpha, [-1])
output = _op.nn.prelu(_op.reshape(inputs[0], [-1]), alpha, axis=0)
Expand Down Expand Up @@ -875,7 +879,6 @@ class DepthToSpace(OnnxOpConverter):

@classmethod
def _impl_v11(cls, inputs, attr, params):

block_size = int(attr["blocksize"])
mode = attr.get("mode", b"DCR").decode("utf-8")
return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)
Expand Down Expand Up @@ -1015,8 +1018,9 @@ def _impl_v9(cls, inputs, attr, params):
scales = params[inputs[1].name_hint].asnumpy()
else:
scales = inputs[1]

if not isinstance(scales, _expr.Call):
if isinstance(scales, _expr.Constant):
scales = list(scales.data.asnumpy())
if not isinstance(scales, _expr.Expr):
assert scales[0] == 1.0 and scales[1] == 1.0

mode = attr.get("mode")
Expand Down Expand Up @@ -1067,12 +1071,19 @@ def _impl_v9(cls, inputs, attr, params):
return out


def shape_of(x, dtype="int64"):
ttype = infer_type(x).checked_type
if not _ty.is_dynamic(ttype):
return _expr.const([i for i in ttype.shape], dtype)
return _op.shape_of(x, "int64")


class Shape(OnnxOpConverter):
"""Operator converter for Shape."""

@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.shape_of(inputs[0], "int64")
return shape_of(inputs[0], "int64")


class Cast(OnnxOpConverter):
Expand Down Expand Up @@ -1182,7 +1193,7 @@ def _impl_v10(cls, inputs, attr, params):

# Update the starts and ends according to axes if required.
if axes is not None:
data_shape = _op.shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
starts = _op.scatter(
_op.const([0] * data_rank, dtype=infer_type(starts).checked_type.dtype),
axes,
Expand All @@ -1201,7 +1212,9 @@ def _impl_v10(cls, inputs, attr, params):
if steps is None:
steps = _op.const([1] * data_rank, dtype=infer_type(starts).checked_type.dtype)

return _op.strided_slice(inputs[0], starts, ends, steps)
return _op.strided_slice(
inputs[0], fold_constant(starts), fold_constant(ends), fold_constant(steps)
)


class Gather(OnnxOpConverter):
Expand Down Expand Up @@ -1509,6 +1522,20 @@ def _impl_v9(cls, inputs, attr, params):
return output


class Constant(OnnxOpConverter):
"""Operator converter for ConstantOfShape."""

@classmethod
def _impl_v9(cls, inputs, attr, params):
if "value" in attr:
np_value = get_numpy(attr.pop("value"))
dtype = np_value.dtype.name
value = _expr.const(np_value, dtype)
return value
else:
raise "No Value in Constant"


class Sign(OnnxOpConverter):
"""Operator converter for Sign."""

Expand Down Expand Up @@ -1569,12 +1596,14 @@ def _impl_v9(cls, inputs, attr, params):
# to that shape.
max_rank = max(ranks)
max_rank_idxs = [i for i, x in enumerate(ranks) if x == max_rank]
broadcast_shape = _op.shape_of(inputs[max_rank_idxs[0]])
broadcast_shape = shape_of(inputs[max_rank_idxs[0]])
# If two or more inputs have the same rank, compute the broadcast
# shape by taking the maximum value of each dimensions.
if len(max_rank_idxs) > 1:
for idx in max_rank_idxs:
broadcast_shape = _op.maximum(broadcast_shape, _op.shape_of(inputs[idx]))
broadcast_shape = _op.maximum(broadcast_shape, shape_of(inputs[idx]))

broadcast_shape = fold_constant(broadcast_shape)

condition = _op.broadcast_to(inputs[0], broadcast_shape)
x = _op.broadcast_to(inputs[1], broadcast_shape)
Expand All @@ -1596,7 +1625,7 @@ class Expand(OnnxOpConverter):
@classmethod
def _impl_v8(cls, inputs, attr, params):
dtype = infer_type(inputs[1]).checked_type.dtype
in_shape = _op.shape_of(inputs[0], dtype=dtype)
in_shape = shape_of(inputs[0], dtype=dtype)
shape = inputs[1]

# Currently 'op.broadcast_to' expect the rank of the given 'shape'
Expand Down Expand Up @@ -1645,7 +1674,7 @@ def expand_shape(in_shape, shape):
new_shape = _op.maximum(in_shape, shape)
return new_shape

shape = expand_shape(in_shape, shape)
shape = fold_constant(expand_shape(in_shape, shape))
return _op.broadcast_to(inputs[0], shape=shape)


Expand Down Expand Up @@ -1920,10 +1949,10 @@ def _impl_v10(cls, inputs, attr, params):
)

scale = inputs[1]
size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale

size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
size = _op.cast(size, "int64")
layout = "NCHW" # ONNX assumes NCHW layout
out_size = _op.strided_slice(size, [2], [4])
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
return _op.image.resize(inputs[0], out_size, layout, method, "asymmetric")

@classmethod
Expand All @@ -1947,7 +1976,8 @@ def _impl_v11(cls, inputs, attr, params):
size = inputs[3]
else:
assert len(scale_shape) != 0, "One of scale or size should be passed."
size = _op.cast(_op.shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
size = _op.cast(shape_of(inputs[0]), infer_type(scale).checked_type.dtype) * scale
size = _op.cast(size, "int64")

coord_trans = attr.get("coordinate_transformation_mode")
if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
Expand All @@ -1961,7 +1991,7 @@ def _impl_v11(cls, inputs, attr, params):
"Unsupported coordinate_transformation_mode: {}".format(coord_trans)
)
layout = "NCHW" # ONNX assumes NCHW layout
out_size = _op.strided_slice(size, [2], [4])
out_size = fold_constant(_op.strided_slice(size, [2], [4]))
return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)


Expand Down Expand Up @@ -2224,7 +2254,7 @@ def body_fn(*loop_inputs):
expand_scan = _op.expand_dims(new_scan, axis=0)
# For non scalar outputs we need to broadcast the initial value.
if rank > 0:
new_scan_shape = _op.shape_of(new_scan, dtype=iter_dtype)
new_scan_shape = shape_of(new_scan, dtype=iter_dtype)
scan_broadcast = _op.concatenate(
[_op.reshape(loop_count, [1]), new_scan_shape], axis=0
)
Expand Down Expand Up @@ -2446,9 +2476,9 @@ def _first_body(
# partially prepare ONNX output format by labeling batch_num, class_id
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
batch_num = _op.broadcast_to(batch_num, _op.shape_of(nms_ret[0], dtype="int64"))
batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64"))
batch_num = _op.expand_dims(batch_num, -1, 1)
class_num = _op.broadcast_to(i, _op.shape_of(nms_padded_out, dtype="int64"))
class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64"))
new_onnx_out = _op.concatenate(
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
)
Expand Down Expand Up @@ -2548,7 +2578,7 @@ def _outer_body(i, B, C, onnx_out, nms_size_out, out):
)

# Call the first loop, perform NMS
B, C, S = _op.split(_op.shape_of(scores, dtype="int64"), 3)
B, C, S = _op.split(shape_of(scores, dtype="int64"), 3)
init_count = _op.const(np.array([0]), dtype="int64")
init_onnx_out = _op.const([1], dtype="int64")
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
Expand Down Expand Up @@ -2595,6 +2625,7 @@ def _get_convert_map(opset):
"ThresholdedRelu": ThresholdedRelu.get_converter(opset),
"ScaledTanh": ScaledTanh.get_converter(opset),
"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
"Constant": Constant.get_converter(opset),
"ConstantOfShape": ConstantOfShape.get_converter(opset),
# 'GivenTensorFill'
"FC": AttrCvt("dense", ignores=["axis", "axis_w"]),
Expand Down Expand Up @@ -2827,12 +2858,16 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
self._params[init_tensor.name] = self._parse_array(init_tensor)
self._nodes[init_tensor.name] = new_var(
init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype,
)
if freeze_params:
array = self._parse_array(init_tensor)
self._nodes[init_tensor.name] = _expr.const(array)
else:
self._params[init_tensor.name] = self._parse_array(init_tensor)
self._nodes[init_tensor.name] = new_var(
init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype,
)
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
Expand All @@ -2844,6 +2879,8 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
self._nodes[i_name] = new_var(
i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
)
elif i_name in self._nodes:
continue
else:
self._num_input += 1
if i_name in self._shape:
Expand Down Expand Up @@ -2886,37 +2923,27 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
for i in node.input:
if i != "":
inputs[i] = self._nodes[self._renames.get(i, i)]
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
# We should convert scalar integers to int32, to normalize.
array = self._parse_array(t_proto)
self._params[node.output[0]] = array
self._nodes[node.output[0]] = new_var(
node.output[0], shape=list(t_proto.dims), dtype=array.dtype
)
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)

op = self._convert_operator(op_name, inputs, attr, opset)
if not isinstance(op, _expr.TupleWrapper):
outputs_num = 1
else:
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)

op = self._convert_operator(op_name, inputs, attr, opset)
if not isinstance(op, _expr.TupleWrapper):
outputs_num = 1
else:
outputs_num = len(op)
assert (
len(node_output) == outputs_num
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)
if outputs_num == 1:
self._nodes[node_output[0]] = op
else:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
outputs_num = len(op)
assert (
len(node_output) == outputs_num
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)
if outputs_num == 1:
self._nodes[node_output[0]] = op
else:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]

# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
Expand All @@ -2934,9 +2961,6 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
self._inputs[i_name] = self._nodes[i_name]
# Create a function from our output expression and all input variables.
func = _function.Function([v for k, v in self._inputs.items()], outputs)
if freeze_params:
func, params = self.freeze(func, self._params)
return IRModule.from_expr(func), params
return IRModule.from_expr(func), self._params

def _parse_value_proto(self, value_proto):
Expand Down

0 comments on commit 9303222

Please sign in to comment.