diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5777f51fe296..36a7823f8655 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -740,10 +740,12 @@ def _impl_v14(cls, bb, inputs, attr, params): x = inputs[0] k = inputs[1] if len(inputs) > 1 else 0 - if isinstance(k, relax.Var) and k.name_hint in params: - k = get_constant(k, params) - elif isinstance(k, relax.Constant): - k = int(k.data.numpy()[0]) + if len(inputs) > 1: + k = get_constant(inputs[1], params) + if isinstance(k, relax.Constant): + k = int(k.data.numpy()[0]) + else: + raise ValueError("Currently only support constant k for Trilu op.") else: k = 0 diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a5811d0dbd46..a81352bb679f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -37,7 +37,6 @@ from tvm.contrib import graph_executor, utils from tvm.relay.frontend.common import infer_type from tvm.relay.build_module import bind_params_by_name -from tvm.relax.frontend.onnx import from_onnx from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span import onnx @@ -5441,67 +5440,6 @@ def verify_softplus(indata): verify_softplus(input_data) -def test_load_cumsum(): - """test_load_cumsum""" - - def create_cumsum_model(): - input_shape = [2, 3] - - graph = helper.make_graph( - [ - helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]), - ], - "cumsum_graph", - inputs=[ - helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape), - helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"), - ], - outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)], - ) - return helper.make_model(graph) - - from_onnx(create_cumsum_model()) - - -def test_load_trilu(): - """test_load_trilu""" - - def create_trilu_model(): - input_shape = [2, 3, 3] - - graph = helper.make_graph( - [ - helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]), - ], - "trilu_graph", - inputs=[ - helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape), - helper.make_tensor_value_info("k", onnx.TensorProto.INT32, [1], "k"), - ], - outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)], - ) - return helper.make_model(graph) - - def create_trilu_model_const_k(): - input_shape = [2, 3, 3] - - graph = helper.make_graph( - [ - make_constant_node("k", onnx.TensorProto.INT32, [1], [1]), - helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]), - ], - "trilu_graph", - inputs=[ - helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape), - ], - outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)], - ) - return helper.make_model(graph) - - from_onnx(create_trilu_model()) - from_onnx(create_trilu_model_const_k()) - - @tvm.testing.parametrize_targets def test_cumsum(target, dev): """test_cumsum""" diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 2837ad2185e9..f2bbd3f3f585 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -710,6 +710,28 @@ def test_trilu(upper: bool): verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper}) +@pytest.mark.parametrize("k_value", [-1, 0, 1]) +def test_trilu_with_const_k(k_value: int): + """test_trilu_with_const_k""" + + input_shape = [2, 3, 3] + + graph = helper.make_graph( + [ + make_constant_node("k", onnx.TensorProto.INT64, [1], [k_value]), + helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]), + ], + "trilu_graph", + inputs=[ + helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape), + ], + outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)], + ) + + model = helper.make_model(graph, producer_name="trilu_graph") + check_correctness(model) + + def test_selu(): verify_unary("Selu", [3, 32, 32]) verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3}) @@ -859,6 +881,27 @@ def test_cumsum(reverse, exclusive): check_correctness(model) +def test_cumsum1(): + """test_cumsum1""" + + input_shape = [2, 3] + + graph = helper.make_graph( + [ + helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]), + ], + "cumsum_graph", + inputs=[ + helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape), + helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"), + ], + outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)], + ) + + model = helper.make_model(graph, producer_name="cumsum_graph") + check_correctness(model) + + @pytest.mark.parametrize("axis", [[0, 2], None]) def test_squeeze(axis): if axis: