Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic load mkl as a fft backend when it is avaialble and requested #36414

Merged
merged 8 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,21 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()


if (WITH_GPU AND (NOT WITH_ROCM))
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc spectral_op.cu DEPS dynload_cuda ${OP_HEADER_DEPS})
endif()
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
if (MKL_FOUND AND WITH_ONEMKL)
op_library(spectral_op SRCS spectral_op.cc DEPS dynload_mklrt ${OP_HEADER_DEPS})
target_include_directories(spectral_op PRIVATE ${MKL_INCLUDE})
else()
op_library(spectral_op SRCS spectral_op.cc DEPS ${OP_HEADER_DEPS})
endif()
endif()

op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute)
Expand Down
113 changes: 54 additions & 59 deletions paddle/fluid/operators/spectral_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "paddle/fluid/platform/complex.h"

#if defined(PADDLE_WITH_ONEMKL)
#include <mkl_dfti.h>
#include "paddle/fluid/platform/dynload/mklrt.h"
#elif defined(PADDLE_WITH_POCKETFFT)
#include "extern_pocketfft/pocketfft_hdronly.h"
#endif
Expand Down Expand Up @@ -357,46 +357,45 @@ FFTNormMode get_norm_from_string(const std::string& norm, bool forward) {
// FFT Functors
#if defined(PADDLE_WITH_ONEMKL)

#define MKL_DFTI_CHECK(expr) \
do { \
MKL_LONG status = (expr); \
if (platform::dynload::DftiErrorClass(status, DFTI_NO_ERROR)) \
PADDLE_THROW(platform::errors::External( \
platform::dynload::DftiErrorMessage(status))); \
} while (0);

namespace {
static inline void MKL_DFTI_CHECK(MKL_INT status) {
if (status && !DftiErrorClass(status, DFTI_NO_ERROR)) {
PADDLE_THROW(platform::errors::External(DftiErrorMessage(status)));
}
}

struct DftiDescriptorDeleter {
void operator()(DFTI_DESCRIPTOR_HANDLE handle) {
if (handle != nullptr) {
MKL_DFTI_CHECK(DftiFreeDescriptor(&handle));
MKL_DFTI_CHECK(platform::dynload::DftiFreeDescriptor(&handle));
}
}
};

// A RAII wrapper for MKL_DESCRIPTOR*
class DftiDescriptor {
public:
void init(DFTI_CONFIG_VALUE precision, DFTI_CONFIG_VALUE signal_type,
MKL_LONG signal_ndim, MKL_LONG* sizes) {
if (desc_ != nullptr) {
PADDLE_THROW(platform::errors::AlreadyExists(
"DFT DESCRIPTOR can only be initialized once."));
}
PADDLE_ENFORCE_EQ(desc_.get(), nullptr,
platform::errors::AlreadyExists(
"DftiDescriptor has already been initialized."));

DFTI_DESCRIPTOR* raw_desc;
if (signal_ndim == 1) {
MKL_DFTI_CHECK(
DftiCreateDescriptor(&raw_desc, precision, signal_type, 1, sizes[0]));
} else {
MKL_DFTI_CHECK(DftiCreateDescriptor(&raw_desc, precision, signal_type,
signal_ndim, sizes));
}
MKL_DFTI_CHECK(platform::dynload::DftiCreateDescriptorX(
&raw_desc, precision, signal_type, signal_ndim, sizes));
desc_.reset(raw_desc);
}

DFTI_DESCRIPTOR* get() const {
if (desc_ == nullptr) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
}
return desc_.get();
DFTI_DESCRIPTOR* raw_desc = desc_.get();
PADDLE_ENFORCE_NOT_NULL(raw_desc,
platform::errors::PreconditionNotMet(
"DFTI DESCRIPTOR has not been initialized."));
return raw_desc;
}

private:
Expand All @@ -421,7 +420,9 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_DOUBLE;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Input data type should be FP32, FP64, COMPLEX64 or COMPLEX128."));
"Invalid input datatype (%s), input data type should be FP32, "
"FP64, COMPLEX64 or COMPLEX128.",
framework::DataTypeToString(in_dtype)));
}
}();

Expand All @@ -430,35 +431,27 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
const DFTI_CONFIG_VALUE domain =
(fft_type == FFTTransformType::C2C) ? DFTI_COMPLEX : DFTI_REAL;

// const bool complex_input = framework::IsComplexType(in_dtype);
// const bool complex_output = framework::IsComplexType(out_dtype);
// const DFTI_CONFIG_VALUE domain = [&] {
// if (forward) {
// return complex_input ? DFTI_COMPLEX : DFTI_REAL;
// } else {
// return complex_output ? DFTI_COMPLEX : DFTI_REAL;
// }
// }();

DftiDescriptor descriptor;
std::vector<MKL_LONG> fft_sizes(signal_sizes.cbegin(), signal_sizes.cend());
const MKL_LONG signal_ndim = fft_sizes.size() - 1;
descriptor.init(precision, domain, signal_ndim, fft_sizes.data() + 1);

// placement inplace or not inplace
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));

// number of transformations
const MKL_LONG batch_size = fft_sizes[0];
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, batch_size));

// input & output distance
const MKL_LONG idist = in_strides[0];
const MKL_LONG odist = out_strides[0];
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_INPUT_DISTANCE, idist));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
DFTI_OUTPUT_DISTANCE, odist));

// input & output stride
std::vector<MKL_LONG> mkl_in_stride(1 + signal_ndim, 0);
Expand All @@ -467,15 +460,15 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
mkl_in_stride[i] = in_strides[i];
mkl_out_stride[i] = out_strides[i];
}
MKL_DFTI_CHECK(
DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES,
mkl_out_stride.data()));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_INPUT_STRIDES, mkl_in_stride.data()));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_out_stride.data()));

// conjugate even storage
if (!(fft_type == FFTTransformType::C2C)) {
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE,
DFTI_COMPLEX_COMPLEX));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(
descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
}

MKL_LONG signal_numel =
Expand All @@ -496,11 +489,12 @@ DftiDescriptor _plan_mkl_fft(const framework::proto::VarType::Type& in_dtype,
return DFTI_BACKWARD_SCALE;
}
}();
MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
MKL_DFTI_CHECK(platform::dynload::DftiSetValue(descriptor.get(),
scale_direction, scale));
}

// commit the descriptor
MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
MKL_DFTI_CHECK(platform::dynload::DftiCommitDescriptor(descriptor.get()));
return descriptor;
}

Expand Down Expand Up @@ -592,15 +586,16 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
collapsed_input.numel(),
collapsed_input_conj.data<Ti>());
for_range(functor);
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input_conj.data<void>(),
collapsed_output.data<void>()));
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
desc.get(), collapsed_input_conj.data<void>(),
collapsed_output.data<void>()));
} else if (fft_type == FFTTransformType::R2C && !forward) {
framework::Tensor collapsed_output_conj(collapsed_output.type());
collapsed_output_conj.mutable_data<To>(collapsed_output.dims(),
ctx.GetPlace());
MKL_DFTI_CHECK(DftiComputeForward(desc.get(), collapsed_input.data<void>(),
collapsed_output_conj.data<void>()));
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output_conj.data<void>()));
// conjugate the output
platform::ForRange<DeviceContext> for_range(ctx, collapsed_output.numel());
math::ConjFunctor<To> functor(collapsed_output_conj.data<To>(),
Expand All @@ -609,13 +604,13 @@ void exec_fft(const DeviceContext& ctx, const Tensor* x, Tensor* out,
for_range(functor);
} else {
if (forward) {
MKL_DFTI_CHECK(DftiComputeForward(desc.get(),
collapsed_input.data<void>(),
collapsed_output.data<void>()));
MKL_DFTI_CHECK(platform::dynload::DftiComputeForward(
desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
} else {
MKL_DFTI_CHECK(DftiComputeBackward(desc.get(),
collapsed_input.data<void>(),
collapsed_output.data<void>()));
MKL_DFTI_CHECK(platform::dynload::DftiComputeBackward(
desc.get(), collapsed_input.data<void>(),
collapsed_output.data<void>()));
}
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ endif()
cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader)
add_dependencies(dynload_lapack extern_lapack)
# TODO(TJ): add iomp, mkldnn?

if (MKL_FOUND AND WITH_ONEMKL)
message("ONEMKL INCLUDE directory is ${MKL_INCLUDE}")
cc_library(dynload_mklrt SRCS mklrt.cc DEPS dynamic_loader)
target_include_directories(dynload_mklrt PRIVATE ${MKL_INCLUDE})
endif()
16 changes: 16 additions & 0 deletions paddle/fluid/platform/dynload/dynamic_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");

DEFINE_string(lapack_dir, "", "Specify path for loading liblapack.so.");

DEFINE_string(mkl_dir, "",
"Specify path for loading libmkl_core.so. "
iclementine marked this conversation as resolved.
Show resolved Hide resolved
"For insrance, /opt/intel/oneapi/mkl/latest/lib/intel64/."
"If default, "
"dlopen will search mkl from LD_LIBRARY_PATH");

DEFINE_string(op_dir, "", "Specify path for loading user-defined op library.");

#ifdef PADDLE_WITH_HIP
Expand Down Expand Up @@ -518,6 +524,16 @@ void* GetCUFFTDsoHandle() {
#endif
}

void* GetMKLRTDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "mkl_rt.dll");
#else
return GetDsoHandleFromSearchPath(FLAGS_mkl_dir, "libmkl_rt.so");
#endif
}

} // namespace dynload
} // namespace platform
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/dynamic_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void* GetLAPACKDsoHandle();
void* GetOpDsoHandle(const std::string& dso_name);
void* GetNvtxDsoHandle();
void* GetCUFFTDsoHandle();
void* GetMKLRTDsoHandle();

void SetPaddleLibPath(const std::string&);
} // namespace dynload
Expand Down
51 changes: 51 additions & 0 deletions paddle/fluid/platform/dynload/mklrt.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/dynload/mklrt.h"

namespace paddle {
namespace platform {
namespace dynload {

std::once_flag mklrt_dso_flag;
void* mklrt_dso_handle = nullptr;

#define DEFINE_WRAP(__name) DynLoad__##__name __name

MKLDFTI_ROUTINE_EACH(DEFINE_WRAP);

DFTI_EXTERN MKL_LONG DftiCreateDescriptorX(DFTI_DESCRIPTOR_HANDLE* desc,
enum DFTI_CONFIG_VALUE prec,
enum DFTI_CONFIG_VALUE domain,
MKL_LONG dim, MKL_LONG* sizes) {
if (prec == DFTI_SINGLE) {
if (dim == 1) {
return DftiCreateDescriptor_s_1d(desc, domain, sizes[0]);
} else {
return DftiCreateDescriptor_s_md(desc, domain, dim, sizes);
}
} else if (prec == DFTI_DOUBLE) {
if (dim == 1) {
return DftiCreateDescriptor_d_1d(desc, domain, sizes[0]);
} else {
return DftiCreateDescriptor_d_md(desc, domain, dim, sizes);
}
} else {
return DftiCreateDescriptor(desc, prec, domain, dim, sizes);
}
}

} // namespace dynload
} // namespace platform
} // namespace paddle
Loading