Skip to content

Commit

Permalink
Add rocblas_sgemm_strided_batched impl. (apache#6579)
Browse files Browse the repository at this point in the history
  • Loading branch information
csullivan authored and Tushar Dey committed Oct 15, 2020
1 parent fdd0b4c commit 01554ef
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 10 deletions.
23 changes: 23 additions & 0 deletions include/tvm/topi/contrib/rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa,
},
"C", "", {})[0];
}
/*!
* \brief Create an op that batch multiplies lhs and rhs with rocBLAS
*
* \param lhs The left matrix operand e.g. (batch_size, M, K)
* \param rhs The right matrix operand e.g. (batch_size, K, N)
* \param transa Whether to transpose lhs
* \param transb Whether to transpose rhs
*
* \return The output tensor
*/
inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) {
auto batch_size = lhs->shape[0];
auto n = transa ? lhs->shape[2] : lhs->shape[1];
auto m = transb ? rhs->shape[1] : rhs->shape[2];

return make_extern(
{{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs},
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]),
pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb});
},
"C", "", {})[0];
}

} // namespace contrib
} // namespace topi
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/contrib/rocblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,36 @@ def matmul(lhs, rhs, transa=False, transb=False):
),
name="C",
)


def batch_matmul(lhs, rhs, transa=False, transb=False):
"""Create an extern op that compute matrix mult of A and rhs with rocBLAS
Parameters
----------
lhs : Tensor
The left batched matrix operand
rhs : Tensor
The right batched matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
batch_size = lhs.shape[0]
assert batch_size == rhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return te.extern(
(batch_size, n, m),
[lhs, rhs],
lambda ins, outs: tvm.tir.call_packed(
"tvm.contrib.rocblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb
),
name="C",
)
21 changes: 21 additions & 0 deletions python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,24 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
plevel=15,
)
return strategy


@batch_matmul_strategy.register("rocm")
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
"""Batch matmul strategy for ROCM"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10,
)
if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas),
wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas),
name="batch_matmul_rocblas.rocm",
plevel=12,
)
return strategy
1 change: 1 addition & 0 deletions python/tvm/topi/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""rocm specific declaration and schedules."""
from __future__ import absolute_import as _abs

from .batch_matmul import *
from .conv2d import *
from .dense import *
from .nn import *
56 changes: 56 additions & 0 deletions python/tvm/topi/rocm/batch_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for batch_matmul operator"""
from tvm import autotvm
from tvm.contrib import rocblas
from .. import generic
from ..util import get_const_tuple


@autotvm.register_topi_compute("batch_matmul_rocblas.rocm")
def batch_matmul_rocblas(cfg, x, y, out_shape=None):
"""Computes matrix multiplication of `x` and `y` via rocblas when
`x` and `y` are batched matrices.
Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file
x : tvm.te.Tensor
3-D with shape [batch, M, K]
y : tvm.te.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.te.Tensor
3-D with shape [batch, M, N]
"""
batch, M, K = get_const_tuple(x.shape)
_, N, _ = get_const_tuple(y.shape)
if out_shape is not None:
assert out_shape[0] == batch, "Input and output batch sizes must match"
assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
result = rocblas.batch_matmul(x, y, False, True)
cfg.add_flop(batch * M * N * K * 2)
return result


@autotvm.register_topi_schedule("batch_matmul_rocblas.rocm")
def schedule_batch_matmul_rocblas(_, outs):
"""Schedule for batch_matmul operator with rocm cblas"""
return generic.schedule_extern(outs)
59 changes: 51 additions & 8 deletions src/runtime/contrib/rocblas/rocblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,61 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMR
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);

CHECK_ROCBLAS_ERROR(
rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0],
transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr,
A->shape[1], &beta, C_ptr, C->shape[1]));
rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none;
rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none;
size_t N = transb ? B->shape[0] : B->shape[1];
size_t M = transa ? A->shape[1] : A->shape[0];
size_t K = transb ? B->shape[1] : B->shape[0];
size_t lda = transa ? M : K;
size_t ldb = transb ? K : N;
size_t ldc = N;

CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb,
A_ptr, lda, &beta, C_ptr, ldc));

CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle));
});

TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
DLTensor* B = args[1];
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];
// call gemm for simple compact code.
CHECK_EQ(A->ndim, 3);
CHECK_EQ(B->ndim, 3);
CHECK_EQ(C->ndim, 3);
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));

rocblas_handle handle;
CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle));
float alpha = 1.0;
float beta = 0.0;
float* A_ptr = reinterpret_cast<float*>(static_cast<char*>(A->data) + A->byte_offset);
float* B_ptr = reinterpret_cast<float*>(static_cast<char*>(B->data) + B->byte_offset);
float* C_ptr = reinterpret_cast<float*>(static_cast<char*>(C->data) + C->byte_offset);

rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none;
rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none;
size_t batch_size = C->shape[0];
size_t N = transb ? B->shape[1] : B->shape[2];
size_t M = transa ? A->shape[2] : A->shape[1];
size_t K = transb ? B->shape[2] : B->shape[1];
size_t lda = transa ? M : K;
size_t ldb = transb ? K : N;
size_t ldc = N;

CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched(
handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, K * N, A_ptr, lda, M * K,
&beta, C_ptr, ldc, M * N, batch_size));
});
} // namespace contrib
} // namespace tvm
66 changes: 64 additions & 2 deletions tests/python/contrib/test_rocblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import tvm.testing
from tvm import te
import numpy as np
import tvm.topi.testing
import tvm.testing
from tvm.contrib import rocblas


@tvm.testing.requires_rocm
def test_matmul_add():
def test_matmul():
n = 1024
l = 128
m = 235
Expand All @@ -46,5 +48,65 @@ def verify(target="rocm"):
verify()


def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False, dtype="float32"):
ashape = (batch, k, m) if transa else (batch, m, k)
bshape = (batch, n, k) if transb else (batch, k, n)
A = te.placeholder(ashape, name="A", dtype=dtype)
B = te.placeholder(bshape, name="B", dtype=dtype)
C = lib.batch_matmul(A, B, transa, transb)
s = te.create_schedule(C.op)

def get_numpy(a, b, transa, transb):
if transa:
a = a.transpose(0, 2, 1)
if not transb:
b = b.transpose(0, 2, 1)
return tvm.topi.testing.batch_matmul(a, b)

def verify(target="rocm"):
if not tvm.testing.device_enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func(lib.__name__ + ".batch_matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5
)

verify()


@tvm.testing.requires_rocm
def test_batch_matmul():
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=True)
verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=False)
verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=True)
verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=False)
verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=True)
verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=False)
verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=True)
verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=False)
verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=True)


if __name__ == "__main__":
test_matmul_add()
test_matmul()
test_batch_matmul()

0 comments on commit 01554ef

Please sign in to comment.