Skip to content

Commit

Permalink
Change behavior of onnx importer to throw when user provides an input…
Browse files Browse the repository at this point in the history
… no in the graph. (apache#7699)
  • Loading branch information
jwfromm authored and Trevor Morris committed May 6, 2021
1 parent 7d4626e commit da6dde7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,7 +2914,7 @@ def from_onnx(self, graph, opset, get_output_expr=False):
else:
self._num_input += 1
if i_name in self._shape:
i_shape = self._shape[i_name]
i_shape = self._shape.pop(i_name)
else:
if "?" in str(i_shape):
warning_msg = (
Expand All @@ -2929,6 +2929,11 @@ def from_onnx(self, graph, opset, get_output_expr=False):
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
self._inputs[i_name] = self._nodes[i_name]
assert (
len(self._shape) == 0
), "User specified the shape for inputs that weren't found in the graph: " + str(
self._shape
)
# get list of unsupported ops
convert_map = _get_convert_map(opset)
unsupported_ops = set()
Expand Down
39 changes: 33 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from onnx import helper, TensorProto, mapping, numpy_helper
import torch
import torchvision
import pytest
import tvm.topi.testing
import tvm
from tvm import relay
Expand Down Expand Up @@ -57,7 +58,7 @@ def get_tvm_output_with_vm(
mod = relay.transform.DynamicToStatic()(mod)

ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
result = ex.evaluate()(*input_data)
result = ex.evaluate()(*input_data, **params)
if isinstance(result, tvm.runtime.NDArray):
return result.asnumpy()
return [r.asnumpy() for r in result]
Expand Down Expand Up @@ -500,7 +501,7 @@ def test_squeeze():

model = helper.make_model(graph, producer_name="squeeze_test")
x = np.random.uniform(size=in_shape).astype("float32")
verify_with_ort_with_inputs(model, [x], [out_shape])
verify_with_ort_with_inputs(model, [x], [out_shape], opset=11)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -538,7 +539,7 @@ def test_unsqueeze():
)

model = helper.make_model(graph, producer_name="squeeze_test")
verify_with_ort(model, [in_shape])
verify_with_ort(model, [in_shape], opset=11)


def verify_gather(in_shape, indices, axis, dtype):
Expand Down Expand Up @@ -1584,7 +1585,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0):
pads = np.array(pads)
# onnx graph
if mode in ["edge", "reflect"]:
inputs = [indata, pads]
inputs = [indata]
outdata = np.pad(indata, pad_width=np_pads, mode=mode)
node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode)
graph = helper.make_graph(
Expand All @@ -1600,7 +1601,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0):
],
)
else:
inputs = [indata, pads, np.array([value]).astype("float32")]
inputs = [indata]
outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value)
node = helper.make_node(
"Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant"
Expand Down Expand Up @@ -1663,7 +1664,7 @@ def verify_reduce_func(func, data, axis, keepdims):

model = helper.make_model(graph, producer_name="reduce_test")

verify_with_ort_with_inputs(model, [data], [outshape])
verify_with_ort_with_inputs(model, [data], [outshape], opset=11)


@tvm.testing.uses_gpu
Expand Down Expand Up @@ -4089,6 +4090,31 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
verify_cumsum(data, 1, 1, 1, type="int32")


def test_wrong_input():
node = helper.make_node(
"Softplus",
inputs=["X"],
outputs=["Y"],
)

graph = helper.make_graph(
[node],
"softplus_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list([5]))],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list([5]))],
)
model = helper.make_model(graph, producer_name="softplus_test")

# Check that the graph can import correctly with proper shape definitions.
correct_shape_dict = {"X": [5]}
relay.frontend.from_onnx(model, shape=correct_shape_dict)

# Check that an assertion is triggered when an input not in the graph is provided.
wrong_shape_dict = {"Z": [5]}
with pytest.raises(AssertionError):
relay.frontend.from_onnx(model, shape=wrong_shape_dict)


if __name__ == "__main__":
test_flatten()
test_reshape()
Expand Down Expand Up @@ -4167,3 +4193,4 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
test_maxunpool()
test_softplus()
test_cumsum()
test_wrong_input()

0 comments on commit da6dde7

Please sign in to comment.