From 03997a1689510c7a56791e610af08348b001d406 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sat, 3 Apr 2021 17:36:45 -0700 Subject: [PATCH] [Relay]Frontend][Onnx] Remove pop that interferes with nested loops. (#7781) * Remove popping that interferes with nested loops. * Only check user inputs in the outer-most graph scope. * Fix style. Co-authored-by: Ubuntu --- python/tvm/relay/frontend/onnx.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 624a61efee27..669eab8cc250 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2981,6 +2981,7 @@ def __init__(self, shape, dtype, freeze_params=False): self._num_input = 0 self._num_param = 0 self._shape = shape if shape else {} + self._input_names = [] self._dtype = dtype self.opset = None self._freeze_params = freeze_params @@ -3062,8 +3063,9 @@ def from_onnx(self, graph, opset, get_output_expr=False): continue else: self._num_input += 1 + self._input_names.append(i_name) if i_name in self._shape: - i_shape = self._shape.pop(i_name) + i_shape = self._shape[i_name] else: if "?" in str(i_shape): warning_msg = ( @@ -3078,11 +3080,13 @@ def from_onnx(self, graph, opset, get_output_expr=False): dtype = d_type self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name] - assert ( - len(self._shape) == 0 - ), "User specified the shape for inputs that weren't found in the graph: " + str( - self._shape - ) + # Only check user inputs in the outer-most graph scope. + if self._old_manager is None: + assert all( + [name in self._input_names for name in self._shape.keys()] + ), "User specified the shape for inputs that weren't found in the graph: " + str( + self._shape + ) # get list of unsupported ops convert_map = _get_convert_map(opset) unsupported_ops = set()