Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 212 additions & 5 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,29 @@ def _impl_v11(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.max(data, axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
return relax.op.max(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is True, return input unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
return relax.op.max(data, axes, keepdims)


class ReduceMin(OnnxOpConverter):
"""Converts an onnx ReduceMin node into an equivalent Relax expression."""
Expand All @@ -2520,6 +2543,29 @@ def _impl_v11(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.min(data, axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
return relax.op.min(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is True, return input unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
return relax.op.min(data, axes, keepdims)


class ReduceSum(OnnxOpConverter):
"""Converts an onnx ReduceSum node into an equivalent Relax expression."""
Expand All @@ -2534,11 +2580,25 @@ def _impl_v11(cls, bb, inputs, attr, params):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
data = inputs[0]
axes = inputs[1]
keepdims = attr.get("keepdims", 1)
assert isinstance(axes, relax.Constant), "Only constant axes currently supported."
axes = axes.data.numpy().tolist()
return relax.op.sum(data, axes, keepdims)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.sum(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over the specified axes
else:
return relax.op.sum(data, axes, keepdims)


class ReduceMean(OnnxOpConverter):
Expand All @@ -2551,6 +2611,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.mean(data, axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.mean(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over the specified axes
else:
return relax.op.mean(data, axes, keepdims)


class ReduceProd(OnnxOpConverter):
"""Converts an onnx ReduceProd node into an equivalent Relax expression."""
Expand All @@ -2562,6 +2645,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.prod(data, axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.prod(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over the specified axes
else:
return relax.op.prod(data, axes, keepdims)


class ReduceLogSumExp(OnnxOpConverter):
"""Converts an onnx ReduceLogSumExp node into an equivalent Relax expression."""
Expand All @@ -2579,6 +2685,38 @@ def _impl_v13(cls, bb, inputs, attr, params):
out_x = relax.op.squeeze(out_x, axes)
return out_x

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
x = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input (second input)
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# Calculate LogSumExp
log_sum_exp = lambda axes: (
max_x := relax.op.max(x, axes, True),
exp_x := relax.op.exp(relax.op.subtract(x, max_x)),
sum_x := relax.op.sum(exp_x, axes, True),
out_x := relax.op.add(relax.op.log(sum_x), max_x),
relax.op.squeeze(out_x, axes) if not keepdims else out_x,
)[-1]

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return log_sum_exp(None)
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return x
# If axes is provided, reduce over the specified axes
else:
return log_sum_exp(axes)


class ReduceLogSum(OnnxOpConverter):
"""Converts an onnx ReduceLogSum node into an equivalent Relax expression."""
Expand All @@ -2590,6 +2728,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.log(relax.op.sum(data, axes, keepdims))

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.log(relax.op.sum(data, None, keepdims))
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over the specified axes
else:
return relax.op.log(relax.op.sum(data, axes, keepdims))


class ReduceSumSquare(OnnxOpConverter):
"""Converts an onnx ReduceSumSquare node into an equivalent Relax expression."""
Expand All @@ -2601,6 +2762,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.sum(relax.op.multiply(data, data), None, keepdims)
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over the specified axes
else:
return relax.op.sum(relax.op.multiply(data, data), axes, keepdims)


class ReduceL1(OnnxOpConverter):
"""Converts an onnx ReduceL1 node into an equivalent Relax expression."""
Expand Down Expand Up @@ -2631,7 +2815,7 @@ def _impl_v18(cls, bb, inputs, attr, params):
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over specified axes
# If axes is provided, reduce over the specified axes
else:
return relax.op.sum(relax.op.abs(data), axes, keepdims)

Expand All @@ -2646,6 +2830,29 @@ def _impl_v13(cls, bb, inputs, attr, params):
keepdims = attr.get("keepdims", 1)
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims))

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
data = inputs[0]
keepdims = attr.get("keepdims", 1)
noop_with_empty_axes = attr.get("noop_with_empty_axes", 0)

# Optional axes input
axes = None
if len(inputs) > 1 and inputs[1] is not None:
axes_const = get_constant(inputs[1], params)
assert isinstance(axes_const, relax.Constant), "Only constant axes currently supported"
axes = axes_const.data.numpy().tolist()

# If axes is empty and noop_with_empty_axes is 0, reduce all dimensions
if not axes and not noop_with_empty_axes:
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), None, keepdims))
# If axes is empty and noop_with_empty_axes is 1, return the input data unchanged.
elif not axes and noop_with_empty_axes:
return data
# If axes is provided, reduce over the specified axes
else:
return relax.op.sqrt(relax.op.sum(relax.op.multiply(data, data), axes, keepdims))


class ArgMax(OnnxOpConverter):
"""Converts an onnx ArgMax node into an equivalent Relax expression."""
Expand Down
22 changes: 10 additions & 12 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,23 +1580,22 @@ def verify_reduce_func(func, data, axis, keepdims):
def create_reduce_test_parameters_axes_input():
output = []
for dynamic in [True, False]:
# TODO(@vacu9708): Enable the tests after implementing other reduce ops
# output.append(("ReduceMax", dynamic, 20))
# output.append(("ReduceMean", dynamic, 18))
# output.append(("ReduceMin", dynamic, 20))
# output.append(("ReduceProd", dynamic, 18))
# output.append(("ReduceSum", dynamic, 13))
# output.append(("ReduceSumSquare", dynamic, 18))
# output.append(("ReduceLogSum", dynamic, 18))
# output.append(("ReduceLogSumExp", dynamic, 18))
output.append(("ReduceMax", dynamic, 18))
output.append(("ReduceMean", dynamic, 18))
output.append(("ReduceMin", dynamic, 18))
output.append(("ReduceProd", dynamic, 18))
output.append(("ReduceSum", dynamic, 13))
output.append(("ReduceSumSquare", dynamic, 18))
output.append(("ReduceLogSum", dynamic, 18))
output.append(("ReduceLogSumExp", dynamic, 18))
output.append(("ReduceL1", dynamic, 18))
# output.append(("ReduceL2", dynamic, 18))
output.append(("ReduceL2", dynamic, 18))
return output


@pytest.mark.parametrize("func, dynamic, opset", create_reduce_test_parameters_axes_input())
def test_all_reduce_funcs_axes_input(func, dynamic, opset):
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes=False):
inshape = data.shape

inputs = ["x"]
Expand Down Expand Up @@ -1698,7 +1697,6 @@ def verify_reduce_func(func, data, axes, keepdims, noop_with_empty_axes):
np.random.randn(3, 3, 3, 1).astype(np.float32),
axes=(1, 2),
keepdims=keepdims,
noop_with_empty_axes=True,
)


Expand Down