diff --git a/include/tvm/topi/contrib/rocblas.h b/include/tvm/topi/contrib/rocblas.h index a4fa26f34aa5..4f0b887fb178 100644 --- a/include/tvm/topi/contrib/rocblas.h +++ b/include/tvm/topi/contrib/rocblas.h @@ -54,6 +54,29 @@ inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, }, "C", "", {})[0]; } +/*! + * \brief Create an op that batch multiplies lhs and rhs with rocBLAS + * + * \param lhs The left matrix operand e.g. (batch_size, M, K) + * \param rhs The right matrix operand e.g. (batch_size, K, N) + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor rocblas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { + auto batch_size = lhs->shape[0]; + auto n = transa ? lhs->shape[2] : lhs->shape[1]; + auto m = transb ? rhs->shape[1] : rhs->shape[2]; + + return make_extern( + {{batch_size, n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImm("tvm.contrib.rocblas.batch_matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; +} } // namespace contrib } // namespace topi diff --git a/python/tvm/contrib/rocblas.py b/python/tvm/contrib/rocblas.py index 03ea2b52fff0..70791dca3152 100644 --- a/python/tvm/contrib/rocblas.py +++ b/python/tvm/contrib/rocblas.py @@ -48,3 +48,36 @@ def matmul(lhs, rhs, transa=False, transb=False): ), name="C", ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False): + """Create an extern op that compute matrix mult of A and rhs with rocBLAS + + Parameters + ---------- + lhs : Tensor + The left batched matrix operand + rhs : Tensor + The right batched matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + + Returns + ------- + C : Tensor + The result tensor. + """ + batch_size = lhs.shape[0] + assert batch_size == rhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return te.extern( + (batch_size, n, m), + [lhs, rhs], + lambda ins, outs: tvm.tir.call_packed( + "tvm.contrib.rocblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + ) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 2410260e78ae..f52bbc36f12b 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -160,3 +160,24 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): plevel=15, ) return strategy + + +@batch_matmul_strategy.register("rocm") +def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): + """Batch matmul strategy for ROCM""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_batch_matmul(topi.cuda.batch_matmul), + wrap_topi_schedule(topi.cuda.schedule_batch_matmul), + name="batch_matmul.cuda", + plevel=10, + ) + if target.kind.name == "rocm" and "rocblas" in target.libs: + assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." + strategy.add_implementation( + wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas), + wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas), + name="batch_matmul_rocblas.rocm", + plevel=12, + ) + return strategy diff --git a/python/tvm/topi/rocm/__init__.py b/python/tvm/topi/rocm/__init__.py index 4efdab4aed41..1ea4c79aaea7 100644 --- a/python/tvm/topi/rocm/__init__.py +++ b/python/tvm/topi/rocm/__init__.py @@ -19,6 +19,7 @@ """rocm specific declaration and schedules.""" from __future__ import absolute_import as _abs +from .batch_matmul import * from .conv2d import * from .dense import * from .nn import * diff --git a/python/tvm/topi/rocm/batch_matmul.py b/python/tvm/topi/rocm/batch_matmul.py new file mode 100644 index 000000000000..fa4dd457f3ed --- /dev/null +++ b/python/tvm/topi/rocm/batch_matmul.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument +"""Schedule for batch_matmul operator""" +from tvm import autotvm +from tvm.contrib import rocblas +from .. import generic +from ..util import get_const_tuple + + +@autotvm.register_topi_compute("batch_matmul_rocblas.rocm") +def batch_matmul_rocblas(cfg, x, y, out_shape=None): + """Computes matrix multiplication of `x` and `y` via rocblas when + `x` and `y` are batched matrices. + + Parameters + ---------- + cfg : ConfigSpace + Autotvm tuning space config file + x : tvm.te.Tensor + 3-D with shape [batch, M, K] + y : tvm.te.Tensor + 3-D with shape [batch, N, K] + Returns + ------- + output : tvm.te.Tensor + 3-D with shape [batch, M, N] + """ + batch, M, K = get_const_tuple(x.shape) + _, N, _ = get_const_tuple(y.shape) + if out_shape is not None: + assert out_shape[0] == batch, "Input and output batch sizes must match" + assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape" + result = rocblas.batch_matmul(x, y, False, True) + cfg.add_flop(batch * M * N * K * 2) + return result + + +@autotvm.register_topi_schedule("batch_matmul_rocblas.rocm") +def schedule_batch_matmul_rocblas(_, outs): + """Schedule for batch_matmul operator with rocm cblas""" + return generic.schedule_extern(outs) diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index 0e6f4bd69686..bca00a591d48 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -70,18 +70,61 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMR CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); float alpha = 1.0; float beta = 0.0; - float* A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); - float* B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* A_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_ROCBLAS_ERROR( - rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0], - transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr, - A->shape[1], &beta, C_ptr, C->shape[1])); + rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none; + size_t N = transb ? B->shape[0] : B->shape[1]; + size_t M = transa ? A->shape[1] : A->shape[0]; + size_t K = transb ? B->shape[1] : B->shape[0]; + size_t lda = transa ? M : K; + size_t ldb = transb ? K : N; + size_t ldc = N; + + CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, + A_ptr, lda, &beta, C_ptr, ldc)); CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); }); + +TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.batch_matmul") + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // call gemm for simple compact code. + CHECK_EQ(A->ndim, 3); + CHECK_EQ(B->ndim, 3); + CHECK_EQ(C->ndim, 3); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); + + rocblas_handle handle; + CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); + float alpha = 1.0; + float beta = 0.0; + float* A_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); + float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + rocblas_operation roc_trans_A = transa ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation roc_trans_B = transb ? rocblas_operation_transpose : rocblas_operation_none; + size_t batch_size = C->shape[0]; + size_t N = transb ? B->shape[1] : B->shape[2]; + size_t M = transa ? A->shape[2] : A->shape[1]; + size_t K = transb ? B->shape[2] : B->shape[1]; + size_t lda = transa ? M : K; + size_t ldb = transb ? K : N; + size_t ldc = N; + + CHECK_ROCBLAS_ERROR(rocblas_sgemm_strided_batched( + handle, roc_trans_B, roc_trans_A, N, M, K, &alpha, B_ptr, ldb, K * N, A_ptr, lda, M * K, + &beta, C_ptr, ldc, M * N, batch_size)); + }); } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_rocblas.py b/tests/python/contrib/test_rocblas.py index 9b8bacbb7191..6f1783daa74c 100644 --- a/tests/python/contrib/test_rocblas.py +++ b/tests/python/contrib/test_rocblas.py @@ -18,11 +18,13 @@ import tvm.testing from tvm import te import numpy as np +import tvm.topi.testing +import tvm.testing from tvm.contrib import rocblas @tvm.testing.requires_rocm -def test_matmul_add(): +def test_matmul(): n = 1024 l = 128 m = 235 @@ -46,5 +48,65 @@ def verify(target="rocm"): verify() +def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False, dtype="float32"): + ashape = (batch, k, m) if transa else (batch, m, k) + bshape = (batch, n, k) if transb else (batch, k, n) + A = te.placeholder(ashape, name="A", dtype=dtype) + B = te.placeholder(bshape, name="B", dtype=dtype) + C = lib.batch_matmul(A, B, transa, transb) + s = te.create_schedule(C.op) + + def get_numpy(a, b, transa, transb): + if transa: + a = a.transpose(0, 2, 1) + if not transb: + b = b.transpose(0, 2, 1) + return tvm.topi.testing.batch_matmul(a, b) + + def verify(target="rocm"): + if not tvm.testing.device_enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func(lib.__name__ + ".batch_matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.rocm(0) + f = tvm.build(s, [A, B, C], target) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), ctx) + f(a, b, c) + tvm.testing.assert_allclose( + c.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5 + ) + + verify() + + +@tvm.testing.requires_rocm +def test_batch_matmul(): + verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False) + verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True) + verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False) + verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=True) + verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=False) + verify_batch_matmul(128, 512, 512, 64, rocblas, transa=False, transb=True) + verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=False) + verify_batch_matmul(128, 512, 512, 64, rocblas, transa=True, transb=True) + verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=False) + verify_batch_matmul(128, 512, 64, 512, rocblas, transa=False, transb=True) + verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=False) + verify_batch_matmul(128, 512, 64, 512, rocblas, transa=True, transb=True) + verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=False) + verify_batch_matmul(128, 64, 128, 128, rocblas, transa=False, transb=True) + verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=False) + verify_batch_matmul(128, 64, 128, 128, rocblas, transa=True, transb=True) + verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=False) + verify_batch_matmul(128, 128, 128, 64, rocblas, transa=False, transb=True) + verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=False) + verify_batch_matmul(128, 128, 128, 64, rocblas, transa=True, transb=True) + + if __name__ == "__main__": - test_matmul_add() + test_matmul() + test_batch_matmul()