Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] Compare against onnxruntime more consistently…
Browse files Browse the repository at this point in the history
… during testing (apache#7300)



Co-authored-by: Josh Fromm <jwfromm@uw.edu>
  • Loading branch information
2 people authored and trevor-m committed Jan 21, 2021
1 parent 7e932b6 commit 5ba46e5
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 572 deletions.
73 changes: 39 additions & 34 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines
# pylint: disable=import-outside-toplevel
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import copy
import warnings
import numpy as np
import tvm
Expand Down Expand Up @@ -1028,10 +1029,6 @@ def _impl_v9(cls, inputs, attr, params):
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)
)

if method == "nearest_neighbor":
align_corners = False
else:
align_corners = True
# in 3d case, we use the purely static op
if dims == 5:
if isinstance(scales, _expr.Call):
Expand Down Expand Up @@ -1065,7 +1062,7 @@ def _impl_v9(cls, inputs, attr, params):
scale_w,
layout=layout,
method=method,
align_corners=align_corners,
align_corners=False,
)
return out

Expand Down Expand Up @@ -1111,17 +1108,22 @@ class Split(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
splits = attr.get("split", False)
if splits:
splits = attr.get("split", None)
if splits is not None:
indices = []
attr["indices_or_sections"] = []
index = 0
for i in splits[:-1]:
index += i
attr["indices_or_sections"].append(index)
indices.append(index)
# When splits isnt specified divide evenly over axis.
else:
attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"]
return AttrCvt("split", ignores=["split"])(inputs, attr, params)
indices = attr["tvm_custom"]["num_outputs"]
output = _op.split(inputs[0], indices, attr.get("axis", 0))
# If the output of split is a single value, unpack if from the TupleWrapper
if len(output) == 1:
output = output[0]
return output


class Slice(OnnxOpConverter):
Expand Down Expand Up @@ -1227,7 +1229,9 @@ class GatherND(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
return _op.gather_nd(inputs[0], inputs[1])
indices_dims = len(infer_shape(inputs[1]))
indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(inputs[0], indices)


class Scatter(OnnxOpConverter):
Expand Down Expand Up @@ -1538,15 +1542,6 @@ def _impl_v1(cls, inputs, attr, params):
class Tile(Elemwise):
"""Operator converter for Tile"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
if "repeats" not in attr:
raise tvm.error.OpAttributeInvalid(
'Attribute "repeats" should be set ' "for operator Tile."
)
reps = attr.pop("repeats") # The number of times repeating the tensor data.
return _op.tile(inputs[0], reps)

@classmethod
def _impl_v6(cls, inputs, attr, params):
return _op.tile(inputs[0], inputs[1])
Expand Down Expand Up @@ -2113,7 +2108,9 @@ def _impl_v11(cls, inputs, attr, params):
cond = inputs[1]
loop_deps = inputs[2:]
num_deps = len(loop_deps)
body = attr["body"]
# Create a copy of the body function to prevent the original
# from being modified.
body = copy.copy(attr["body"])
iter_dtype = infer_type(max_loop_count).checked_type.dtype

# Determine what condition mode we're in.
Expand Down Expand Up @@ -2150,6 +2147,8 @@ def get_var(name, val, scan=False):
checked_type = infer_type(val)
if hasattr(checked_type, "type_annotation"):
checked_type = checked_type.type_annotation
if hasattr(checked_type, "checked_type"):
checked_type = checked_type.checked_type
shape = get_const_tuple(checked_type.shape)
actual_shape = []
for dim in shape:
Expand Down Expand Up @@ -2185,8 +2184,14 @@ def get_var(name, val, scan=False):
scan_output_init = []
for i in range(num_scan_outputs):
name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps])
scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + shape), dtype=dtype))
scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape))
if dtype == "float":
dtype = "float32"
scan_output_vars.append(
_expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype)
)
scan_output_init.append(
_op.reshape(_expr.const(np.array([]).astype(dtype)), [0] + [1] * len(shape))
)

# Now we can remove loop iter variables from our inner loop's inputs.
# This is kind of a hack since we have graph inputs that we don't
Expand Down Expand Up @@ -2219,18 +2224,18 @@ def body_fn(*loop_inputs):
new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)]
new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, len(loop_outputs))]

# Increment counter.
if max_loop_count is not None:
incr = _expr.const(1, dtype=iter_dtype)
loop_count = loop_count + incr

# Add new scan outputs to tracking
combined_scan_outputs = []
for i, scan in enumerate(scan_outputs):
new_scan = _op.expand_dims(new_scan_outputs[i], axis=0)
combined_scan = _op.concatenate([scan, new_scan], axis=0)
combined_scan_outputs.append(combined_scan)

# Increment counter.
if max_loop_count is not None:
incr = _expr.const(1, dtype=iter_dtype)
loop_count = loop_count + incr

# Pack loop outputs for next iteration
# [iter_count, cond, loop_deps, loop_scans]
return [loop_count, max_count, new_cond] + new_loop_vars + combined_scan_outputs
Expand Down Expand Up @@ -2630,12 +2635,12 @@ def _get_convert_map(opset):
"Greater": Greater.get_converter(opset),
"Less": Less.get_converter(opset),
"Log": Renamer("log"),
"ACos": Renamer("acos"),
"ACosh": Renamer("acosh"),
"ASin": Renamer("asin"),
"ASinh": Renamer("asinh"),
"ATan": Renamer("atan"),
"ATanh": Renamer("atanh"),
"Acos": Renamer("acos"),
"Acosh": Renamer("acosh"),
"Asin": Renamer("asin"),
"Asinh": Renamer("asinh"),
"Atan": Renamer("atan"),
"Atanh": Renamer("atanh"),
"Cos": Renamer("cos"),
"Cosh": Renamer("cosh"),
"Sin": Renamer("sin"),
Expand Down
Loading

0 comments on commit 5ba46e5

Please sign in to comment.