Skip to content

Commit

Permalink
[TOPI] Expose topi::collapse_sum to Python and support symbolic sha…
Browse files Browse the repository at this point in the history
…pe (#14541)

TOPI has an implementation of collapse_sum internally (tvm/topi/reduction.h) but it is not exposed to FFI and can not be called in Python side. This patch exposes it and adds related tests. And this PR lets the implementation of topi::collapse_sum support symbolic shape cases.
  • Loading branch information
SiriusNEO authored Apr 9, 2023
1 parent a84a2cb commit 15f9be5
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 7 deletions.
19 changes: 13 additions & 6 deletions include/tvm/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,21 +333,28 @@ inline Tensor sum(const Tensor& data, const Array<Integer>& axis, bool keepdims
}

inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
ICHECK_GE(data->shape.size(), target_shape.size());
auto ishape = detail::GetConstIntValues(data->shape, "ishape");
auto oshape = detail::GetConstIntValues(target_shape, "oshape");
const auto& ishape = data->shape;
const auto& oshape = target_shape;
int isize = data->shape.size();
int osize = target_shape.size();

ICHECK_GE(isize, osize)
<< "Invalid collapse: input dimensionality smaller than output dimensionality.\ninput shape: "
<< data->shape << "\nvs\noutput shape: " << target_shape;

std::vector<int> reduce_axes;
std::vector<int> squeeze_axes;
for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
tvm::PrimExpr one(1);

for (int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) {
if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) {
--o_ax;
continue;
}
reduce_axes.push_back(i_ax);
if (o_ax < 0) { // squeeze o_ax if was added during expansion
squeeze_axes.push_back(i_ax);
} else if (oshape[o_ax] == 1) {
} else if (topi::detail::EqualCheck(one, oshape[o_ax])) {
--o_ax;
}
}
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,34 @@ def prod(data, axis=None, keepdims=False):
ret : tvm.te.Tensor
"""
return cpp.prod(data, axis, keepdims)


def collapse_sum(data, target_shape):
"""Return a summation of data to the given shape.
collapse_sum is intended as the backward operator of topi broadcast operators in the automatic
differentiation process.
We expect that data is the result of broadcasting some tensor of target_shape in some
broadcast operation. Thus target_shape and data.shape must follow broadcast rules.
During computation, the axes of data.shape and target_shape are checked from right to left.
For every axis, if it either:
- exist in data but not in target_shape, or
- is larger than 1 in data and equals to 1 in target_shape,
data will be summed over this axis.
Parameters
----------
data : tvm.te.Tensor
The input tensor.
shape : Tuple[int]
The shape to collapse to.
Returns
-------
ret : tvm.te.Tensor
The result tensor after summation.
"""
return cpp.collapse_sum(data, target_shape)
4 changes: 4 additions & 0 deletions src/topi/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,9 @@ TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});

TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = topi::collapse_sum(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
50 changes: 49 additions & 1 deletion tests/python/topi/python/test_topi_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tvm.testing
import tvm.topi.testing

from tvm import te, topi
from tvm import te, topi, tir

in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
((32,), 0, False, "argmax", "float32"),
Expand Down Expand Up @@ -191,5 +191,53 @@ def test_complex_reduce(target, dev):
tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)


n = tir.Var("n", "int32")
m = tir.Var("m", "int32")
true_value_map = {n: 3, m: 5}

data_shape, target_shape = tvm.testing.parameters(
((2, 3), (3,)),
((2, 3, 4), (2, 1, 4)),
((2, 3, 4, 5), (3, 1, 5)),
((2, n, 4, m), (n, 1, m)),
)


def _my_npy_collapse_sum(data, target_shape):
reduce_axes = []
i = data.ndim - 1
j = len(target_shape) - 1
while i >= 0:
if j < 0:
reduce_axes.append(i)
elif target_shape[j] == 1 and data.shape[i] > 1:
reduce_axes.append(i)
i -= 1
j -= 1
return np.sum(data, tuple(reduce_axes)).reshape(target_shape)


def test_collapse_sum(data_shape, target_shape):
A = te.placeholder(data_shape, name="A")
B = topi.collapse_sum(A, target_shape)
s = te.create_schedule([B.op])

data_shape_const = [int(s) if s not in true_value_map else true_value_map[s] for s in A.shape]
target_shape_const = [
int(s) if s not in true_value_map else true_value_map[s] for s in target_shape
]
a_np = np.random.uniform(size=data_shape_const).astype(A.dtype)
b_np = _my_npy_collapse_sum(a_np, target_shape_const)
dev = tvm.cpu(0)
a = tvm.nd.array(a_np, dev)
B_shape_const = [int(s) if s not in true_value_map else true_value_map[s] for s in B.shape]
b = tvm.nd.array(np.zeros(B_shape_const, dtype=B.dtype), dev)
# Building with the CSE pass disabled
with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]):
foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")
foo(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 15f9be5

Please sign in to comment.