Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Complete rewrite of where op to support broadcasting #6759

Merged
merged 13 commits into from
Oct 28, 2020
40 changes: 40 additions & 0 deletions include/tvm/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,46 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
return tvm::te::compute(oshape, l, name, tag);
}

inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1,
const tvm::te::Tensor& shape_tensor2,
std::string name = "T_broadcast_shape_tensors",
std::string tag = kBroadcast) {
const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]);
const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]);
const auto out_rank = std::max<int32_t>(rank1, rank2);
const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1));

auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank,
tvm::tir::Var index) -> PrimExpr {
if (rank < out_rank) {
// if the rank is smaller, dimension 1 is prepended according to
// the numpy broadcasting semantics.
return tvm::tir::Select(rank - (out_rank - index) < 0, one,
shape_tensor[rank - (out_rank - index)]);
} else {
// rank == out_rank, safe to index directly
return shape_tensor[index];
}
};

auto func = [&](tvm::Array<tvm::tir::Var> ovars) {
auto index = ovars[0];
PrimExpr dim1 = select_dim(shape_tensor1, rank1, index);
PrimExpr dim2 = select_dim(shape_tensor2, rank2, index);
if (topi::detail::EqualCheck(one, dim1)) {
return dim2;
} else if (topi::detail::EqualCheck(one, dim2)) {
return dim1;
}
return tvm::max(dim1, dim2);
masahi marked this conversation as resolved.
Show resolved Hide resolved
};

Array<PrimExpr> oshape;
oshape.push_back(PrimExpr(out_rank));

return tvm::te::compute(oshape, func, name, tag);
}

#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
Expand Down
67 changes: 23 additions & 44 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include <unordered_set>
#include <vector>

#include "detail/broadcast.h"

namespace tvm {
namespace topi {

Expand Down Expand Up @@ -887,53 +889,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string
*/
inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
std::string name = "T_where", std::string tag = kBroadcast) {
ICHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: " << x->shape.size()
<< " vs " << y->shape.size();
ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;
auto get_out_shape = [&]() {
auto bh1 = detail::BroadcastShape(x->shape, y->shape);
Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
return common_shape2;
};

if (x->shape.size() == 0) {
return compute(
condition->shape,
[&](const Array<Var>& indices) {
PrimExpr cond;
if (condition->shape.size() == 0) {
cond = condition();
} else {
Array<PrimExpr> condition_idx{indices[0]};
cond = condition(condition_idx);
}
return tvm::tir::Select(cond != 0, x(), y());
},
name, tag);
} else if (condition->shape.size() != 1) {
ICHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
return compute(
x->shape,
[&](const Array<Var>& indices) {
return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices));
},
name, tag);
} else {
int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]);
int64_t x_first_dim = topi::GetConstInt(x->shape[0]);
if (cond_first_dim > 0 && x_first_dim > 0) {
ICHECK_EQ(cond_first_dim, x_first_dim)
<< "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim
<< " vs " << x_first_dim;
}
return compute(
x->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices));
},
name, tag);
}
auto oshape = get_out_shape();

auto c_bh = detail::BroadcastShape(condition->shape, oshape);
auto x_bh = detail::BroadcastShape(x->shape, oshape);
auto y_bh = detail::BroadcastShape(y->shape, oshape);

auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
return tvm::tir::Select(c != 0, true_val, false_val);
};

return compute(oshape, select, name, tag);
}

/*!
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ def where_shape_func(attrs, inputs, _):
"""
cond_shape = inputs[0]
x_shape = inputs[1]
out_shape = x_shape if x_shape.shape else cond_shape
y_shape = inputs[2]
bcast_shape = topi.broadcast.broadcast_shape_tensors(x_shape, y_shape)
out_shape = topi.broadcast.broadcast_shape_tensors(bcast_shape, cond_shape)

return [topi.math.identity(out_shape)]
return [out_shape]
17 changes: 9 additions & 8 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,25 +649,26 @@ def where(condition, x, y):
condition.

.. note::
The shape of condition, x, and y needs to be the same.
Shapes of condition, x, and y must be broadcastable to a common shape.
Semantics follow numpy where function
https://numpy.org/doc/stable/reference/generated/numpy.where.html

Parameters
----------
condition : relay.Expr
The condition array. The n-th element in `y` is selected when the n-th
value in the `condition` array is zero. Otherwise, the corresponding
element from `x` will be picked.
Where True, yield x, otherwise yield y

x : relay.Expr
The first array to be selected.
The first array or scalar to be selected.

y : relay.Expr
The second array to be selected.
The second array or scalar to be selected.

Returns
-------
result : relay.Expr
The selected array.
The selected array. The output shape is the broadcasted shape from
condition, x, and y.

Examples
--------
Expand All @@ -678,7 +679,7 @@ def where(condition, x, y):
condition = [[0, 1], [-1, 0]]
relay.where(conditon, x, y) = [[5, 2], [3, 8]]

condition = [1, 0]
condition = [[1], [0]]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
"""
return _make.where(condition, x, y)
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def broadcast_to(data, shape):
"""Broadcast the src to the target shape

We follows the numpy broadcasting rule.
We follow the numpy broadcasting rule.
See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html

Parameters
Expand All @@ -40,6 +40,26 @@ def broadcast_to(data, shape):
return _cpp.broadcast_to(data, shape)


def broadcast_shape_tensors(shape_tensor1, shape_tensor2):
"""Compute a shape tensor whose values represents the broadcasted shape
of two input shape tensors

Parameters
----------
shape_tensor1 : tvm.te.Tensor
One of input shape tensors

shape_tensor2 : tvm.te.Tensor
One of input shape tensors

Returns
-------
ret : tvm.te.Tensor
A shape tensor whose values represents the broadcasted shape
"""
return _cpp.broadcast_shape_tensors(shape_tensor1, shape_tensor2)


def add(lhs, rhs):
"""Addition with auto-broadcasting

Expand Down
47 changes: 14 additions & 33 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "../../transforms/pattern_utils.h"
#include "../make_op.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -1685,30 +1686,17 @@ bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return false;
}

const auto& cond_shape = condition->shape;
const auto& x_shape = x->shape;
const auto& y_shape = y->shape;
ICHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";
ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;

if (cond_shape.size() != x_shape.size()) {
ICHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape
<< " must be either equal to x or has dimension of 1.";
}
for (size_t i = 0; i < x_shape.size(); i++) {
ICHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
auto tensor_ty_condition = GetRef<TensorType>(condition);
auto tensor_ty_x = GetRef<TensorType>(x);
auto tensor_ty_y = GetRef<TensorType>(y);

if (i < cond_shape.size()) {
ICHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
if (x_shape.size() == 0) {
// if x and y are scalar, the condition shape becomes the output shape
reporter->Assign(types[3], TensorType(cond_shape, x->dtype));
} else {
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
}
auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype);
auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype);
masahi marked this conversation as resolved.
Show resolved Hide resolved

reporter->Assign(types[3], ret_ty);
return true;
}

Expand All @@ -1731,17 +1719,10 @@ Return the elements, either from x or y, depending on the condition.

Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.

If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.

When x and y are scalars, condition must be an 1D array. The output shape
is the same as condition's shape.
Shapes of condition, x, and y must be broadcastable to a common shape, which
is the output shape of this op. Semantics follow numpy where function.
https://numpy.org/doc/stable/reference/generated/numpy.where.html

Note that all non-zero values are interpreted as True in condition.

Expand All @@ -1753,7 +1734,7 @@ Examples::
where(cond, x, y) = [[5, 2], [3, 8]]


cond = [1, 0]
cond = [[1], [0]]
where(cond, x, y) = [[1, 2], [7, 8]]

cond = [0, 1]
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
return false;
}

Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
Expand Down
9 changes: 9 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

/*!
* \brief Determine the broadcasted shape from two input shapes
* \param t1 One of two Tensortype whose shapes are broadcasted
* \param t2 One of two Tensortype whose shapes are broadcasted
* \param output_dtype dtype of the output TensorType
* \return A TensorType whose shape is broadcasted from two input TensorType.
*/
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype);

/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
Expand Down
4 changes: 4 additions & 0 deletions src/topi/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,9 @@ TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue*
*rv = broadcast_to(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.broadcast_shape_tensors").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = broadcast_shape_tensors(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
31 changes: 31 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,5 +1236,36 @@ def test_any_stack():
verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2)


def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape):
dtype = "float32"
cond = relay.var("cond", shape=cond_shape, dtype="bool")
x = relay.var("x", shape=x_shape, dtype=dtype)
y = relay.var("y", shape=y_shape, dtype=dtype)
z = relay.where(cond, x, y)
mod = tvm.IRModule()
mod["main"] = relay.Function([cond, x, y], z)

cond_np = np.random.randn(*cond_np_shape) > 0
x_np = np.random.randn(*x_np_shape).astype(dtype)
y_np = np.random.randn(*y_np_shape).astype(dtype)
expected = np.where(cond_np, x_np, y_np)

check_result([cond_np, x_np, y_np], mod, expected)


@tvm.testing.uses_gpu
def test_any_where():
verify_any_where(any_dims(1), (5,), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (5,), (5,))
verify_any_where((5,), any_dims(1), any_dims(1), (5,), (5,), (5,))

# where with broadcast
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,))
verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5))
verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5))
verify_any_where(any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4))


if __name__ == "__main__":
pytest.main([__file__])
Loading