From bc40d8cbc98ef0674901c1f7dc07c25336b2306a Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Tue, 21 May 2019 16:05:28 -0700 Subject: [PATCH] [Contrib] cblas batch_matmul (#3210) --- cmake/modules/contrib/BLAS.cmake | 6 ++- python/tvm/contrib/cblas.py | 55 +++++++++++++++++--- src/contrib/cblas/cblas.cc | 40 ++++++++++---- src/contrib/cblas/gemm_common.h | 6 ++- tests/python/contrib/test_cblas.py | 83 ++++++++++++++++++++++++++---- 5 files changed, 159 insertions(+), 31 deletions(-) diff --git a/cmake/modules/contrib/BLAS.cmake b/cmake/modules/contrib/BLAS.cmake index 09526ef38f6b..a9a15c01b209 100644 --- a/cmake/modules/contrib/BLAS.cmake +++ b/cmake/modules/contrib/BLAS.cmake @@ -10,7 +10,11 @@ elseif(USE_BLAS STREQUAL "mkl") if(NOT IS_DIRECTORY ${USE_MKL_PATH}) set(USE_MKL_PATH /opt/intel/mkl) endif() - find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + if(APPLE) + find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + elseif(UNIX) + find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64) + endif() include_directories(${USE_MKL_PATH}/include) list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY}) list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC}) diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index eb32cc490347..5fefcd90706d 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -1,10 +1,10 @@ """External function interface to BLAS libraries.""" from __future__ import absolute_import as _abs -from .. import api as _api -from .. import intrin as _intrin +from .. import api as _api, intrin as _intrin -def matmul(lhs, rhs, transa=False, transb=False): + +def matmul(lhs, rhs, transa=False, transb=False, **kwargs): """Create an extern op that compute matrix mult of A and rhs with CrhsLAS This function serves as an example on how to call external libraries. @@ -28,7 +28,50 @@ def matmul(lhs, rhs, transa=False, transb=False): n = lhs.shape[1] if transa else lhs.shape[0] m = rhs.shape[0] if transb else rhs.shape[1] return _api.extern( - (n, m), [lhs, rhs], + (n, m), + [lhs, rhs], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb + ), + name="C", + **kwargs + ) + + +def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs): + """Create an extern op that compute batched matrix mult of A and rhs with CBLAS + This function serves as an example on how to call external libraries. + Parameters + ---------- + lhs : Tensor + The left matrix operand + rhs : Tensor + The right matrix operand + transa : bool + Whether transpose lhs + transb : bool + Whether transpose rhs + Returns + ------- + C : Tensor + The result tensor. + """ + b = lhs.shape[0] + n = lhs.shape[2] if transa else lhs.shape[1] + m = rhs.shape[1] if transb else rhs.shape[2] + return _api.extern( + (b, n, m), + [lhs, rhs], lambda ins, outs: _intrin.call_packed( - "tvm.contrib.cblas.matmul", - ins[0], ins[1], outs[0], transa, transb), name="C") + "tvm.contrib.cblas.batch_matmul" + if not iterative + else "tvm.contrib.cblas.batch_matmul_iterative", + ins[0], + ins[1], + outs[0], + transa, + transb, + ), + name="C", + **kwargs + ) diff --git a/src/contrib/cblas/cblas.cc b/src/contrib/cblas/cblas.cc index eeddcf8908ef..0f222e2f2a39 100644 --- a/src/contrib/cblas/cblas.cc +++ b/src/contrib/cblas/cblas.cc @@ -133,8 +133,25 @@ struct CblasDgemmBatchOp { } }; +struct CblasDgemmBatchIterativeOp { + typedef double TDatatype; + void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, + int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, + int c_stride, int ldc) { + CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta); + CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb); + for (int i = 0; i < batch_size; ++i) { + cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); + A += a_stride; + B += b_stride; + C += c_stride; + } + } +}; + // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -144,7 +161,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRet CallGemm(args, ret, CblasDgemmOp()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") +.set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { @@ -155,14 +173,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, }); TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") - .set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - LOG(FATAL) << "Unhandled type"; - } - }); +.set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/cblas/gemm_common.h b/src/contrib/cblas/gemm_common.h index 103b4ae236c1..7d6e3fa331fc 100644 --- a/src/contrib/cblas/gemm_common.h +++ b/src/contrib/cblas/gemm_common.h @@ -24,6 +24,8 @@ */ #pragma once +#include +#include #include #include #include @@ -176,5 +178,5 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { static_cast(beta), C_data, C_size, ColumnStride3D(C)); } -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index 890820ba4519..f4a98cf071ba 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -1,18 +1,25 @@ import tvm import numpy as np +import topi.testing from tvm.contrib import cblas -def test_matmul_add(): - n = 1024 - l = 128 - m = 235 - bias = tvm.var('bias', dtype=tvm.float32) - A = tvm.placeholder((n, l), name='A') - B = tvm.placeholder((l, m), name='B') - C = cblas.matmul(A, B) +def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32): + bias = tvm.var('bias', dtype=dtype) + ashape = (l, n) if transa else (n, l) + bshape = (m, l) if transb else (l, m) + A = tvm.placeholder(ashape, name='A', dtype=dtype) + B = tvm.placeholder(bshape, name='B', dtype=dtype) + C = cblas.matmul(A, B, transa, transb) D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D") s = tvm.create_schedule(D.op) + def get_numpy(a, b, bb, transa, transb): + if transa: + a = a.transpose() + if transb: + b = b.transpose() + return np.dot(a, b) + bb + def verify(target="llvm"): if not tvm.module.enabled(target): print("skip because %s is not enabled..." % target) @@ -22,15 +29,69 @@ def verify(target="llvm"): return ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D, bias], target) - a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx) bb = 10.0 f(a, b, d, bb) tvm.testing.assert_allclose( - d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5) + d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5) + verify() + +def test_matmul_add(): + verify_matmul_add(235, 128, 1024) + verify_matmul_add(235, 128, 1024, True, False) + verify_matmul_add(235, 128, 1024, False, True) + verify_matmul_add(235, 128, 1024, True, True) + verify_matmul_add(1, 16, 4) + verify_matmul_add(1, 16, 3, True, False) + verify_matmul_add(1, 16, 3, False, False) + verify_matmul_add(1, 16, 3, True, True) + +def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32): + ashape = (batch, l, n) if transa else (batch, n, l) + bshape = (batch, m, l) if transb else (batch, l, m) + A = tvm.placeholder(ashape, name='A', dtype=dtype) + B = tvm.placeholder(bshape, name='B', dtype=dtype) + C = cblas.batch_matmul(A, B, transa, transb) + D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D") + s = tvm.create_schedule(D.op) + + def get_numpy(a, b, transa, transb): + if transa: + a = a.transpose(0, 2, 1) + if not transb: + b = b.transpose(0, 2, 1) + return topi.testing.batch_matmul(a, b) + + def verify(target="llvm"): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + ctx = tvm.cpu(0) + f = tvm.build(s, [A, B, D], target) + a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx) + f(a, b, d) + tvm.testing.assert_allclose( + d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5) verify() +def test_batch_matmul(): + verify_batch_matmul(16, 235, 128, 1024) + verify_batch_matmul(16, 235, 128, 1024, True, False) + verify_batch_matmul(16, 235, 128, 1024, False, True) + verify_batch_matmul(16, 235, 128, 1024, True, True) + verify_batch_matmul(1, 1, 16, 3) + verify_batch_matmul(1, 1, 16, 3, True, False) + verify_batch_matmul(1, 1, 16, 3, False, False) + verify_batch_matmul(1, 1, 16, 3, True, True) + verify_batch_matmul(1, 1, 16, 3, iterative=True) if __name__ == "__main__": test_matmul_add() + test_batch_matmul()