Skip to content

Commit 26041f8

Browse files
authored
[Relax][Frontend] Support max/min in frontend op interface (#17782)
This PR adds the min/max reduction operators to the relax nn frontend operators, which were missing before this PR.
1 parent e60fd80 commit 26041f8

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

python/tvm/relax/frontend/nn/op.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,74 @@ def sum(
245245
return wrap_nested(_op.sum(x._expr, axis, keepdims), name)
246246

247247

248+
def max(
249+
x: Tensor,
250+
axis: Optional[Union[int, List[int]]] = None,
251+
keepdims: bool = False,
252+
name: str = "max",
253+
) -> Tensor:
254+
"""Computes the max of tensor elements over given axes.
255+
256+
Parameters
257+
----------
258+
x : Tensor
259+
The input data tensor
260+
261+
axis : Optional[Union[int, List[int]]]
262+
Axis or axes along which a max is performed.
263+
The default, axis=None, will max all of the elements of the input tensor.
264+
Negative indexing is supported.
265+
266+
keepdims : bool
267+
If this is set to True, the axes which are reduced are left in the result as
268+
dimensions with size one.
269+
With this option, the result will broadcast correctly against the input tensor.
270+
271+
name : str
272+
Name hint for this operation.
273+
274+
Returns
275+
-------
276+
result : Tensor
277+
The computed result.
278+
"""
279+
return wrap_nested(_op.max(x._expr, axis, keepdims), name)
280+
281+
282+
def min(
283+
x: Tensor,
284+
axis: Optional[Union[int, List[int]]] = None,
285+
keepdims: bool = False,
286+
name: str = "min",
287+
) -> Tensor:
288+
"""Computes the min of tensor elements over given axes.
289+
290+
Parameters
291+
----------
292+
x : Tensor
293+
The input data tensor
294+
295+
axis : Optional[Union[int, List[int]]]
296+
Axis or axes along which a min is performed.
297+
The default, axis=None, will min all of the elements of the input tensor.
298+
Negative indexing is supported.
299+
300+
keepdims : bool
301+
If this is set to True, the axes which are reduced are left in the result as
302+
dimensions with size one.
303+
With this option, the result will broadcast correctly against the input tensor.
304+
305+
name : str
306+
Name hint for this operation.
307+
308+
Returns
309+
-------
310+
result : Tensor
311+
The computed result.
312+
"""
313+
return wrap_nested(_op.min(x._expr, axis, keepdims), name)
314+
315+
248316
def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = "matmul") -> Tensor:
249317
"""General matrix multiplication of two tensors, with broadcasting on batched dimensions.
250318

tests/python/relax/test_frontend_nn_op.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,54 @@ def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R
128128
tvm.ir.assert_structural_equal(irmodule["test"], test)
129129

130130

131+
def test_max():
132+
class Model(Module):
133+
def test(self, x: Tensor):
134+
z0 = op.max(x, axis=[1, 2], keepdims=True)
135+
return z0
136+
137+
# fmt: off
138+
@R.function
139+
def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)):
140+
R.func_attr({"num_input": 2})
141+
with R.dataflow():
142+
max: R.Tensor((3, 1, 1, 4), dtype="float32") = R.max(x, axis=[1, 2], keepdims=True)
143+
gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = max, (_io,)
144+
R.output(gv1)
145+
return gv1
146+
# fmt: on
147+
148+
m = Model()
149+
irmodule, _ = m.export_tvm(
150+
spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True
151+
)
152+
tvm.ir.assert_structural_equal(irmodule["test"], test)
153+
154+
155+
def test_min():
156+
class Model(Module):
157+
def test(self, x: Tensor):
158+
z0 = op.min(x, axis=[1, 2], keepdims=True)
159+
return z0
160+
161+
# fmt: off
162+
@R.function
163+
def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)):
164+
R.func_attr({"num_input": 2})
165+
with R.dataflow():
166+
min: R.Tensor((3, 1, 1, 4), dtype="float32") = R.min(x, axis=[1, 2], keepdims=True)
167+
gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = min, (_io,)
168+
R.output(gv1)
169+
return gv1
170+
# fmt: on
171+
172+
m = Model()
173+
irmodule, _ = m.export_tvm(
174+
spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True
175+
)
176+
tvm.ir.assert_structural_equal(irmodule["test"], test)
177+
178+
131179
def test_manipulate():
132180
class Model(Module):
133181
def test(self, x: Tensor):

0 commit comments

Comments
 (0)