Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Tensorflow] Sparse_Dense Op CSR scheduling issue resolved for Cuda & X86 #7148

Merged
merged 4 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from tvm import relay, te

from .. import nn
from ..utils import traverse_inline
from ..utils import traverse_inline, get_const_tuple, prod, get_const_int


def sparse_dense(data, weight_data, weight_indices, weight_indptr):
def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -57,19 +57,21 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr):
2-D with shape [M, N]
"""
# pylint:disable=unused-argument
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr)
return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs)


def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])

# TODO(ANSHUMAN87): Add for sparse_dense_bsrmm_v1 also
def _callback(op):
if op.tag == "sparse_dense_bsrmm_v2":
if op.tag == "sparse_dense_sp_rhs_bsrmm" or op.tag == "sparse_dense_sp_lhs_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block_v2"
assert (
y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block"
or y_bsrmm.op.tag == "sparse_dense_sp_lhs_bsrmm_block"
)
out = s.outputs[0].output(0)

if op not in s.outputs:
Expand All @@ -91,6 +93,13 @@ def _callback(op):
s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx)
s[y_bsrmm].set_store_predicate(thread_x.var.equal(0))
s[out].set_store_predicate(thread_x.var.equal(0))
elif op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_rhs_csrmm":
out = op.output(0)
const_size = get_const_int(prod(out.shape))
fused = s[out].fuse(*s[out].op.axis)
bx, tx = s[out].split(fused, factor=const_size)
s[out].bind(tx, te.thread_axis("threadIdx.x"))
s[out].bind(bx, te.thread_axis("blockIdx.x"))

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down Expand Up @@ -279,7 +288,26 @@ def gen_ir(data, w_data, w_indices, w_indptr, out):
return out


def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr):
def is_valid_for_sparse_dense_padded(data, weight_data):
"""
Check whether input is applicable for sparse_dense_padded op.
If not we should fall back to default scheduling.
"""
# pylint:disable=invalid-name
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
m = get_const_tuple(data.checked_type.shape)[1]
if len(weight_data.shape) == 1:
bs_m = 1
else:
bs_m = weight_data.shape[1]

mb = m // bs_m
if mb >= warp_size:
return True
return False


def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -311,6 +339,8 @@ def sparse_dense_padded(data, weight_data, weight_indices, weight_indptr):
output : tvm.te.Tensor
2-D with shape [M, N]
"""
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
assert not sparse_lhs, "Currently only sparse weight is supported."
return sparse_dense_tir(data, weight_data, weight_indices, weight_indptr)


Expand Down Expand Up @@ -368,6 +398,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
isinstance(inputs[1], relay.Constant)
and isinstance(inputs[2], relay.Constant)
and isinstance(inputs[3], relay.Constant)
and is_valid_for_sparse_dense_padded(inputs[0], inputs[1].data.asnumpy())
):
if len(inputs[1].data.asnumpy().shape) == 1:
sparse_matrix = sp.csr_matrix(
Expand Down
36 changes: 18 additions & 18 deletions python/tvm/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..utils import get_const_tuple


def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
def sparse_dense_sp_rhs(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Expand Down Expand Up @@ -52,13 +52,13 @@ def sparse_dense_v2(data, weight_data, weight_indices, weight_indptr):
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm_v2
func = _sparse_dense_sp_rhs_csrmm
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm_v2
func = _sparse_dense_sp_rhs_bsrmm
return func(data, weight_data, weight_indices, weight_indptr)


def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
def sparse_dense_sp_lhs(data_data, data_indices, data_indptr, weight):
"""
Computes sparse-dense matrix multiplication of
`(data_data, data_indices, data_indptr)` and `weight.T`
Expand Down Expand Up @@ -87,9 +87,9 @@ def sparse_dense_v1(data_data, data_indices, data_indptr, weight):
"""
assert len(data_data.shape) in (1, 3)
if len(data_data.shape) == 1:
func = _sparse_dense_csrmm_v1
func = _sparse_dense_sp_lhs_csrmm
if len(data_data.shape) == 3:
func = _sparse_dense_bsrmm_v1
func = _sparse_dense_sp_lhs_bsrmm
return func(data_data, data_indices, data_indptr, weight)


Expand Down Expand Up @@ -128,12 +128,12 @@ def sparse_dense(dense_data, sparse_data, sparse_indices, sparse_indptr, sparse_
2-D with shape [M, N]
"""
if sparse_lhs:
return sparse_dense_v1(sparse_data, sparse_indices, sparse_indptr, dense_data)
return sparse_dense_sp_lhs(sparse_data, sparse_indices, sparse_indptr, dense_data)
else:
return sparse_dense_v2(dense_data, sparse_data, sparse_indices, sparse_indptr)
return sparse_dense_sp_rhs(dense_data, sparse_data, sparse_indices, sparse_indptr)


def _sparse_dense_csrmm_v1(data_data, data_indices, data_indptr, weight):
def _sparse_dense_sp_lhs_csrmm(data_data, data_indices, data_indptr, weight):
oshape = (get_const_tuple(data_indptr.shape)[0] - 1, get_const_tuple(weight.shape)[0])

def f(row, i):
Expand All @@ -146,10 +146,10 @@ def f(row, i):
weight_val = weight[i, data_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm_v1")
return te.compute(oshape, f, tag="sparse_dense_sp_lhs_csrmm")


def _sparse_dense_csrmm_v2(data, weight_data, weight_indices, weight_indptr):
def _sparse_dense_sp_rhs_csrmm(data, weight_data, weight_indices, weight_indptr):
oshape = (get_const_tuple(data.shape)[0], get_const_tuple(weight_indptr.shape)[0] - 1)

def f(i, row):
Expand All @@ -162,10 +162,10 @@ def f(i, row):
weight_val = data[i, weight_indices[elem]]
return te.sum(a_val * weight_val, axis=elem_idx)

return te.compute(oshape, f, tag="sparse_dense_csrmm_v2")
return te.compute(oshape, f, tag="sparse_dense_sp_rhs_csrmm")


def _sparse_dense_bsrmm_v1(data_data, data_indices, data_indptr, weight):
def _sparse_dense_sp_lhs_bsrmm(data_data, data_indices, data_indptr, weight):
(m, _) = get_const_tuple(weight.shape)
(_, bs_r, bs_c) = get_const_tuple(data_data.shape)
(num_blocks_plus_1,) = get_const_tuple(data_indptr.shape)
Expand All @@ -187,16 +187,16 @@ def _compute_block(nb_j, j, i):
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_bsrmm_block_v1"
(num_blocks, bs_r, m), _compute_block, tag="sparse_dense_sp_lhs_bsrmm_block"
)
return te.compute(
(num_blocks * bs_r, m),
lambda m, n: bsrmm_block[idxd(m, bs_r), idxm(m, bs_r), n],
tag="sparse_dense_bsrmm_v1",
tag="sparse_dense_sp_lhs_bsrmm",
)


def _sparse_dense_bsrmm_v2(data, weight_data, weight_indices, weight_indptr):
def _sparse_dense_sp_rhs_bsrmm(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape)
Expand All @@ -218,12 +218,12 @@ def _compute_block(i, nb_j, j):
idxm = tvm.tir.indexmod

bsrmm_block = te.compute(
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_bsrmm_block_v2"
(m, num_blocks, bs_r), _compute_block, tag="sparse_dense_sp_rhs_bsrmm_block"
)
return te.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
tag="sparse_dense_bsrmm_v2",
tag="sparse_dense_sp_rhs_bsrmm",
)


Expand Down
18 changes: 10 additions & 8 deletions python/tvm/topi/x86/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@ def schedule_sparse_dense(outs):

def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
(_, v_i) = s[op].op.axis
s[op].vectorize(v_i)
(y_o, y_i) = s[outs[0].op].split(s[outs[0].op].op.axis[1], 2 * simd_width)
s[op].compute_at(s[outs[0]], y_o)
s[outs[0].op].vectorize(y_i)
if op.tag == "sparse_dense_bsrmm":
if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm":
(y_o, y_i) = s[op].split(s[op].op.axis[1], 2)
fused = s[op].fuse(s[op].op.axis[0], y_o)
s[op].parallel(fused)
s[op].vectorize(y_i)
elif op.tag == "sparse_dense_sp_rhs_bsrmm" or op.tag == "sparse_dense_sp_rhs_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
assert (
y_bsrmm.op.tag == "sparse_dense_sp_rhs_bsrmm_block"
or y_bsrmm.op.tag == "sparse_dense_sp_lhs_bsrmm_block"
)
y_reshape = op
(m, num_blocks, b_r) = s[y_bsrmm].op.axis
bs_r = get_const_int(b_r.dom.extent)
Expand Down
3 changes: 1 addition & 2 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,8 +1776,7 @@ def _test_sparse_dense_matmul(indices, values, A_shape, B_shape, dtype, flip=Fal

B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)

# TODO(ANSHUMAN87): There is an issue in cuda scheduling for csr, work in progress
compare_tf_with_tvm([B_np], [B.name], result.name, no_gpu=True)
compare_tf_with_tvm([B_np], [B.name], result.name)


def test_forward_sparse_dense_matmul():
Expand Down
28 changes: 28 additions & 0 deletions tests/python/topi/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,33 @@ def test_sparse_dense_padded_alter_op():
assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded"


@tvm.testing.requires_cuda
def test_sparse_dense_padded_alter_op_var_inp():
with tvm.target.Target("cuda"):
M = 128
N = 16
K = 128
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = random_bsr_matrix(N, K, 2, 2, density=0.01, dtype="float32")
x = relay.var("x", relay.TensorType(X_np.shape, "float32"))
mult = relay.op.nn.sparse_dense(
x,
(
relay.Constant(tvm.nd.array(W_sp_np.data)),
relay.Constant(tvm.nd.array(W_sp_np.indices)),
relay.Constant(tvm.nd.array(W_sp_np.indptr)),
),
)
f = relay.Function([x], mult)
f_ = relay.transform.InferType()(tvm.IRModule.from_expr(f))
f_ = relay.transform.AlterOpLayout()(f_)
assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded"

# build with cuda and AlterOpLayout to ensure that sparse_dense_padded is in action
with tvm.transform.PassContext(opt_level=3, required_pass="AlterOpLayout"):
x = relay.build(tvm.IRModule.from_expr(f), target=tvm.target.Target("cuda"))


if __name__ == "__main__":
test_csrmv()
test_csrmm()
Expand All @@ -530,5 +557,6 @@ def test_sparse_dense_padded_alter_op():
test_sparse_transpose_csr()
test_sparse_dense_padded_cuda()
test_sparse_dense_padded_alter_op()
ANSHUMAN87 marked this conversation as resolved.
Show resolved Hide resolved
test_sparse_dense_padded_alter_op_var_inp()
test_sparse_dense_csr_reverse()
test_sparse_dense_bsr_reverse()