Skip to content

Commit

Permalink
Call to mkldnn_matmul from aten::addmm on AArch64 (pytorch#91763)
Browse files Browse the repository at this point in the history
We have noticed that on BERT_pytorch in torchbenchmark majority of time is spent in running GEMM in aten:addmm. At the moment this calls into BLAS routine, but on AArch64 it will be faster if it calls into mkldnn_matmul. Performance wise compared to build with OpenBLAS it runs faster 1.2x faster on 16 cores with batch size of 8 on Graviton3, while if fast math mode (mkldnn_matmul exposes through oneDNN and Arm Compute Library option to run GEMM with FP32 inputs using BBF16 operations) is enabled then it is 2.3x

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#91763
Approved by: https://github.com/jgong5, https://github.com/ngimel, https://github.com/malfet
  • Loading branch information
milpuz01 authored and snadampal committed Apr 4, 2023
1 parent c263bd4 commit b9f6f1e
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 28 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ header_template_rule(
include = "aten/src",
substitutions = {
"@AT_MKLDNN_ENABLED@": "1",
"@AT_MKLDNN_ACL_ENABLED@": "0",
"@AT_MKL_ENABLED@": "1",
"@AT_MKL_SEQUENTIAL@": "0",
"@AT_FFTW_ENABLED@": "0",
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/Config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h

#define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@
#define AT_MKLDNN_ACL_ENABLED() @AT_MKLDNN_ACL_ENABLED@
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
#define AT_MKL_SEQUENTIAL() @AT_MKL_SEQUENTIAL@
#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@
Expand Down
46 changes: 31 additions & 15 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1416,21 +1416,37 @@ static void addmm_impl_cpu_(
// Always ensure the conjugation for c is resolved since there's no way to specify c's conjugation in the gemm call
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c.is_conj());

// Apply BLAS routine
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpublas::gemm(
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
m, n, k,
alpha.to<opmath_t>(),
a.data_ptr<scalar_t>(), lda,
b.data_ptr<scalar_t>(), ldb,
beta.to<opmath_t>(),
c.data_ptr<scalar_t>(), ldc);
});
bool dispatched = false;
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
// On AArch64 if LHS matrix in BLAS routine is transposed but RHS is not then
// it is faster to call oneDNN matrix multiplication primitive with RHS*LHS
// that will call then into Arm® Compute Library (ACL) GEMM kernel and also
// additionally have support for running kernel with BF16 instructions
if(transpose_a && !transpose_b && result.scalar_type() == at::ScalarType::Float) {
mkldnn_matmul(b, a, c, beta.to<float>(), alpha.to<float>());
// We have dispatched to ACL GEMM for single precision float
// so do not need to dispatch to BLAS GEMM below
dispatched = true;
}
#endif

if(!dispatched) {
// Apply BLAS routine
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
result.scalar_type(), "addmm_impl_cpu_",
[&]{
using opmath_t = at::opmath_type<scalar_t>;
at::native::cpublas::gemm(
transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
transpose_b ? b.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
m, n, k,
alpha.to<opmath_t>(),
a.data_ptr<scalar_t>(), lda,
b.data_ptr<scalar_t>(), ldb,
beta.to<opmath_t>(),
c.data_ptr<scalar_t>(), ldc);
});
}

if (!c.is_same(result)) {
result.copy_(c);
Expand Down
28 changes: 15 additions & 13 deletions aten/src/ATen/native/mkldnn/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,25 @@ void mkldnn_matmul(
(mat1.dim() == 1 && mat2.dim() == 1), // aten::dot
"mkldnn_matmul: unsupported dims for mat and mat2");

#if defined(__aarch64__)
// oneDNN fast-maths mode (enabled by setting the environment variable ONEDNN_DEFAULT_FPMATH_MODE=BF16) will dispatch
// fp32 inputs to bf16 kernels where HW permits. So, both fp32 and bf16 inputs are permitted.
TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
"mkldnn_matmul: only enabled for fp32 and bf16 path");
// device needs to support bf16 if the inputs are of bf16 type
if (mat1.scalar_type() == at::kBFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check_arm(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs a cpu with bf16 support");
}
#else
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq, or AWS Graviton3");

#if defined(__aarch64__)
if (mkldnn_bf16_device_check_arm()) {
//onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
//so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
"mkldnn_matmul: only enabled for fp32 and bf16 path");
} else
TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
mat2.scalar_type() == at::kBFloat16 &&
result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
#endif
{
TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
mat2.scalar_type() == at::kBFloat16 &&
result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
}

auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/test/verify_api_visibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#error "AT_MKLDNN_ENABLED should not be visible in public headers"
#endif

#ifdef AT_MKLDNN_ACL_ENABLED
#error "AT_MKLDNN_ACL_ENABLED should not be visible in public headers"
#endif

#ifdef CAFFE2_STATIC_LINK_CUDA
#error "CAFFE2_STATIC_LINK_CUDA should not be visible in public headers"
#endif
Expand Down
4 changes: 4 additions & 0 deletions buckbuild.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def get_aten_preprocessor_flags():
"-DCAFFE2_USE_LITE_PROTO",
"-DATEN_CUDNN_ENABLED_FBXPLAT=0",
"-DATEN_MKLDNN_ENABLED_FBXPLAT=0",
"-DATEN_MKLDNN_ACL_ENABLED_FBXPLAT=0",
"-DATEN_NNPACK_ENABLED_FBXPLAT=0",
"-DATEN_MKL_ENABLED_FBXPLAT=0",
"-DATEN_MKL_SEQUENTIAL_FBXPLAT=0",
Expand Down Expand Up @@ -1042,6 +1043,9 @@ def define_buck_targets(
"@AT_MKLDNN_ENABLED@",
"ATEN_MKLDNN_ENABLED_FBXPLAT",
"--replace",
"@AT_MKLDNN_ACL_ENABLED@",
"ATEN_MKLDNN_ACL_ENABLED_FBXPLAT",
"--replace",
"@AT_MKL_ENABLED@",
"ATEN_MKL_ENABLED_FBXPLAT",
"--replace",
Expand Down
5 changes: 5 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ endif()

# ---[ BLAS

set(AT_MKLDNN_ACL_ENABLED 0)
# setting default preferred BLAS options if not already present.
if(NOT INTERN_BUILD_MOBILE)
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
Expand Down Expand Up @@ -1741,6 +1742,7 @@ if(NOT INTERN_BUILD_MOBILE)
endif()

set(AT_MKLDNN_ENABLED 0)
set(AT_MKLDNN_ACL_ENABLED 0)
if(USE_MKLDNN)
if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
message(WARNING
Expand All @@ -1749,6 +1751,9 @@ if(NOT INTERN_BUILD_MOBILE)
"Turn this warning off by USE_MKLDNN=OFF.")
set(USE_MKLDNN OFF)
endif()
if(USE_MKLDNN_ACL)
set(AT_MKLDNN_ACL_ENABLED 1)
endif()
endif()
if(USE_MKLDNN)
include(${CMAKE_CURRENT_LIST_DIR}/public/mkldnn.cmake)
Expand Down

0 comments on commit b9f6f1e

Please sign in to comment.