Skip to content

Commit

Permalink
[Contrib] cblas batch_matmul (apache#3210)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 committed May 24, 2019
1 parent 862dae7 commit bc40d8c
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 31 deletions.
6 changes: 5 additions & 1 deletion cmake/modules/contrib/BLAS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ elseif(USE_BLAS STREQUAL "mkl")
if(NOT IS_DIRECTORY ${USE_MKL_PATH})
set(USE_MKL_PATH /opt/intel/mkl)
endif()
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
if(APPLE)
find_library(BLAS_LIBRARY NAMES mklml HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
elseif(UNIX)
find_library(BLAS_LIBRARY NAMES mkl_rt mklml_gnu HINTS ${USE_MKL_PATH}/lib/ ${USE_MKL_PATH}/lib/intel64)
endif()
include_directories(${USE_MKL_PATH}/include)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY})
list(APPEND RUNTIME_SRCS ${CBLAS_CONTRIB_SRC})
Expand Down
55 changes: 49 additions & 6 deletions python/tvm/contrib/cblas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""External function interface to BLAS libraries."""
from __future__ import absolute_import as _abs

from .. import api as _api
from .. import intrin as _intrin
from .. import api as _api, intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False):

def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
"""Create an extern op that compute matrix mult of A and rhs with CrhsLAS
This function serves as an example on how to call external libraries.
Expand All @@ -28,7 +28,50 @@ def matmul(lhs, rhs, transa=False, transb=False):
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
return _api.extern(
(n, m), [lhs, rhs],
(n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb
),
name="C",
**kwargs
)


def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs):
"""Create an extern op that compute batched matrix mult of A and rhs with CBLAS
This function serves as an example on how to call external libraries.
Parameters
----------
lhs : Tensor
The left matrix operand
rhs : Tensor
The right matrix operand
transa : bool
Whether transpose lhs
transb : bool
Whether transpose rhs
Returns
-------
C : Tensor
The result tensor.
"""
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
return _api.extern(
(b, n, m),
[lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cblas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
"tvm.contrib.cblas.batch_matmul"
if not iterative
else "tvm.contrib.cblas.batch_matmul_iterative",
ins[0],
ins[1],
outs[0],
transa,
transb,
),
name="C",
**kwargs
)
40 changes: 29 additions & 11 deletions src/contrib/cblas/cblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,25 @@ struct CblasDgemmBatchOp {
}
};

struct CblasDgemmBatchIterativeOp {
typedef double TDatatype;
void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A,
int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C,
int c_stride, int ldc) {
CBLAS_TRANSPOSE trans_a = BooleanToTranspose(ta);
CBLAS_TRANSPOSE trans_b = BooleanToTranspose(tb);
for (int i = 0; i < batch_size; ++i) {
cblas_dgemm(CblasColMajor, trans_a, trans_b, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
A += a_stride;
B += b_stride;
C += c_stride;
}
}
};

// matrix multiplication for row major
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));

Expand All @@ -144,7 +161,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRet
CallGemm(args, ret, CblasDgemmOp());
});

TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
Expand All @@ -155,14 +173,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args,
});

TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative")
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
} else {
LOG(FATAL) << "Unhandled type";
}
});
.set_body([](TVMArgs args, TVMRetValue* ret) {
DLTensor* A = args[0];
CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64));
if (TypeMatch(A->dtype, kDLFloat, 32)) {
CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp());
} else {
CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp());
}
});
} // namespace contrib
} // namespace tvm
6 changes: 4 additions & 2 deletions src/contrib/cblas/gemm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
*/
#pragma once

#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <algorithm>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
Expand Down Expand Up @@ -176,5 +178,5 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
}

} // namespace contrib
} // namespace tvm
} // namespace contrib
} // namespace tvm
83 changes: 72 additions & 11 deletions tests/python/contrib/test_cblas.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import tvm
import numpy as np
import topi.testing
from tvm.contrib import cblas

def test_matmul_add():
n = 1024
l = 128
m = 235
bias = tvm.var('bias', dtype=tvm.float32)
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((l, m), name='B')
C = cblas.matmul(A, B)
def verify_matmul_add(m, l, n, transa=False, transb=False, dtype=tvm.float32):
bias = tvm.var('bias', dtype=dtype)
ashape = (l, n) if transa else (n, l)
bshape = (m, l) if transb else (l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
C = cblas.matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda i, j: C[i,j] + bias, name="D")
s = tvm.create_schedule(D.op)

def get_numpy(a, b, bb, transa, transb):
if transa:
a = a.transpose()
if transb:
b = b.transpose()
return np.dot(a, b) + bb

def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
Expand All @@ -22,15 +29,69 @@ def verify(target="llvm"):
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D, bias], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
bb = 10.0
f(a, b, d, bb)
tvm.testing.assert_allclose(
d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), bb, transa, transb), rtol=1e-5)
verify()

def test_matmul_add():
verify_matmul_add(235, 128, 1024)
verify_matmul_add(235, 128, 1024, True, False)
verify_matmul_add(235, 128, 1024, False, True)
verify_matmul_add(235, 128, 1024, True, True)
verify_matmul_add(1, 16, 4)
verify_matmul_add(1, 16, 3, True, False)
verify_matmul_add(1, 16, 3, False, False)
verify_matmul_add(1, 16, 3, True, True)

def verify_batch_matmul(batch, m, l, n, transa=False, transb=False, iterative=False, dtype=tvm.float32):
ashape = (batch, l, n) if transa else (batch, n, l)
bshape = (batch, m, l) if transb else (batch, l, m)
A = tvm.placeholder(ashape, name='A', dtype=dtype)
B = tvm.placeholder(bshape, name='B', dtype=dtype)
C = cblas.batch_matmul(A, B, transa, transb)
D = tvm.compute(C.shape, lambda k, i, j: C[k, i,j], name="D")
s = tvm.create_schedule(D.op)

def get_numpy(a, b, transa, transb):
if transa:
a = a.transpose(0, 2, 1)
if not transb:
b = b.transpose(0, 2, 1)
return topi.testing.batch_matmul(a, b)

def verify(target="llvm"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
print("skip because extern function is not available")
return
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, D], target)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros((batch, n, m), dtype=D.dtype), ctx)
f(a, b, d)
tvm.testing.assert_allclose(
d.asnumpy(), get_numpy(a.asnumpy(), b.asnumpy(), transa, transb), rtol=1e-5)
verify()

def test_batch_matmul():
verify_batch_matmul(16, 235, 128, 1024)
verify_batch_matmul(16, 235, 128, 1024, True, False)
verify_batch_matmul(16, 235, 128, 1024, False, True)
verify_batch_matmul(16, 235, 128, 1024, True, True)
verify_batch_matmul(1, 1, 16, 3)
verify_batch_matmul(1, 1, 16, 3, True, False)
verify_batch_matmul(1, 1, 16, 3, False, False)
verify_batch_matmul(1, 1, 16, 3, True, True)
verify_batch_matmul(1, 1, 16, 3, iterative=True)

if __name__ == "__main__":
test_matmul_add()
test_batch_matmul()

0 comments on commit bc40d8c

Please sign in to comment.