Skip to content

Commit 96f616b

Browse files
committed
[Relax][ONNX] Update Reduce ops to support axes as input
- Support axes as an input for Reduce ops (e.g., ReduceL2, ReduceMax, …) - Add corresponding test cases
1 parent 8327a8c commit 96f616b

File tree

2 files changed

+222
-22
lines changed

2 files changed

+222
-22
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 212 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,26 +2500,68 @@ def _impl_v17(cls, bb, inputs, attr, params):
25002500

25012501

25022502
class ReduceMax(OnnxOpConverter):
2503-
"""Converts an onnx ReduceMax node into an equivalent Relax expression."""
2504-
25052503
@classmethod
25062504
def _impl_v11(cls, bb, inputs, attr, params):
25072505
data = inputs[0]
25082506
axes = attr.get("axes", None)
25092507
keepdims = attr.get("keepdims", 1)
25102508
return relax.op.max(data, axes, keepdims)
25112509

2510+
@classmethod
2511+
def _impl_v18(cls, bb, inputs, attr, params):
2512+
data = inputs[0]
2513+
keepdims = attr.get("keepdims", 1)
2514+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2515+
2516+
# Optional axes input
2517+
axes = None
2518+
if len(inputs) > 1 and inputs[1] is not None:
2519+
axes_const = get_constant(inputs[1], params)
2520+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2521+
axes = axes_const.data.numpy().tolist()
2522+
2523+
# If axes is empty and noop_with_empty_axes is False, reduce all dims
2524+
if not axes and not noop_with_empty_axes:
2525+
return relax.op.max(data, None, keepdims)
2526+
# If axes is empty and noop_with_empty_axes is True, return input unchanged
2527+
elif not axes and noop_with_empty_axes:
2528+
return data
2529+
# Otherwise reduce over specified axes
2530+
else:
2531+
return relax.op.max(data, axes, keepdims)
25122532

2513-
class ReduceMin(OnnxOpConverter):
2514-
"""Converts an onnx ReduceMin node into an equivalent Relax expression."""
25152533

2534+
class ReduceMin(OnnxOpConverter):
25162535
@classmethod
25172536
def _impl_v11(cls, bb, inputs, attr, params):
25182537
data = inputs[0]
25192538
axes = attr.get("axes", None)
25202539
keepdims = attr.get("keepdims", 1)
25212540
return relax.op.min(data, axes, keepdims)
25222541

2542+
@classmethod
2543+
def _impl_v18(cls, bb, inputs, attr, params):
2544+
data = inputs[0]
2545+
keepdims = attr.get("keepdims", 1)
2546+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2547+
2548+
# Optional axes input
2549+
axes = None
2550+
if len(inputs) > 1 and inputs[1] is not None:
2551+
axes_const = get_constant(inputs[1], params)
2552+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2553+
axes = axes_const.data.numpy().tolist()
2554+
2555+
# If axes is empty and noop_with_empty_axes is False, reduce all dims
2556+
if not axes and not noop_with_empty_axes:
2557+
return relax.op.min(data, None, keepdims)
2558+
# If axes is empty and noop_with_empty_axes is True, return input unchanged
2559+
elif not axes and noop_with_empty_axes:
2560+
return data
2561+
# Otherwise reduce over specified axes
2562+
else:
2563+
return relax.op.min(data, axes, keepdims)
2564+
25232565

25242566
class ReduceSum(OnnxOpConverter):
25252567
"""Converts an onnx ReduceSum node into an equivalent Relax expression."""
@@ -2534,11 +2576,25 @@ def _impl_v11(cls, bb, inputs, attr, params):
25342576
@classmethod
25352577
def _impl_v13(cls, bb, inputs, attr, params):
25362578
data = inputs[0]
2537-
axes = inputs[1]
25382579
keepdims = attr.get("keepdims", 1)
2539-
assert isinstance(axes, relax.Constant), "Only constant axes currently supported."
2540-
axes = axes.data.numpy().tolist()
2541-
return relax.op.sum(data, axes, keepdims)
2580+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2581+
2582+
# Optional axes input
2583+
axes = None
2584+
if len(inputs) > 1 and inputs[1] is not None:
2585+
axes_const = get_constant(inputs[1], params)
2586+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2587+
axes = axes_const.data.numpy().tolist()
2588+
2589+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2590+
if not axes and not noop_with_empty_axes:
2591+
return relax.op.sum(data, None, keepdims)
2592+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2593+
elif not axes and noop_with_empty_axes:
2594+
return data
2595+
# If axes is provided, reduce over the specified axes
2596+
else:
2597+
return relax.op.sum(data, axes, keepdims)
25422598

25432599

25442600
class ReduceMean(OnnxOpConverter):
@@ -2550,6 +2606,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
25502606
axes = attr.get("axes", None)
25512607
keepdims = attr.get("keepdims", 1)
25522608
return relax.op.mean(data, axes, keepdims)
2609+
2610+
@classmethod
2611+
def _impl_v18(cls, bb, inputs, attr, params):
2612+
data = inputs[0]
2613+
keepdims = attr.get("keepdims", 1)
2614+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2615+
2616+
# Optional axes input
2617+
axes = None
2618+
if len(inputs) > 1 and inputs[1] is not None:
2619+
axes_const = get_constant(inputs[1], params)
2620+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2621+
axes = axes_const.data.numpy().tolist()
2622+
2623+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2624+
if not axes and not noop_with_empty_axes:
2625+
return relax.op.mean(data, None, keepdims)
2626+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2627+
elif not axes and noop_with_empty_axes:
2628+
return data
2629+
# If axes is provided, reduce over the specified axes
2630+
else:
2631+
return relax.op.mean(data, axes, keepdims)
25532632

25542633

25552634
class ReduceProd(OnnxOpConverter):
@@ -2561,6 +2640,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
25612640
axes = attr.get("axes", None)
25622641
keepdims = attr.get("keepdims", 1)
25632642
return relax.op.prod(data, axes, keepdims)
2643+
2644+
@classmethod
2645+
def _impl_v18(cls, bb, inputs, attr, params):
2646+
data = inputs[0]
2647+
keepdims = attr.get("keepdims", 1)
2648+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2649+
2650+
# Optional axes input
2651+
axes = None
2652+
if len(inputs) > 1 and inputs[1] is not None:
2653+
axes_const = get_constant(inputs[1], params)
2654+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2655+
axes = axes_const.data.numpy().tolist()
2656+
2657+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2658+
if not axes and not noop_with_empty_axes:
2659+
return relax.op.prod(data, None, keepdims)
2660+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2661+
elif not axes and noop_with_empty_axes:
2662+
return data
2663+
# If axes is provided, reduce over the specified axes
2664+
else:
2665+
return relax.op.prod(data, axes, keepdims)
25642666

25652667

25662668
class ReduceLogSumExp(OnnxOpConverter):
@@ -2578,6 +2680,38 @@ def _impl_v13(cls, bb, inputs, attr, params):
25782680
if not keepdims:
25792681
out_x = relax.op.squeeze(out_x, axes)
25802682
return out_x
2683+
2684+
@classmethod
2685+
def _impl_v18(cls, bb, inputs, attr, params):
2686+
x = inputs[0]
2687+
keepdims = attr.get("keepdims", 1)
2688+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2689+
2690+
# Optional axes input (second input)
2691+
axes = None
2692+
if len(inputs) > 1 and inputs[1] is not None:
2693+
axes_const = get_constant(inputs[1], params)
2694+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2695+
axes = axes_const.data.numpy().tolist()
2696+
2697+
# Calculate LogSumExp
2698+
log_sum_exp = lambda axes: (
2699+
max_x := relax.op.max(x, axes, True),
2700+
exp_x := relax.op.exp(relax.op.subtract(x, max_x)),
2701+
sum_x := relax.op.sum(exp_x, axes, True),
2702+
out_x := relax.op.add(relax.op.log(sum_x), max_x),
2703+
relax.op.squeeze(out_x, axes) if not keepdims else out_x
2704+
)[-1]
2705+
2706+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2707+
if not axes and not noop_with_empty_axes:
2708+
return log_sum_exp(None)
2709+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2710+
elif not axes and noop_with_empty_axes:
2711+
return x
2712+
# If axes is provided, reduce over the specified axes
2713+
else:
2714+
return log_sum_exp(axes)
25812715

25822716

25832717
class ReduceLogSum(OnnxOpConverter):
@@ -2589,6 +2723,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
25892723
axes = attr.get("axes", None)
25902724
keepdims = attr.get("keepdims", 1)
25912725
return relax.op.log(relax.op.sum(data, axes, keepdims))
2726+
2727+
@classmethod
2728+
def _impl_v18(cls, bb, inputs, attr, params):
2729+
data = inputs[0]
2730+
keepdims = attr.get("keepdims", 1)
2731+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2732+
2733+
# Optional axes input
2734+
axes = None
2735+
if len(inputs) > 1 and inputs[1] is not None:
2736+
axes_const = get_constant(inputs[1], params)
2737+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2738+
axes = axes_const.data.numpy().tolist()
2739+
2740+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2741+
if not axes and not noop_with_empty_axes:
2742+
return relax.op.log(relax.op.sum(data, None, keepdims))
2743+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2744+
elif not axes and noop_with_empty_axes:
2745+
return data
2746+
# If axes is provided, reduce over the specified axes
2747+
else:
2748+
return relax.op.log(relax.op.sum(data, axes, keepdims))
25922749

25932750

25942751
class ReduceSumSquare(OnnxOpConverter):
@@ -2600,6 +2757,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
26002757
axes = attr.get("axes", None)
26012758
keepdims = attr.get("keepdims", 1)
26022759
return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)
2760+
2761+
@classmethod
2762+
def _impl_v18(cls, bb, inputs, attr, params):
2763+
data = inputs[0]
2764+
keepdims = attr.get("keepdims", 1)
2765+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2766+
2767+
# Optional axes input
2768+
axes = None
2769+
if len(inputs) > 1 and inputs[1] is not None:
2770+
axes_const = get_constant(inputs[1], params)
2771+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2772+
axes = axes_const.data.numpy().tolist()
2773+
2774+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2775+
if not axes and not noop_with_empty_axes:
2776+
return relax.op.sum(relax.op.multiply(data, data), None, keepdims)
2777+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2778+
elif not axes and noop_with_empty_axes:
2779+
return data
2780+
# If axes is provided, reduce over the specified axes
2781+
else:
2782+
return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)
26032783

26042784

26052785
class ReduceL1(OnnxOpConverter):
@@ -2631,7 +2811,7 @@ def _impl_v18(cls, bb, inputs, attr, params):
26312811
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
26322812
elif not axes and noop_with_empty_axes:
26332813
return data
2634-
# If axes is provided, reduce over specified axes
2814+
# If axes is provided, reduce over the specified axes
26352815
else:
26362816
return relax.op.sum(relax.op.abs(data), axes, keepdims)
26372817

@@ -2645,6 +2825,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
26452825
axes = attr.get("axes", None)
26462826
keepdims = attr.get("keepdims", 1)
26472827
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims))
2828+
2829+
@classmethod
2830+
def _impl_v18(cls, bb, inputs, attr, params):
2831+
data = inputs[0]
2832+
keepdims = attr.get("keepdims", 1)
2833+
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)
2834+
2835+
# Optional axes input
2836+
axes = None
2837+
if len(inputs) > 1 and inputs[1] is not None:
2838+
axes_const = get_constant(inputs[1], params)
2839+
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
2840+
axes = axes_const.data.numpy().tolist()
2841+
2842+
# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
2843+
if not axes and not noop_with_empty_axes:
2844+
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), None, keepdims))
2845+
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
2846+
elif not axes and noop_with_empty_axes:
2847+
return data
2848+
# If axes is provided, reduce over the specified axes
2849+
else:
2850+
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims))
26482851

26492852

26502853
class ArgMax(OnnxOpConverter):

tests/python/relax/test_frontend_onnx.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,23 +1580,22 @@ def verify_reduce_func(func, data, axis, keepdims):
15801580
def create_reduce_test_parameters_axes_input():
15811581
output = []
15821582
for dynamic in [True, False]:
1583-
# TODO(@vacu9708): Enable the tests after implementing other reduce ops
1584-
# output.append(("ReduceMax", dynamic, 20))
1585-
# output.append(("ReduceMean", dynamic, 18))
1586-
# output.append(("ReduceMin", dynamic, 20))
1587-
# output.append(("ReduceProd", dynamic, 18))
1588-
# output.append(("ReduceSum", dynamic, 13))
1589-
# output.append(("ReduceSumSquare", dynamic, 18))
1590-
# output.append(("ReduceLogSum", dynamic, 18))
1591-
# output.append(("ReduceLogSumExp", dynamic, 18))
1583+
output.append(("ReduceMax", dynamic, 18))
1584+
output.append(("ReduceMean", dynamic, 18))
1585+
output.append(("ReduceMin", dynamic, 18))
1586+
output.append(("ReduceProd", dynamic, 18))
1587+
output.append(("ReduceSum", dynamic, 13))
1588+
output.append(("ReduceSumSquare", dynamic, 18))
1589+
output.append(("ReduceLogSum", dynamic, 18))
1590+
output.append(("ReduceLogSumExp", dynamic, 18))
15921591
output.append(("ReduceL1", dynamic, 18))
1593-
# output.append(("ReduceL2", dynamic, 18))
1592+
output.append(("ReduceL2", dynamic, 18))
15941593
return output
15951594

15961595

15971596
@pytest.mark.parametrize("func, dynamic, opset", create_reduce_test_parameters_axes_input())
15981597
def test_all_reduce_funcs_axes_input(func, dynamic, opset):
1599-
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
1598+
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes=False):
16001599
inshape = data.shape
16011600

16021601
inputs = ["x"]
@@ -1698,10 +1697,8 @@ def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
16981697
np.random.randn(3, 3, 3, 1).astype(np.float32),
16991698
axes=(1, 2),
17001699
keepdims=keepdims,
1701-
noop_with_empty_axes=True,
17021700
)
17031701

1704-
17051702
@pytest.mark.parametrize("in_dtype", [np.float32, np.int32])
17061703
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
17071704
@pytest.mark.parametrize("keepdims", [None, True, False])

0 commit comments

Comments
 (0)