Skip to content

Commit

Permalink
fix mkl offloading of batch matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
masa committed Oct 28, 2020
1 parent f96ac41 commit cd90aa7
Showing 7 changed files with 118 additions and 102 deletions.
67 changes: 23 additions & 44 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
@@ -39,6 +39,8 @@
#include <unordered_set>
#include <vector>

#include "detail/broadcast.h"

namespace tvm {
namespace topi {

@@ -887,53 +889,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string
*/
inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
std::string name = "T_where", std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: " << x->shape.size()
<< " vs " << y->shape.size();
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;
auto get_out_shape = [&]() {
auto bh1 = detail::BroadcastShape(x->shape, y->shape);
Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
return common_shape2;
};

if (x->shape.size() == 0) {
return compute(
condition->shape,
[&](const Array<Var>& indices) {
PrimExpr cond;
if (condition->shape.size() == 0) {
cond = condition();
} else {
Array<PrimExpr> condition_idx{indices[0]};
cond = condition(condition_idx);
}
return tvm::tir::Select(cond != 0, x(), y());
},
name, tag);
} else if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
return compute(
x->shape,
[&](const Array<Var>& indices) {
return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices));
},
name, tag);
} else {
int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]);
int64_t x_first_dim = topi::GetConstInt(x->shape[0]);
if (cond_first_dim > 0 && x_first_dim > 0) {
CHECK_EQ(cond_first_dim, x_first_dim)
<< "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim
<< " vs " << x_first_dim;
}
return compute(
x->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices));
},
name, tag);
}
auto oshape = get_out_shape();

auto c_bh = detail::BroadcastShape(condition->shape, oshape);
auto x_bh = detail::BroadcastShape(x->shape, oshape);
auto y_bh = detail::BroadcastShape(y->shape, oshape);

auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
return tvm::tir::Select(c != 0, true_val, false_val);
};

return compute(oshape, select, name, tag);
}

/*!
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
@@ -377,6 +377,13 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
name="batch_matmul_cblas.x86",
plevel=15,
)
if "mkl" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul_mkl),
wrap_topi_schedule(topi.x86.schedule_batch_matmul_mkl),
name="batch_matmul_mkl.x86",
plevel=15,
)
return strategy


24 changes: 20 additions & 4 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
from tvm import te
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas
from tvm.contrib import cblas, mkl
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

@@ -137,8 +137,7 @@ def _default_batch_matmul_config(cfg, M, N, K):
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y, out_shape=None):
def batch_matmul_common(cfg, x, y, out_shape, lib):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
@@ -152,6 +151,8 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
3-D with shape [batch, N, K]
out_shape : tuple or None
Shape of the output
lib : A contrib module
cblas or mkl are supported
Returns
-------
@@ -168,9 +169,24 @@ def batch_matmul_cblas(cfg, x, y, out_shape=None):
assert out_shape[1] == M, "got invalid output shape"
assert out_shape[2] == N, "got invalid output shape"
cfg.add_flop(XB * M * N * XK * 2)
return cblas.batch_matmul(x, y, False, True)
return lib.batch_matmul(x, y, False, True)


@autotvm.register_topi_compute("batch_matmul_cblas.x86")
def batch_matmul_cblas(cfg, x, y, out_shape=None):
return batch_matmul_common(cfg, x, y, out_shape, cblas)


@autotvm.register_topi_schedule("batch_matmul_cblas.x86")
def schedule_batch_matmul_cblas(_, outs):
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("batch_matmul_mkl.x86")
def batch_matmul_mkl(cfg, x, y, out_shape=None):
return batch_matmul_common(cfg, x, y, out_shape, mkl)


@autotvm.register_topi_schedule("batch_matmul_mkl.x86")
def schedule_batch_matmul_mkl(_, outs):
return generic.schedule_extern(outs)
32 changes: 10 additions & 22 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@
#include "../../transforms/pattern_utils.h"
#include "../make_op.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {
@@ -1685,30 +1686,17 @@ bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return false;
}

const auto& cond_shape = condition->shape;
const auto& x_shape = x->shape;
const auto& y_shape = y->shape;
CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;

if (cond_shape.size() != x_shape.size()) {
CHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape
<< " must be either equal to x or has dimension of 1.";
}
for (size_t i = 0; i < x_shape.size(); i++) {
CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
auto tensor_ty_condition = GetRef<TensorType>(condition);
auto tensor_ty_x = GetRef<TensorType>(x);
auto tensor_ty_y = GetRef<TensorType>(y);

if (i < cond_shape.size()) {
CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
if (x_shape.size() == 0) {
// if x and y are scalar, the condition shape becomes the output shape
reporter->Assign(types[3], TensorType(cond_shape, x->dtype));
} else {
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
}
auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype);
auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype);

reporter->Assign(types[3], ret_ty);
return true;
}

2 changes: 1 addition & 1 deletion src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
return false;
}

Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
2 changes: 2 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
@@ -57,6 +57,8 @@ bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype);

/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
86 changes: 55 additions & 31 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
@@ -152,35 +152,59 @@ def run(func, inputs, ref_res):
op_res = intrp.evaluate(func)(*inputs)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

shape = (3, 4)
dtype = "float32"
cond = relay.var("cond", relay.TensorType(shape, dtype))
x = relay.var("x", relay.TensorType(shape, dtype))
y = relay.var("y", relay.TensorType(shape, dtype))
z = relay.where(cond, x, y)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(shape, dtype)
def verify(x_np, y_np, cond_np):
ref_res = np.where(cond_np, x_np, y_np)

args = []
args_np = []
vs = []

cond = relay.var("cond", relay.TensorType(cond_np.shape, "bool"))

args.append(cond)
args_np.append(cond_np)

for v_name, v_np in [("x", x_np), ("y", y_np)]:
if len(v_np.shape) == 0:
v = relay.const(v_np.item())
else:
v = relay.var(v_name, relay.TensorType(v_np.shape, dtype))
args.append(v)
args_np.append(v_np)
vs.append(v)

func = relay.Function([cond, x, y], z)
condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
x = np.random.uniform(size=shape).astype(dtype)
y = np.random.uniform(size=shape).astype(dtype)
ref_res = np.where(condition, x, y)
z = relay.where(cond, vs[0], vs[1])

run(func, [condition, x, y], ref_res)
func = relay.Function(args, z)

run(func, args_np, ref_res)

x = relay.const(1)
y = relay.const(-1)
shape = (3,)
dtype = "float32"
cond = relay.var("cond", relay.TensorType(shape, "bool"))
z = relay.where(cond, x, y)

func = relay.Function([cond], z)
condition = np.array([1, 0, 1], dtype=np.bool)
ref_res = np.where(condition, 1, -1)
x_np = np.random.uniform(size=(3, 4)).astype(dtype)
y_np = np.random.uniform(size=(3, 4)).astype(dtype)
cond_np = np.random.uniform(low=-1, high=1, size=(3, 4)) > 0

verify(x_np, y_np, cond_np)

x_np = np.array(1.0, dtype)
y_np = np.array(-1.0, dtype)
cond_np = np.array([1, 0, 1], dtype=np.bool)

verify(x_np, y_np, cond_np)

x_np = np.array([[1, 2], [3, 4]], dtype)
y_np = np.array([[5, 6], [7, 8]], dtype)
cond_np = np.array([[1], [0]], dtype=np.bool)

verify(x_np, y_np, cond_np)
verify(x_np, y_np, cond_np.T)

x_np = np.random.randn(1, 12, 8, 8).astype(dtype)
y_np = np.array(-1.0, dtype)
cond_np = np.random.randn(1, 1, 8, 8) > 0

run(func, [condition], ref_res)
verify(x_np, y_np, cond_np)


def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
@@ -498,12 +522,12 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True):


if __name__ == "__main__":
test_strided_slice()
test_strided_set()
test_binary_op()
test_cmp_type()
test_binary_int_broadcast_1()
test_binary_int_broadcast_2()
# test_strided_slice()
# test_strided_set()
# test_binary_op()
# test_cmp_type()
# test_binary_int_broadcast_1()
# test_binary_int_broadcast_2()
test_where()
test_reduce_functions()
test_mean_var_std()
# test_reduce_functions()
# test_mean_var_std()

0 comments on commit cd90aa7

Please sign in to comment.