diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 4c6d921db79b..23045f7c4ebf 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -245,6 +245,74 @@ def sum( return wrap_nested(_op.sum(x._expr, axis, keepdims), name) +def max( + x: Tensor, + axis: Optional[Union[int, List[int]]] = None, + keepdims: bool = False, + name: str = "max", +) -> Tensor: + """Computes the max of tensor elements over given axes. + + Parameters + ---------- + x : Tensor + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a max is performed. + The default, axis=None, will max all of the elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + name : str + Name hint for this operation. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.max(x._expr, axis, keepdims), name) + + +def min( + x: Tensor, + axis: Optional[Union[int, List[int]]] = None, + keepdims: bool = False, + name: str = "min", +) -> Tensor: + """Computes the min of tensor elements over given axes. + + Parameters + ---------- + x : Tensor + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis or axes along which a min is performed. + The default, axis=None, will min all of the elements of the input tensor. + Negative indexing is supported. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. + With this option, the result will broadcast correctly against the input tensor. + + name : str + Name hint for this operation. + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.min(x._expr, axis, keepdims), name) + + def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = "matmul") -> Tensor: """General matrix multiplication of two tensors, with broadcasting on batched dimensions. diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 682f805026cb..6e63b0e4c069 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -128,6 +128,54 @@ def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R tvm.ir.assert_structural_equal(irmodule["test"], test) +def test_max(): + class Model(Module): + def test(self, x: Tensor): + z0 = op.max(x, axis=[1, 2], keepdims=True) + return z0 + + # fmt: off + @R.function + def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)): + R.func_attr({"num_input": 2}) + with R.dataflow(): + max: R.Tensor((3, 1, 1, 4), dtype="float32") = R.max(x, axis=[1, 2], keepdims=True) + gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = max, (_io,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + irmodule, _ = m.export_tvm( + spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True + ) + tvm.ir.assert_structural_equal(irmodule["test"], test) + + +def test_min(): + class Model(Module): + def test(self, x: Tensor): + z0 = op.min(x, axis=[1, 2], keepdims=True) + return z0 + + # fmt: off + @R.function + def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)): + R.func_attr({"num_input": 2}) + with R.dataflow(): + min: R.Tensor((3, 1, 1, 4), dtype="float32") = R.min(x, axis=[1, 2], keepdims=True) + gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = min, (_io,) + R.output(gv1) + return gv1 + # fmt: on + + m = Model() + irmodule, _ = m.export_tvm( + spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True + ) + tvm.ir.assert_structural_equal(irmodule["test"], test) + + def test_manipulate(): class Model(Module): def test(self, x: Tensor):