Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,74 @@ def sum(
return wrap_nested(_op.sum(x._expr, axis, keepdims), name)


def max(
x: Tensor,
axis: Optional[Union[int, List[int]]] = None,
keepdims: bool = False,
name: str = "max",
) -> Tensor:
"""Computes the max of tensor elements over given axes.

Parameters
----------
x : Tensor
The input data tensor

axis : Optional[Union[int, List[int]]]
Axis or axes along which a max is performed.
The default, axis=None, will max all of the elements of the input tensor.
Negative indexing is supported.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.

name : str
Name hint for this operation.

Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.max(x._expr, axis, keepdims), name)


def min(
x: Tensor,
axis: Optional[Union[int, List[int]]] = None,
keepdims: bool = False,
name: str = "min",
) -> Tensor:
"""Computes the min of tensor elements over given axes.

Parameters
----------
x : Tensor
The input data tensor

axis : Optional[Union[int, List[int]]]
Axis or axes along which a min is performed.
The default, axis=None, will min all of the elements of the input tensor.
Negative indexing is supported.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.

name : str
Name hint for this operation.

Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.min(x._expr, axis, keepdims), name)


def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = "matmul") -> Tensor:
"""General matrix multiplication of two tensors, with broadcasting on batched dimensions.

Expand Down
48 changes: 48 additions & 0 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,54 @@ def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R
tvm.ir.assert_structural_equal(irmodule["test"], test)


def test_max():
class Model(Module):
def test(self, x: Tensor):
z0 = op.max(x, axis=[1, 2], keepdims=True)
return z0

# fmt: off
@R.function
def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
max: R.Tensor((3, 1, 1, 4), dtype="float32") = R.max(x, axis=[1, 2], keepdims=True)
gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = max, (_io,)
R.output(gv1)
return gv1
# fmt: on

m = Model()
irmodule, _ = m.export_tvm(
spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True
)
tvm.ir.assert_structural_equal(irmodule["test"], test)


def test_min():
class Model(Module):
def test(self, x: Tensor):
z0 = op.min(x, axis=[1, 2], keepdims=True)
return z0

# fmt: off
@R.function
def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 2})
with R.dataflow():
min: R.Tensor((3, 1, 1, 4), dtype="float32") = R.min(x, axis=[1, 2], keepdims=True)
gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = min, (_io,)
R.output(gv1)
return gv1
# fmt: on

m = Model()
irmodule, _ = m.export_tvm(
spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True
)
tvm.ir.assert_structural_equal(irmodule["test"], test)


def test_manipulate():
class Model(Module):
def test(self, x: Tensor):
Expand Down
Loading