diff --git a/README.md b/README.md
index 489661400..41e97f751 100644
--- a/README.md
+++ b/README.md
@@ -224,9 +224,14 @@ Supported domains: BLAS, LAPACK, RNG, DFT
LLVM*, hipSYCL |
- | DFT |
+ DFT |
Intel GPU |
- Intel(R) oneAPI Math Kernel Library |
+ Intel(R) oneAPI Math Kernel Library |
+ Dynamic, Static |
+ DPC++ |
+
+
+ | x86 CPU |
Dynamic, Static |
DPC++ |
diff --git a/examples/dft/compile_time_dispatching/CMakeLists.txt b/examples/dft/compile_time_dispatching/CMakeLists.txt
index 4556a0cb0..59bf557a6 100644
--- a/examples/dft/compile_time_dispatching/CMakeLists.txt
+++ b/examples/dft/compile_time_dispatching/CMakeLists.txt
@@ -19,28 +19,45 @@
#Build object from all sources
set(DFTI_CT_SOURCES "")
+
if(ENABLE_MKLGPU_BACKEND)
list(APPEND DFTI_CT_SOURCES "complex_fwd_buffer_mklgpu")
endif()
+if(ENABLE_MKLCPU_BACKEND)
+ list(APPEND DFTI_CT_SOURCES "complex_fwd_buffer_mklcpu")
+endif()
+
include(WarningsUtils)
foreach(dfti_ct_sources ${DFTI_CT_SOURCES})
- add_executable(example_${domain}_${dfti_ct_sources} ${dfti_ct_sources}.cpp)
+ # add executable and define include directories
+ # add dependencies and link libraries
+ # register example as ctest
+ add_executable(example_${domain}_${dfti_ct_sources} ${dfti_ct_sources}.cpp)
target_include_directories(example_${domain}_${dfti_ct_sources}
PUBLIC ${PROJECT_SOURCE_DIR}/examples/include
PUBLIC ${PROJECT_SOURCE_DIR}/include
PUBLIC ${CMAKE_BINARY_DIR}/bin
)
- if(domain STREQUAL "dft" AND ENABLE_MKLGPU_BACKEND)
- add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklgpu)
- list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklgpu)
+
+ set(ONEMKL_LIBRARIES_${domain} "")
+ if(domain STREQUAL "dft")
+ if(dfti_ct_sources MATCHES "_mklcpu$")
+ add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklcpu)
+ list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklcpu)
+ endif()
+ if(dfti_ct_sources MATCHES "_mklgpu$")
+ add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklgpu)
+ list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklgpu)
+ endif()
endif()
+
target_link_libraries(example_${domain}_${dfti_ct_sources}
PUBLIC ${ONEMKL_LIBRARIES_${domain}}
- PUBLIC ONEMKL::SYCL::SYCL
- PRIVATE onemkl_warnings
+ ${ONEMKL_LIBRARIES_${domain}}
+ ONEMKL::SYCL::SYCL
)
# Register example as ctest
- add_test(NAME ${domain}/EXAMPLE/CT/${dfti_ct_sources} COMMAND example_${domain}_${dfti_ct_sources})
-endforeach(dfti_ct_sources)
+ add_test(NAME ${domain}/EXAMPLE/CT/${dfti_ct_sources} COMMAND example_${domain}_${dfti_ct_sources})
+endforeach(dfti_ct_sources)
\ No newline at end of file
diff --git a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklcpu.cpp
new file mode 100644
index 000000000..26160aeed
--- /dev/null
+++ b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklcpu.cpp
@@ -0,0 +1,132 @@
+/*******************************************************************************
+* Copyright 2023 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+// STL includes
+#include
+
+// oneMKL/SYCL includes
+#if __has_include()
+#include
+#else
+#include
+#endif
+#include "oneapi/mkl.hpp"
+
+void run_example(const sycl::device& cpu_device) {
+ constexpr int N = 10;
+
+ // Catch asynchronous exceptions for cpu
+ auto cpu_error_handler = [&](sycl::exception_list exceptions) {
+ for (auto const& e : exceptions) {
+ try {
+ std::rethrow_exception(e);
+ }
+ catch (sycl::exception const& e) {
+ // Handle not dft related exceptions that happened during asynchronous call
+ std::cerr << "Caught asynchronous SYCL exception:" << std::endl;
+ std::cerr << "\t" << e.what() << std::endl;
+ }
+ }
+ std::exit(2);
+ };
+
+ sycl::queue cpu_queue(cpu_device, cpu_error_handler);
+
+ std::vector> input_data(N);
+ std::vector> output_data(N);
+
+ // enabling
+ // 1. create descriptors
+ oneapi::mkl::dft::descriptor
+ desc(N);
+
+ // 2. variadic set_value
+ desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
+ oneapi::mkl::dft::config_value::NOT_INPLACE);
+ desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS,
+ static_cast(1));
+
+ // 3. commit_descriptor (compile_time MKLCPU)
+ desc.commit(oneapi::mkl::backend_selector{ cpu_queue });
+
+ // 4. compute_forward / compute_backward (MKLCPU)
+ {
+ sycl::buffer> input_buffer(input_data.data(), sycl::range<1>(N));
+ sycl::buffer> output_buffer(output_data.data(), sycl::range<1>(N));
+ oneapi::mkl::dft::compute_forward,
+ std::complex>(desc, input_buffer, output_buffer);
+ }
+}
+
+//
+// Description of example setup, apis used and supported floating point type precisions
+//
+void print_example_banner() {
+ std::cout << "\n"
+ "########################################################################\n"
+ "# Complex out-of-place forward transform for Buffer API's example:\n"
+ "#\n"
+ "# Using APIs:\n"
+ "# Compile-time dispatch API\n"
+ "# Buffer forward complex out-of-place\n"
+ "#\n"
+ "# Using double precision (double) data type\n"
+ "#\n"
+ "# For Intel CPU with Intel MKLCPU backend.\n"
+ "#\n"
+ "# The environment variable SYCL_DEVICE_FILTER can be used to specify\n"
+ "# SYCL device\n"
+ "########################################################################\n"
+ << std::endl;
+}
+
+//
+// Main entry point for example.
+//
+int main(int argc, char** argv) {
+ print_example_banner();
+
+ try {
+ sycl::device cpu_device((sycl::cpu_selector_v));
+ std::cout << "Running DFT Complex forward out-of-place buffer example" << std::endl;
+ std::cout << "Using compile-time dispatch API with MKLCPU." << std::endl;
+ std::cout << "Running with double precision real data type on:" << std::endl;
+ std::cout << "\tCPU device :" << cpu_device.get_info()
+ << std::endl;
+
+ run_example(cpu_device);
+ std::cout << "DFT Complex USM example ran OK on MKLCPU" << std::endl;
+ }
+ catch (sycl::exception const& e) {
+ // Handle not dft related exceptions that happened during synchronous call
+ std::cerr << "Caught synchronous SYCL exception:" << std::endl;
+ std::cerr << "\t" << e.what() << std::endl;
+ std::cerr << "\tSYCL error code: " << e.code().value() << std::endl;
+ return 1;
+ }
+ catch (std::exception const& e) {
+ // Handle not SYCL related exceptions that happened during synchronous call
+ std::cerr << "Caught synchronous std::exception:" << std::endl;
+ std::cerr << "\t" << e.what() << std::endl;
+ return 1;
+ }
+
+ return 0;
+}
diff --git a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp
index 376cc7994..232d03758 100644
--- a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp
+++ b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp
@@ -31,7 +31,7 @@
void run_example(const sycl::device& gpu_device) {
constexpr std::size_t N = 10;
- // Catch asynchronous exceptions for cpu
+ // Catch asynchronous exceptions for gpu
auto gpu_error_handler = [&](sycl::exception_list exceptions) {
for (auto const& e : exceptions) {
try {
@@ -92,7 +92,7 @@ void print_example_banner() {
"# For Intel GPU with Intel MKLGPU backend.\n"
"#\n"
"# The environment variable SYCL_DEVICE_FILTER can be used to specify\n"
- "#SYCL device\n"
+ "# SYCL device\n"
"########################################################################\n"
<< std::endl;
}
diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp
index 0cbffe9f7..85684fd07 100644
--- a/include/oneapi/mkl/dft/detail/types_impl.hpp
+++ b/include/oneapi/mkl/dft/detail/types_impl.hpp
@@ -31,8 +31,6 @@ namespace detail {
typedef long DFT_ERROR;
-#define DFT_NOTSET -1
-
enum class precision { SINGLE, DOUBLE };
template
diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt
index 5e4e18ef1..97af1fa25 100644
--- a/src/dft/backends/mklcpu/CMakeLists.txt
+++ b/src/dft/backends/mklcpu/CMakeLists.txt
@@ -41,7 +41,7 @@ target_include_directories(${LIB_OBJ}
${MKL_INCLUDE}
)
-target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT} -DBUILD_COMP)
+target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT})
if (USE_ADD_SYCL_TO_TARGET_INTEGRATION)
add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES})
endif()
diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp
index a26ad545e..ab95cb5af 100644
--- a/src/dft/backends/mklcpu/backward.cpp
+++ b/src/dft/backends/mklcpu/backward.cpp
@@ -30,81 +30,293 @@
#include "oneapi/mkl/dft/descriptor.hpp"
#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp"
+#include "dft/backends/mklcpu/commit_derived_impl.hpp"
+
+// MKLCPU header
+#include "mkl_dfti.h"
+
namespace oneapi {
namespace mkl {
namespace dft {
namespace mklcpu {
+namespace detail {
// BUFFER version
+// backward a MKLCPU DFT call to the backend, checking that the commit impl is valid.
+template
+inline void check_bwd_commit(dft::descriptor &desc) {
+ auto commit_handle = dft::detail::get_commit(desc);
+ if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklcpu) {
+ throw mkl::invalid_argument("DFT", "computer_backward",
+ "DFT descriptor has not been commited for MKLCPU");
+ }
+
+ auto mklcpu_desc = reinterpret_cast(commit_handle->get_handle());
+ MKL_LONG commit_status{ DFTI_UNCOMMITTED };
+ DftiGetValue(mklcpu_desc[1], DFTI_COMMIT_STATUS, &commit_status);
+ if (commit_status != DFTI_COMMITTED) {
+ throw mkl::invalid_argument("DFT", "compute_backward",
+ "MKLCPU DFT descriptor was not successfully committed.");
+ }
+}
+
+// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match
+// the expected value.
+template
+inline auto expect_config(DescT &desc, const char *message) {
+ dft::detail::config_value actual{ 0 };
+ desc.get_value(Param, &actual);
+ if (actual != Expected) {
+ throw mkl::invalid_argument("DFT", "compute_backward", message);
+ }
+}
+// convert the base commit class to derived cpu commit class
+template
+auto get_buffer(commit_t *commit_handle) {
+ commit_derived_t *derived_commit =
+ static_cast *>(commit_handle);
+ return derived_commit->get_handle_buffer();
+}
+} // namespace detail
//In-place transform
template
-ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/,
- sycl::buffer& /*inout*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto inout_acc = inout.template get_access(cgh);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status =
+ DftiComputeBackward(desc_acc[detail::DIR::bwd], inout_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/,
- sycl::buffer& /*inout_re*/,
- sycl::buffer& /*inout_im*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re,
+ sycl::buffer &inout_im) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto re_acc = inout_re.template get_access(cgh);
+ auto im_acc = inout_im.template get_access(cgh);
+
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], re_acc.get_pointer(),
+ im_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform
template
-ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/, sycl::buffer& /*in*/,
- sycl::buffer& /*out*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in,
+ sycl::buffer &out) {
+ detail::expect_config(desc,
+ "Unexpected value for placement");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto in_acc = in.template get_access(cgh);
+ auto out_acc = out.template get_access(cgh);
+
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], in_acc.get_pointer(),
+ out_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/,
- sycl::buffer& /*in_re*/,
- sycl::buffer& /*in_im*/,
- sycl::buffer& /*out_re*/,
- sycl::buffer& /*out_im*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re,
+ sycl::buffer &in_im,
+ sycl::buffer &out_re,
+ sycl::buffer &out_im) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto inre_acc = in_re.template get_access(cgh);
+ auto inim_acc = in_im.template get_access(cgh);
+ auto outre_acc = out_re.template get_access(cgh);
+ auto outim_acc = out_im.template get_access(cgh);
+
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeBackward(
+ desc_acc[detail::DIR::bwd], inre_acc.get_pointer(), inim_acc.get_pointer(),
+ outre_acc.get_pointer(), outim_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//USM version
//In-place transform
template
-ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, data_type* /*inout*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout,
+ const std::vector &dependencies) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], inout);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, data_type* /*inout_re*/,
- data_type* /*inout_im*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re,
+ data_type *inout_im,
+ const std::vector &dependencies) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], inout_re, inout_im);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform
template
-ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, input_type* /*in*/,
- output_type* /*out*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out,
+ const std::vector &dependencies) {
+ // Check: inplace, complex storage
+ detail::expect_config(desc,
+ "Unexpected value for placement");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], in, out);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, input_type* /*in_re*/,
- input_type* /*in_im*/, output_type* /*out_re*/,
- output_type* /*out_im*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re,
+ input_type *in_im, output_type *out_re,
+ output_type *out_im,
+ const std::vector &dependencies) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_bwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status =
+ DftiComputeBackward(desc_acc[detail::DIR::bwd], in_re, in_im, out_re, out_im);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/backends/mklcpu", "compute_backward",
+ std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
// Template function instantiations
diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp
index 5010250da..fe05625bd 100644
--- a/src/dft/backends/mklcpu/commit.cpp
+++ b/src/dft/backends/mklcpu/commit.cpp
@@ -31,6 +31,8 @@
#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp"
#include "oneapi/mkl/dft/detail/commit_impl.hpp"
+
+#include "dft/backends/mklcpu/commit_derived_impl.hpp"
#include "mkl_service.h"
#include "mkl_dfti.h"
@@ -38,106 +40,152 @@ namespace oneapi {
namespace mkl {
namespace dft {
namespace mklcpu {
+namespace detail {
+
+template
+commit_derived_impl::commit_derived_impl(
+ sycl::queue queue, const dft::detail::dft_values& config_values)
+ : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklcpu) {
+ // create the descriptor once for the lifetime of the descriptor class
+ DFT_ERROR status[2] = { DFTI_BAD_DESCRIPTOR, DFTI_BAD_DESCRIPTOR };
-template
-class commit_derived_impl final : public detail::commit_impl {
-public:
- commit_derived_impl(sycl::queue queue, const detail::dft_values& config_values)
- : detail::commit_impl(queue, backend::mklcpu) {
- DFT_ERROR status = DFT_NOTSET;
+ for (auto dir : { DIR::fwd, DIR::bwd }) {
const auto rank = static_cast(config_values.dimensions.size());
- if (rank == 1) {
- status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), 1,
- config_values.dimensions[0]);
+ if (config_values.dimensions.size() == 1) {
+ status[dir] = DftiCreateDescriptor(&bidirection_handle[dir], mklcpu_prec, mklcpu_dom, 1,
+ config_values.dimensions[0]);
}
else {
- status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), rank,
- config_values.dimensions.data());
- }
- if (status != DFTI_NO_ERROR) {
- throw oneapi::mkl::exception(
- "dft/backends/mklcpu", "commit",
- "DftiCreateDescriptor failed with status: " + std::to_string(status));
+ status[dir] = DftiCreateDescriptor(&bidirection_handle[dir], mklcpu_prec, mklcpu_dom,
+ rank, config_values.dimensions.data());
}
}
- void commit(const detail::dft_values& config_values) override {
- set_value(handle, config_values);
- auto status = DftiCommitDescriptor(handle);
- if (status != DFTI_NO_ERROR) {
- throw oneapi::mkl::exception(
- "dft/backends/mklcpu", "commit",
- "DftiCommitDescriptor failed with status: " + std::to_string(status));
- }
+ if (status[0] != DFTI_NO_ERROR || status[1] != DFTI_NO_ERROR) {
+ std::string err = std::string("DftiCreateDescriptor failed with status : ") +
+ DftiErrorMessage(status[0]) + std::string(", ") +
+ DftiErrorMessage(status[1]);
+ throw oneapi::mkl::exception("dft/backends/mklcpu", "create_descriptor", err);
}
+}
- virtual void* get_handle() noexcept override {
- return handle;
+template
+commit_derived_impl::~commit_derived_impl() {
+ for (auto dir : { DIR::fwd, DIR::bwd }) {
+ DftiFreeDescriptor(&bidirection_handle[dir]);
}
+}
- virtual ~commit_derived_impl() override {
- DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE*)&handle);
- }
+template
+void commit_derived_impl::commit(
+ const dft::detail::dft_values& config_values) {
+ set_value(bidirection_handle.data(), config_values);
+
+ this->get_queue()
+ .submit([&](sycl::handler& cgh) {
+ auto bidir_handle_obj =
+ bidirection_buffer.get_access(cgh);
+
+ host_task>(cgh, [=]() {
+ DFT_ERROR status[2] = { DFTI_BAD_DESCRIPTOR, DFTI_BAD_DESCRIPTOR };
+
+ for (auto dir : { DIR::fwd, DIR::bwd })
+ status[dir] = DftiCommitDescriptor(bidir_handle_obj[dir]);
+
+ // this is important for real-batched transforms, as the backward transform would
+ // be inconsistent based on the stride setup, but once recommited before backward
+ // it should work just fine. so we error out only if there is a issue with both.
+ if (status[0] != DFTI_NO_ERROR && status[1] != DFTI_NO_ERROR) {
+ std::string err = std::string("DftiCommitDescriptor failed with status : ") +
+ DftiErrorMessage(status[0]) + std::string(", ") +
+ DftiErrorMessage(status[1]);
+ throw oneapi::mkl::exception("dft/backends/mklcpu", "commit", err);
+ }
+ });
+ })
+ .wait();
+}
-private:
- DFTI_DESCRIPTOR_HANDLE handle = nullptr;
+template
+void* commit_derived_impl::get_handle() noexcept {
+ return reinterpret_cast(bidirection_handle.data());
+}
- constexpr DFTI_CONFIG_VALUE get_domain(domain d) {
- if (d == domain::COMPLEX) {
- return DFTI_COMPLEX;
- }
- else {
- return DFTI_REAL;
- }
+template
+template
+void commit_derived_impl::set_value_item(mklcpu_desc_t hand, enum DFTI_CONFIG_PARAM name,
+ Args... args) {
+ DFT_ERROR value_err = DftiSetValue(hand, name, args...);
+ if (value_err != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception("dft/backends/mklcpu", "set_value_item",
+ DftiErrorMessage(value_err));
}
+}
- constexpr DFTI_CONFIG_VALUE get_precision(precision p) {
- if (p == precision::SINGLE) {
- return DFTI_SINGLE;
+template
+void commit_derived_impl::set_value(mklcpu_desc_t* descHandle,
+ const dft::detail::dft_values& config) {
+ for (auto dir : { DIR::fwd, DIR::bwd }) {
+ set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data());
+ set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data());
+ set_value_item(descHandle[dir], DFTI_BACKWARD_SCALE, config.bwd_scale);
+ set_value_item(descHandle[dir], DFTI_FORWARD_SCALE, config.fwd_scale);
+ set_value_item(descHandle[dir], DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms);
+ set_value_item(descHandle[dir], DFTI_INPUT_DISTANCE,
+ (dir == detail::DIR::fwd) ? config.fwd_dist : config.bwd_dist);
+ set_value_item(descHandle[dir], DFTI_OUTPUT_DISTANCE,
+ (dir == detail::DIR::fwd) ? config.bwd_dist : config.fwd_dist);
+ set_value_item(descHandle[dir], DFTI_COMPLEX_STORAGE,
+ to_mklcpu(config.complex_storage));
+ set_value_item(descHandle[dir], DFTI_REAL_STORAGE,
+ to_mklcpu(config.real_storage));
+ set_value_item(descHandle[dir], DFTI_CONJUGATE_EVEN_STORAGE,
+ to_mklcpu(config.conj_even_storage));
+ set_value_item(descHandle[dir], DFTI_PLACEMENT,
+ to_mklcpu(config.placement));
+ set_value_item(descHandle[dir], DFTI_PACKED_FORMAT,
+ to_mklcpu(config.packed_format));
+ // Setting the workspace causes an FFT_INVALID_DESCRIPTOR.
+ if (config.workspace != config_value::ALLOW) {
+ throw mkl::invalid_argument("dft/backends/mklcpu", "commit",
+ "MKLCPU only supports workspace set to allow");
}
- else {
- return DFTI_DOUBLE;
+ // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used:
+ if (config.ordering != dft::detail::config_value::ORDERED) {
+ throw mkl::invalid_argument("dft/backends/mklcpu", "commit",
+ "MKLCPU only supports ordered ordering.");
}
- }
-
- template
- void set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) {
- if (auto ret = DftiSetValue(hand, name, args...); ret != DFTI_NO_ERROR) {
- throw oneapi::mkl::exception(
- "dft/backends/mklcpu", "set_value_item",
- "name: " + std::to_string(name) + " error: " + std::to_string(ret));
+ // Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used:
+ if (config.transpose != false) {
+ throw mkl::invalid_argument("dft/backends/mklcpu", "commit",
+ "MKLCPU only supports non-transposed.");
}
}
+}
+} // namespace detail
- void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle,
- const detail::dft_values& config) {
- set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data());
- set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data());
- set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale);
- set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale);
- set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms);
- set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist);
- set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist);
- set_value_item(
- descHandle, DFTI_PLACEMENT,
- (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE);
- }
-};
-
-template
-detail::commit_impl* create_commit(const descriptor& desc,
- sycl::queue& sycl_queue) {
- return new commit_derived_impl(sycl_queue, desc.get_values());
+template
+dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc,
+ sycl::queue& sycl_queue) {
+ return new detail::commit_derived_impl(sycl_queue, desc.get_values());
}
-template detail::commit_impl* create_commit(
- const descriptor&, sycl::queue&);
-template detail::commit_impl* create_commit(
- const descriptor&, sycl::queue&);
-template detail::commit_impl* create_commit(
- const descriptor&, sycl::queue&);
-template detail::commit_impl* create_commit(
- const descriptor&, sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
+template dft::detail::commit_impl*
+create_commit(
+ const dft::detail::descriptor&,
+ sycl::queue&);
} // namespace mklcpu
} // namespace dft
diff --git a/src/dft/backends/mklcpu/commit_derived_impl.hpp b/src/dft/backends/mklcpu/commit_derived_impl.hpp
new file mode 100644
index 000000000..9b418d94c
--- /dev/null
+++ b/src/dft/backends/mklcpu/commit_derived_impl.hpp
@@ -0,0 +1,83 @@
+/*******************************************************************************
+* Copyright 2023 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_
+#define _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_
+
+#include "oneapi/mkl/exceptions.hpp"
+#include "oneapi/mkl/dft/detail/types_impl.hpp"
+#include "dft/backends/mklcpu/mklcpu_helpers.hpp"
+
+// MKLCPU header
+#include "mkl_dfti.h"
+
+namespace oneapi {
+namespace mkl {
+namespace dft {
+namespace mklcpu {
+namespace detail {
+
+// this is used for indexing bidirectional_handle
+enum DIR { fwd = 0, bwd = 1 };
+
+template
+class commit_derived_impl final : public dft::detail::commit_impl {
+private:
+ static constexpr DFTI_CONFIG_VALUE mklcpu_prec = to_mklcpu(prec);
+ static constexpr DFTI_CONFIG_VALUE mklcpu_dom = to_mklcpu(dom);
+ using mklcpu_desc_t = DFTI_DESCRIPTOR_HANDLE;
+
+public:
+ commit_derived_impl(sycl::queue queue, const dft::detail::dft_values& config_values);
+
+ virtual void commit(const dft::detail::dft_values& config_values) override;
+
+ virtual void* get_handle() noexcept override;
+
+ virtual ~commit_derived_impl() override;
+
+ sycl::buffer get_handle_buffer() noexcept {
+ return bidirection_buffer;
+ };
+
+private:
+ // bidirectional_handle[0] is the forward handle, bidirectional_handle[1] is the backward handle
+ std::array bidirection_handle{ nullptr, nullptr };
+ sycl::buffer bidirection_buffer{ bidirection_handle.data(),
+ sycl::range<1>{ 2 } };
+
+ template
+ void set_value_item(mklcpu_desc_t hand, enum DFTI_CONFIG_PARAM name, Args... args);
+
+ void set_value(mklcpu_desc_t* descHandle, const dft::detail::dft_values& config);
+};
+
+template
+using commit_t = dft::detail::commit_impl;
+
+template
+using commit_derived_t = detail::commit_derived_impl;
+
+} // namespace detail
+} // namespace mklcpu
+} // namespace dft
+} // namespace mkl
+} // namespace oneapi
+
+#endif // _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_
diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp
index a852f43a0..9f7eea851 100644
--- a/src/dft/backends/mklcpu/forward.cpp
+++ b/src/dft/backends/mklcpu/forward.cpp
@@ -30,79 +30,300 @@
#include "oneapi/mkl/dft/descriptor.hpp"
#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp"
+#include "dft/backends/mklcpu/commit_derived_impl.hpp"
+
+// MKLCPU header
+#include "mkl_dfti.h"
+
namespace oneapi {
namespace mkl {
namespace dft {
namespace mklcpu {
+namespace detail {
+
+// BUFFER version
+// Forward a MKLCPU DFT call to the backend, checking that the commit impl is valid.
+template
+inline void check_fwd_commit(dft::descriptor &desc) {
+ auto commit_handle = dft::detail::get_commit(desc);
+ if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklcpu) {
+ throw mkl::invalid_argument("DFT", "computer_forward",
+ "DFT descriptor has not been commited for MKLCPU");
+ }
+
+ auto mklcpu_desc = reinterpret_cast(commit_handle->get_handle());
+ MKL_LONG commit_status{ DFTI_UNCOMMITTED };
+ DftiGetValue(mklcpu_desc[0], DFTI_COMMIT_STATUS, &commit_status);
+ if (commit_status != DFTI_COMMITTED) {
+ throw mkl::invalid_argument("DFT", "compute_forward",
+ "MKLCPU DFT descriptor was not successfully committed.");
+ }
+}
+
+// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match
+// the expected value.
+template
+inline auto expect_config(DescT &desc, const char *message) {
+ dft::detail::config_value actual{ 0 };
+ desc.get_value(Param, &actual);
+ if (actual != Expected) {
+ throw mkl::invalid_argument("DFT", "compute_forward", message);
+ }
+}
+
+// convert the base commit class to derived cpu commit class
+template
+auto get_buffer(commit_t *commit_handle) {
+ commit_derived_t *derived_commit =
+ static_cast *>(commit_handle);
+ return derived_commit->get_handle_buffer();
+}
+} // namespace detail
//In-place transform
template
-ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/,
- sycl::buffer& /*inout*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto inout_acc = inout.template get_access(cgh);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status =
+ DftiComputeForward(desc_acc[detail::DIR::fwd], inout_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/,
- sycl::buffer& /*inout_re*/,
- sycl::buffer& /*inout_im*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re,
+ sycl::buffer &inout_im) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto re_acc = inout_re.template get_access(cgh);
+ auto im_acc = inout_im.template get_access(cgh);
+
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], re_acc.get_pointer(),
+ im_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform
template
-ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/, sycl::buffer& /*in*/,
- sycl::buffer& /*out*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in,
+ sycl::buffer &out) {
+ detail::expect_config(desc,
+ "Unexpected value for placement");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto in_acc = in.template get_access(cgh);
+ auto out_acc = out.template get_access(cgh);
+
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], in_acc.get_pointer(),
+ out_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/,
- sycl::buffer& /*in_re*/,
- sycl::buffer& /*in_im*/,
- sycl::buffer& /*out_re*/,
- sycl::buffer& /*out_im*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
+ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re,
+ sycl::buffer &in_im,
+ sycl::buffer &out_re,
+ sycl::buffer &out_im) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+ auto inre_acc = in_re.template get_access(cgh);
+ auto inim_acc = in_im.template get_access(cgh);
+ auto outre_acc = out_re.template get_access(cgh);
+ auto outim_acc = out_im.template get_access(cgh);
+
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd],
+ inre_acc.get_pointer(), inim_acc.get_pointer(),
+ outre_acc.get_pointer(), outim_acc.get_pointer());
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//USM version
//In-place transform
template
-ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, data_type* /*inout*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout,
+ const std::vector &dependencies) {
+ detail::expect_config(
+ desc, "Unexpected value for placement");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], inout);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, data_type* /*inout_re*/,
- data_type* /*inout_im*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re,
+ data_type *inout_im,
+ const std::vector &dependencies) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], inout_re, inout_im);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform
template
-ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, input_type* /*in*/,
- output_type* /*out*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out,
+ const std::vector &dependencies) {
+ // Check: inplace
+ detail::expect_config(desc,
+ "Unexpected value for placement");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], in, out);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format
template
-ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, input_type* /*in_re*/,
- input_type* /*in_im*/, output_type* /*out_re*/,
- output_type* /*out_im*/,
- const std::vector& /*dependencies*/) {
- throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU");
- return sycl::event{};
+ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re,
+ input_type *in_im, output_type *out_re,
+ output_type *out_im,
+ const std::vector &dependencies) {
+ detail::expect_config(
+ desc, "Unexpected value for complex storage");
+
+ auto commit_handle = dft::detail::get_commit(desc);
+ detail::check_fwd_commit(desc);
+ sycl::queue &cpu_queue{ commit_handle->get_queue() };
+
+ auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) };
+
+ return cpu_queue.submit([&](sycl::handler &cgh) {
+ auto desc_acc = mklcpu_desc_buffer.template get_access(cgh);
+
+ cgh.depends_on(dependencies);
+ detail::host_task(cgh, [=]() {
+ DFT_ERROR status =
+ DftiComputeForward(desc_acc[detail::DIR::fwd], in_re, in_im, out_re, out_im);
+ if (status != DFTI_NO_ERROR) {
+ throw oneapi::mkl::exception(
+ "dft/forward/mklcpu", "compute_forward",
+ std::string("DftiComputeForward failed : ") + DftiErrorMessage(status));
+ }
+ });
+ });
}
// Template function instantiations
diff --git a/src/dft/backends/mklcpu/mklcpu_helpers.hpp b/src/dft/backends/mklcpu/mklcpu_helpers.hpp
new file mode 100644
index 000000000..e4aad0dde
--- /dev/null
+++ b/src/dft/backends/mklcpu/mklcpu_helpers.hpp
@@ -0,0 +1,182 @@
+/*******************************************************************************
+* Copyright Codeplay Software Ltd.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions
+* and limitations under the License.
+*
+*
+* SPDX-License-Identifier: Apache-2.0
+*******************************************************************************/
+
+#ifndef _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_
+#define _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_
+
+#include "oneapi/mkl/exceptions.hpp"
+#include "oneapi/mkl/dft/detail/types_impl.hpp"
+
+// MKLCPU header
+#include "mkl_dfti.h"
+
+namespace oneapi {
+namespace mkl {
+namespace dft {
+namespace mklcpu {
+namespace detail {
+
+template
+static inline auto host_task_internal(H& cgh, F f, int) -> decltype(cgh.host_task(f)) {
+ return cgh.host_task(f);
+}
+
+template
+static inline void host_task(H& cgh, F f) {
+ (void)host_task_internal(cgh, f, 0);
+}
+
+template
+class kernel_name {};
+
+/// Convert domain to equivalent backend native value.
+inline constexpr DFTI_CONFIG_VALUE to_mklcpu(dft::detail::domain dom) {
+ if (dom == dft::detail::domain::REAL) {
+ return DFTI_REAL;
+ }
+ else {
+ return DFTI_COMPLEX;
+ }
+}
+
+/// Convert precision to equivalent backend native value.
+inline constexpr DFTI_CONFIG_VALUE to_mklcpu(dft::detail::precision dom) {
+ if (dom == dft::detail::precision::SINGLE) {
+ return DFTI_SINGLE;
+ }
+ else {
+ return DFTI_DOUBLE;
+ }
+}
+
+/// Convert a config_param to equivalent backend native value.
+inline constexpr DFTI_CONFIG_PARAM to_mklcpu(dft::detail::config_param param) {
+ using iparam = dft::detail::config_param;
+ switch (param) {
+ case iparam::FORWARD_DOMAIN: return DFTI_FORWARD_DOMAIN;
+ case iparam::DIMENSION: return DFTI_DIMENSION;
+ case iparam::LENGTHS: return DFTI_LENGTHS;
+ case iparam::PRECISION: return DFTI_PRECISION;
+ case iparam::FORWARD_SCALE: return DFTI_FORWARD_SCALE;
+ case iparam::NUMBER_OF_TRANSFORMS: return DFTI_NUMBER_OF_TRANSFORMS;
+ case iparam::COMPLEX_STORAGE: return DFTI_COMPLEX_STORAGE;
+ case iparam::REAL_STORAGE: return DFTI_REAL_STORAGE;
+ case iparam::CONJUGATE_EVEN_STORAGE: return DFTI_CONJUGATE_EVEN_STORAGE;
+ case iparam::INPUT_STRIDES: return DFTI_INPUT_STRIDES;
+ case iparam::OUTPUT_STRIDES: return DFTI_OUTPUT_STRIDES;
+ case iparam::FWD_DISTANCE: return DFTI_FWD_DISTANCE;
+ case iparam::BWD_DISTANCE: return DFTI_BWD_DISTANCE;
+ case iparam::WORKSPACE: return DFTI_WORKSPACE;
+ case iparam::ORDERING: return DFTI_ORDERING;
+ case iparam::TRANSPOSE: return DFTI_TRANSPOSE;
+ case iparam::PACKED_FORMAT: return DFTI_PACKED_FORMAT;
+ case iparam::COMMIT_STATUS: return DFTI_COMMIT_STATUS;
+ default:
+ throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()",
+ "Invalid config param.");
+ return static_cast(0);
+ }
+}
+
+/** Convert a config_value to the backend's native value. Throw on invalid input.
+ * @tparam Param The config param the value is for.
+ * @param value The config value to convert.
+**/
+template
+inline constexpr int to_mklcpu(dft::detail::config_value value);
+
+template <>
+inline constexpr int to_mklcpu(
+ dft::detail::config_value value) {
+ if (value == dft::detail::config_value::COMPLEX_COMPLEX) {
+ return DFTI_COMPLEX_COMPLEX;
+ }
+ else if (value == dft::detail::config_value::REAL_REAL) {
+ return DFTI_REAL_REAL;
+ }
+ else {
+ throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()",
+ "Invalid config value for complex storage.");
+ return 0;
+ }
+}
+
+template <>
+inline constexpr int to_mklcpu(
+ dft::detail::config_value value) {
+ if (value == dft::detail::config_value::REAL_REAL) {
+ return DFTI_REAL_REAL;
+ }
+ else {
+ throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()",
+ "Invalid config value for real storage.");
+ return 0;
+ }
+}
+template <>
+inline constexpr int to_mklcpu(
+ dft::detail::config_value value) {
+ if (value == dft::detail::config_value::COMPLEX_COMPLEX) {
+ return DFTI_COMPLEX_COMPLEX;
+ }
+ else {
+ throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()",
+ "Invalid config value for conjugate even storage.");
+ return 0;
+ }
+}
+
+template <>
+inline constexpr int to_mklcpu(
+ dft::detail::config_value value) {
+ if (value == dft::detail::config_value::INPLACE) {
+ return DFTI_INPLACE;
+ }
+ else if (value == dft::detail::config_value::NOT_INPLACE) {
+ return DFTI_NOT_INPLACE;
+ }
+ else {
+ throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()",
+ "Invalid config value for inplace.");
+ return 0;
+ }
+}
+
+template <>
+inline constexpr int to_mklcpu(
+ dft::detail::config_value value) {
+ if (value == dft::detail::config_value::CCE_FORMAT) {
+ return DFTI_CCE_FORMAT;
+ }
+ else {
+ throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()",
+ "Invalid config value for packed format.");
+ return 0;
+ }
+}
+
+using mklcpu_desc_t = DFTI_DESCRIPTOR_HANDLE;
+
+} // namespace detail
+} // namespace mklcpu
+} // namespace dft
+} // namespace mkl
+} // namespace oneapi
+
+#endif // _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_
diff --git a/src/dft/descriptor.cxx b/src/dft/descriptor.cxx
index d77d012b6..13dac0b2f 100644
--- a/src/dft/descriptor.cxx
+++ b/src/dft/descriptor.cxx
@@ -172,7 +172,9 @@ void descriptor::get_value(config_param param, ...) const {
va_start(vl, param);
switch (param) {
case config_param::FORWARD_DOMAIN: *va_arg(vl, dft::domain*) = dom; break;
- case config_param::DIMENSION: *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size()); break;
+ case config_param::DIMENSION:
+ *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size());
+ break;
case config_param::LENGTHS:
std::copy(values_.dimensions.begin(), values_.dimensions.end(),
va_arg(vl, std::int64_t*));
diff --git a/tests/unit_tests/dft/include/compute_inplace.hpp b/tests/unit_tests/dft/include/compute_inplace.hpp
index cfcd465e7..e5bad3cff 100644
--- a/tests/unit_tests/dft/include/compute_inplace.hpp
+++ b/tests/unit_tests/dft/include/compute_inplace.hpp
@@ -256,6 +256,7 @@ int DFT_Test::test_in_place_USM() {
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto real_strides = get_conjugate_even_real_component_strides(sizes);
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
+
descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, complex_strides.data());
descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, real_strides.data());
commit_descriptor(descriptor, sycl_queue);
diff --git a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp
index a7ab36da4..f2a30eff8 100644
--- a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp
+++ b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp
@@ -22,18 +22,19 @@
#include "compute_tester.hpp"
-/* Test is not implemented because currently there are no available dft implementations.
- * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */
template
int DFT_Test::test_in_place_real_real_USM() {
if (!init(MemoryAccessModel::usm)) {
return test_skipped;
}
+ if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
+ std::cout << "skipping real split tests as they are not supported" << std::endl;
- try {
+ return test_skipped;
+ }
+ else {
descriptor_t descriptor{ sizes };
PrecisionType backward_scale = 1.f / static_cast(forward_elements);
-
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::INPLACE);
descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE,
@@ -42,6 +43,7 @@ int DFT_Test