Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,10 +740,12 @@ def _impl_v14(cls, bb, inputs, attr, params):
x = inputs[0]
k = inputs[1] if len(inputs) > 1 else 0

if isinstance(k, relax.Var) and k.name_hint in params:
k = get_constant(k, params)
elif isinstance(k, relax.Constant):
k = int(k.data.numpy()[0])
if len(inputs) > 1:
k = get_constant(inputs[1], params)
if isinstance(k, relax.Constant):
k = int(k.data.numpy()[0])
else:
raise ValueError("Currently only support constant k for Trilu op.")
else:
k = 0

Expand Down
62 changes: 0 additions & 62 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from tvm.contrib import graph_executor, utils
from tvm.relay.frontend.common import infer_type
from tvm.relay.build_module import bind_params_by_name
from tvm.relax.frontend.onnx import from_onnx
from relay.utils.tag_span import _create_span, _set_span, _verify_structural_equal_with_span

import onnx
Expand Down Expand Up @@ -5441,67 +5440,6 @@ def verify_softplus(indata):
verify_softplus(input_data)


def test_load_cumsum():
"""test_load_cumsum"""

def create_cumsum_model():
input_shape = [2, 3]

graph = helper.make_graph(
[
helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
],
"cumsum_graph",
inputs=[
helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape),
helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"),
],
outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)],
)
return helper.make_model(graph)

from_onnx(create_cumsum_model())


def test_load_trilu():
"""test_load_trilu"""

def create_trilu_model():
input_shape = [2, 3, 3]

graph = helper.make_graph(
[
helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
],
"trilu_graph",
inputs=[
helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape),
helper.make_tensor_value_info("k", onnx.TensorProto.INT32, [1], "k"),
],
outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)],
)
return helper.make_model(graph)

def create_trilu_model_const_k():
input_shape = [2, 3, 3]

graph = helper.make_graph(
[
make_constant_node("k", onnx.TensorProto.INT32, [1], [1]),
helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
],
"trilu_graph",
inputs=[
helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape),
],
outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)],
)
return helper.make_model(graph)

from_onnx(create_trilu_model())
from_onnx(create_trilu_model_const_k())


@tvm.testing.parametrize_targets
def test_cumsum(target, dev):
"""test_cumsum"""
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,28 @@ def test_trilu(upper: bool):
verify_unary("Trilu", [3, 5, 5], attrs={"upper": upper})


@pytest.mark.parametrize("k_value", [-1, 0, 1])
def test_trilu_with_const_k(k_value: int):
"""test_trilu_with_const_k"""

input_shape = [2, 3, 3]

graph = helper.make_graph(
[
make_constant_node("k", onnx.TensorProto.INT64, [1], [k_value]),
helper.make_node("Trilu", inputs=["x", "k"], outputs=["y"]),
],
"trilu_graph",
inputs=[
helper.make_tensor_value_info("x", onnx.TensorProto.DOUBLE, input_shape),
],
outputs=[helper.make_tensor_value_info("y", onnx.TensorProto.DOUBLE, input_shape)],
)

model = helper.make_model(graph, producer_name="trilu_graph")
check_correctness(model)


def test_selu():
verify_unary("Selu", [3, 32, 32])
verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3})
Expand Down Expand Up @@ -859,6 +881,27 @@ def test_cumsum(reverse, exclusive):
check_correctness(model)


def test_cumsum1():
"""test_cumsum1"""

input_shape = [2, 3]

graph = helper.make_graph(
[
helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
],
"cumsum_graph",
inputs=[
helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, input_shape),
helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], "axis"),
],
outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, input_shape)],
)

model = helper.make_model(graph, producer_name="cumsum_graph")
check_correctness(model)


@pytest.mark.parametrize("axis", [[0, 2], None])
def test_squeeze(axis):
if axis:
Expand Down