Skip to content

Commit

Permalink
Fix initializer for CumSum. (apache#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh authored and Josh Fromm committed Feb 3, 2023
1 parent b8f2a3a commit 5bc54eb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
3 changes: 1 addition & 2 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,7 @@ class CumSum(OnnxOpConverter):
def _impl_v13(cls, bb, inputs, attr):
assert getattr(attr, "reverse", 0) == 0, "reverse is not supported yet"
if len(inputs) > 1:
# axis = int(infer_value(inputs[1], params).numpy())
axis = inputs[1]
axis = int(inputs[1].data.numpy())
else:
axis = None
return bb.emit_te(
Expand Down
8 changes: 4 additions & 4 deletions tests/python/relax/frontend/test_onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


def generate_random_inputs(
model: ModelProto, inputs: Dict[str, np.array] = None
model: ModelProto, inputs: Optional[Dict[str, np.array]] = None
) -> Dict[str, np.array]:
input_values = {}
# Iterate through model inputs and extract their shape.
Expand Down Expand Up @@ -559,13 +559,13 @@ def test_cumsum():
"cumsum_test",
inputs=[
helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
helper.make_tensor_value_info("axis", TensorProto.INT64, ()),
],
initializer=[helper.make_tensor("axis", TensorProto.INT64, (), [1])],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
)

model = helper.make_model(graph, producer_name="cumsum_test")
check_correctness(model, {"axis": [1]})
check_correctness(model)


if __name__ == "__main__":
Expand All @@ -585,6 +585,7 @@ def test_cumsum():
test_conv()
test_pow()
test_erf()
test_cumsum()

# TODO, still has issues
# test_reshape()
Expand All @@ -594,4 +595,3 @@ def test_cumsum():
test_transpose()
test_unsqueeze()
# test_shape()
# test_cumsum() # need axis as int

0 comments on commit 5bc54eb

Please sign in to comment.