Skip to content

Commit

Permalink
fix style, two bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Feb 12, 2021
1 parent feddd6a commit 9c2c5a6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,7 +1076,7 @@ def shape_of(x, dtype="int64"):
if not _ty.is_dynamic(ttype):
shape = list(ttype.shape)
return _expr.const(shape, dtype)
return _op.shape_of(x, "int64")
return _op.shape_of(x, dtype)


class Shape(OnnxOpConverter):
Expand Down Expand Up @@ -2879,11 +2879,11 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
array = self._parse_array(init_tensor)
if freeze_params:
array = self._parse_array(init_tensor)
self._nodes[init_tensor.name] = _expr.const(array)
else:
self._params[init_tensor.name] = self._parse_array(init_tensor)
self._params[init_tensor.name] = array
self._nodes[init_tensor.name] = new_var(
init_tensor.name,
shape=self._params[init_tensor.name].shape,
Expand Down Expand Up @@ -2961,10 +2961,10 @@ def from_onnx(self, graph, opset, freeze_params=False, get_output_expr=False):
len(node_output), outputs_num, op_name
)
if outputs_num == 1:
self._nodes[node_output[0]] = op
self._nodes[node_output[0]] = fold_constant(op)
else:
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
self._nodes[k] = fold_constant(op[i])

# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
Expand Down

0 comments on commit 9c2c5a6

Please sign in to comment.