Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Expose the interface of topi.collapse_sum #102

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/tvm/relax/transform/legalize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,11 @@ def _image_resize2d(bb: BlockBuilder, call: Call) -> Expr:
"relax.reshape": _reshape(topi.reshape, "reshape"),
"relax.split": _split,
"relax.squeeze": _squeeze,
# Todo(relax-team): Introduce TOPI collapse_sum for gradient
# "relax.collapse_sum_like": _reshape(topi.collapse_sum, "collapse_sum"),
# "relax.collapse_sum_to": _reshape(
# topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True
# ),
# TODO(relax-team): collapse_sum support symbolic shape
"relax.collapse_sum_like": _reshape(
topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True
),
"relax.collapse_sum_to": _reshape(topi.collapse_sum, "collapse_sum"),
# Search
"relax.where": _where,
# Statistical
Expand Down
25 changes: 25 additions & 0 deletions python/tvm/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,28 @@ def prod(data, axis=None, keepdims=False):
ret : tvm.te.Tensor
"""
return cpp.prod(data, axis, keepdims)


def collapse_sum(data, target_shape):
"""Return a summation of data to the given shape.
collapse_sum is intended as the backward operator of topi broadcast operators in the automatic
differentiation process.
We expect that data is the result of broadcasting some tensor of target_shape in some
broadcast operation. Thus target_shape and data.shape must follow broadcast rules.
During computation, the axes of data.shape and target_shape are checked from right to left.
For every axis, if it either:
- exist in data but not in target_shape, or
- is larger than 1 in data and equals to 1 in target_shape,
data will be summed over this axis.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
shape : Tuple[int]
The shape to collapse to.
Returns
-------
ret : tvm.te.Tensor
The result tensor after summation.
"""
return cpp.collapse_sum(data, target_shape)
4 changes: 4 additions & 0 deletions src/topi/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});

TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::collapse_sum(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
27 changes: 23 additions & 4 deletions tests/python/relax/test_transform_legalize_ops_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,6 @@ def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle):
tvm.ir.assert_structural_equal(mod, Expected)


@pytest.mark.skip("TOPI has no collapse_sum. Waiting a fixing patch.")
def test_collapse_sum_like():
# fmt: off
@tvm.script.ir_module
Expand Down Expand Up @@ -819,7 +818,7 @@ def collapse_sum(rxplaceholder: T.Buffer[(T.int64(2), T.int64(3)), "float32"], r
tvm.ir.assert_structural_equal(mod, Expected)


@pytest.mark.skip("TOPI has no collapse_sum. Waiting a fixing patch.")
@pytest.mark.skip("TOPI collapse_sum not support symbolic now")
def test_collapse_sum_like_symbolic():
# fmt: off
@tvm.script.ir_module
Expand All @@ -836,7 +835,6 @@ def main(x: R.Tensor(("a", "b", "a"), "float32"), y: R.Tensor(("b", 1), "float32
tvm.ir.assert_structural_equal(mod, Expected)


@pytest.mark.skip("TOPI has no collapse_sum. Waiting a fixing patch.")
def test_collapse_sum_to():
# fmt: off
@tvm.script.ir_module
Expand All @@ -846,13 +844,34 @@ def main(x: R.Tensor((3, 2, 3), "float32")) -> R.Tensor((2, 1), "float32"):
gv: R.Tensor((2, 1), "float32") = R.collapse_sum_to(x, (2, 1))
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((3, 2, 3), dtype="float32")
) -> R.Tensor((2, 1), dtype="float32"):
# block 0
gv = R.call_tir(collapse_sum, (x,), (2, 1), dtype="float32")
return gv

@T.prim_func
def collapse_sum(rxplaceholder: T.Buffer[(T.int64(3), T.int64(2), T.int64(3)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(2), T.int64(1)), "float32"]):
T.func_attr({"tir.noalias": True})
for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), T.int64(3)):
with T.block("rxplaceholder_red"):
v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2])
T.reads(rxplaceholder[v_k0, v_ax0, v_k2])
T.writes(rxplaceholder_red[v_ax0, v_ax1])
with T.init():
rxplaceholder_red[v_ax0, v_ax1] = T.float32(0)
rxplaceholder_red[v_ax0, v_ax1] = (rxplaceholder_red[v_ax0, v_ax1] + rxplaceholder[v_k0, v_ax0, v_k2])
# fmt: on

mod = LegalizeOps()(CollapseSumTo)
tvm.ir.assert_structural_equal(mod, Expected)


@pytest.mark.skip("TOPI has no collapse_sum. Waiting a fixing patch.")
@pytest.mark.skip("TOPI collapse_sum not support symbolic now")
def test_collapse_sum_to_symbolic():
# fmt: off
@tvm.script.ir_module
Expand Down
39 changes: 39 additions & 0 deletions tests/python/topi/python/test_topi_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tvm
import tvm.testing
import tvm.topi.testing
from tvm.topi.utils import get_const_tuple

from tvm import te, topi

Expand Down Expand Up @@ -183,5 +184,43 @@ def test_complex_reduce(target, dev):
tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)


data_shape, target_shape = tvm.testing.parameters(
((2, 3), (3,)),
((2, 3, 4), (2, 1, 4)),
((2, 3, 4, 5), (3, 1, 5)),
)


def _my_npy_collapse_sum(data, target_shape):
reduce_axes = []
i = data.ndim - 1
j = len(target_shape) - 1
while i >= 0:
if j < 0:
reduce_axes.append(i)
elif target_shape[j] == 1 and data.shape[i] > 1:
reduce_axes.append(i)
i -= 1
j -= 1
return np.sum(data, tuple(reduce_axes)).reshape(target_shape)


def test_collapse_sum(data_shape, target_shape):
A = te.placeholder(data_shape, name="A")
B = topi.collapse_sum(A, target_shape)
s = te.create_schedule([B.op])

a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = _my_npy_collapse_sum(a_np, target_shape)
dev = tvm.cpu(0)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
# Building with the CSE pass disabled
with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]):
foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")
foo(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)


if __name__ == "__main__":
tvm.testing.main()