From e10fe31d30ff5e79f66636bab1d7d212a7c1a3c4 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Tue, 27 Apr 2021 22:31:38 +0000 Subject: [PATCH] Improve dtype detection in loop to fix onnx tests. --- python/tvm/relay/frontend/onnx.py | 9 +++++++-- tests/python/frontend/onnx/test_forward.py | 6 ++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a695e0002b34..7c8e2e86af5d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -43,7 +43,7 @@ class onnx_input: - """ Dual purpose list or dictionary access object.""" + """Dual purpose list or dictionary access object.""" def __init__(self): self.input_keys = [] @@ -126,7 +126,10 @@ def get_info(info_proto): shape.append(value) name = info_proto.name - dtype = get_type(info_proto.type.tensor_type.elem_type) + if info_proto.type.tensor_type.elem_type: + dtype = get_type(info_proto.type.tensor_type.elem_type) + else: + dtype = None return name, shape, dtype, shape_name @@ -2405,6 +2408,8 @@ def get_var(name, val, scan=False): scan_output_init = [] for i in range(num_scan_outputs): name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) + if dtype is None: + dtype = infer_type(loop_deps[i]).checked_type.dtype if dtype == "float": dtype = "float32" scan_output_vars.append( diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1a3d0d4ac6e0..cb54b7948134 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -45,7 +45,7 @@ def get_input_data_shape_dict(graph_def, input_data): def get_tvm_output_with_vm( graph_def, input_data, target, device, opset=None, freeze_params=False, convert_to_static=False ): - """ Generic function to execute and get tvm output with vm executor""" + """Generic function to execute and get tvm output with vm executor""" if not isinstance(input_data, list): input_data = [input_data] _, shape_dict = get_input_data_shape_dict(graph_def, input_data) @@ -67,7 +67,7 @@ def get_tvm_output_with_vm( def get_tvm_output( graph_def, input_data, target, device, output_shape=None, output_dtype="float32", opset=None ): - """ Generic function to execute and get tvm output""" + """Generic function to execute and get tvm output""" # TODO: Resolve the issues and remove the following lines target = "llvm" device = tvm.cpu(0) @@ -4222,8 +4222,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): "test_qlinearconv/", "test_qlinearmatmul_2D/", "test_qlinearmatmul_3D/", - "test_range_float_type_positive_delta_expanded/", - "test_range_int32_type_negative_delta_expanded/", "test_resize_tf_crop_and_resize/", ## For these three tests, ONNX 1.6.0 has incorrect graphs, they pass with ONNX 1.7.0 "test_resize_upsample_sizes_nearest_ceil_half_pixel/",