diff --git a/README.md b/README.md index 489661400..41e97f751 100644 --- a/README.md +++ b/README.md @@ -224,9 +224,14 @@ Supported domains: BLAS, LAPACK, RNG, DFT LLVM*, hipSYCL - DFT + DFT Intel GPU - Intel(R) oneAPI Math Kernel Library + Intel(R) oneAPI Math Kernel Library + Dynamic, Static + DPC++ + + + x86 CPU Dynamic, Static DPC++ diff --git a/examples/dft/compile_time_dispatching/CMakeLists.txt b/examples/dft/compile_time_dispatching/CMakeLists.txt index 4556a0cb0..59bf557a6 100644 --- a/examples/dft/compile_time_dispatching/CMakeLists.txt +++ b/examples/dft/compile_time_dispatching/CMakeLists.txt @@ -19,28 +19,45 @@ #Build object from all sources set(DFTI_CT_SOURCES "") + if(ENABLE_MKLGPU_BACKEND) list(APPEND DFTI_CT_SOURCES "complex_fwd_buffer_mklgpu") endif() +if(ENABLE_MKLCPU_BACKEND) + list(APPEND DFTI_CT_SOURCES "complex_fwd_buffer_mklcpu") +endif() + include(WarningsUtils) foreach(dfti_ct_sources ${DFTI_CT_SOURCES}) - add_executable(example_${domain}_${dfti_ct_sources} ${dfti_ct_sources}.cpp) + # add executable and define include directories + # add dependencies and link libraries + # register example as ctest + add_executable(example_${domain}_${dfti_ct_sources} ${dfti_ct_sources}.cpp) target_include_directories(example_${domain}_${dfti_ct_sources} PUBLIC ${PROJECT_SOURCE_DIR}/examples/include PUBLIC ${PROJECT_SOURCE_DIR}/include PUBLIC ${CMAKE_BINARY_DIR}/bin ) - if(domain STREQUAL "dft" AND ENABLE_MKLGPU_BACKEND) - add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklgpu) - list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklgpu) + + set(ONEMKL_LIBRARIES_${domain} "") + if(domain STREQUAL "dft") + if(dfti_ct_sources MATCHES "_mklcpu$") + add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklcpu) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklcpu) + endif() + if(dfti_ct_sources MATCHES "_mklgpu$") + add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklgpu) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklgpu) + endif() endif() + target_link_libraries(example_${domain}_${dfti_ct_sources} PUBLIC ${ONEMKL_LIBRARIES_${domain}} - PUBLIC ONEMKL::SYCL::SYCL - PRIVATE onemkl_warnings + ${ONEMKL_LIBRARIES_${domain}} + ONEMKL::SYCL::SYCL ) # Register example as ctest - add_test(NAME ${domain}/EXAMPLE/CT/${dfti_ct_sources} COMMAND example_${domain}_${dfti_ct_sources}) -endforeach(dfti_ct_sources) + add_test(NAME ${domain}/EXAMPLE/CT/${dfti_ct_sources} COMMAND example_${domain}_${dfti_ct_sources}) +endforeach(dfti_ct_sources) \ No newline at end of file diff --git a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklcpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklcpu.cpp new file mode 100644 index 000000000..26160aeed --- /dev/null +++ b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklcpu.cpp @@ -0,0 +1,132 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// STL includes +#include + +// oneMKL/SYCL includes +#if __has_include() +#include +#else +#include +#endif +#include "oneapi/mkl.hpp" + +void run_example(const sycl::device& cpu_device) { + constexpr int N = 10; + + // Catch asynchronous exceptions for cpu + auto cpu_error_handler = [&](sycl::exception_list exceptions) { + for (auto const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) { + // Handle not dft related exceptions that happened during asynchronous call + std::cerr << "Caught asynchronous SYCL exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + } + } + std::exit(2); + }; + + sycl::queue cpu_queue(cpu_device, cpu_error_handler); + + std::vector> input_data(N); + std::vector> output_data(N); + + // enabling + // 1. create descriptors + oneapi::mkl::dft::descriptor + desc(N); + + // 2. variadic set_value + desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, + oneapi::mkl::dft::config_value::NOT_INPLACE); + desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, + static_cast(1)); + + // 3. commit_descriptor (compile_time MKLCPU) + desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); + + // 4. compute_forward / compute_backward (MKLCPU) + { + sycl::buffer> input_buffer(input_data.data(), sycl::range<1>(N)); + sycl::buffer> output_buffer(output_data.data(), sycl::range<1>(N)); + oneapi::mkl::dft::compute_forward, + std::complex>(desc, input_buffer, output_buffer); + } +} + +// +// Description of example setup, apis used and supported floating point type precisions +// +void print_example_banner() { + std::cout << "\n" + "########################################################################\n" + "# Complex out-of-place forward transform for Buffer API's example:\n" + "#\n" + "# Using APIs:\n" + "# Compile-time dispatch API\n" + "# Buffer forward complex out-of-place\n" + "#\n" + "# Using double precision (double) data type\n" + "#\n" + "# For Intel CPU with Intel MKLCPU backend.\n" + "#\n" + "# The environment variable SYCL_DEVICE_FILTER can be used to specify\n" + "# SYCL device\n" + "########################################################################\n" + << std::endl; +} + +// +// Main entry point for example. +// +int main(int argc, char** argv) { + print_example_banner(); + + try { + sycl::device cpu_device((sycl::cpu_selector_v)); + std::cout << "Running DFT Complex forward out-of-place buffer example" << std::endl; + std::cout << "Using compile-time dispatch API with MKLCPU." << std::endl; + std::cout << "Running with double precision real data type on:" << std::endl; + std::cout << "\tCPU device :" << cpu_device.get_info() + << std::endl; + + run_example(cpu_device); + std::cout << "DFT Complex USM example ran OK on MKLCPU" << std::endl; + } + catch (sycl::exception const& e) { + // Handle not dft related exceptions that happened during synchronous call + std::cerr << "Caught synchronous SYCL exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + std::cerr << "\tSYCL error code: " << e.code().value() << std::endl; + return 1; + } + catch (std::exception const& e) { + // Handle not SYCL related exceptions that happened during synchronous call + std::cerr << "Caught synchronous std::exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp index 376cc7994..232d03758 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp @@ -31,7 +31,7 @@ void run_example(const sycl::device& gpu_device) { constexpr std::size_t N = 10; - // Catch asynchronous exceptions for cpu + // Catch asynchronous exceptions for gpu auto gpu_error_handler = [&](sycl::exception_list exceptions) { for (auto const& e : exceptions) { try { @@ -92,7 +92,7 @@ void print_example_banner() { "# For Intel GPU with Intel MKLGPU backend.\n" "#\n" "# The environment variable SYCL_DEVICE_FILTER can be used to specify\n" - "#SYCL device\n" + "# SYCL device\n" "########################################################################\n" << std::endl; } diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index 0cbffe9f7..85684fd07 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -31,8 +31,6 @@ namespace detail { typedef long DFT_ERROR; -#define DFT_NOTSET -1 - enum class precision { SINGLE, DOUBLE }; template diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index 5e4e18ef1..97af1fa25 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -41,7 +41,7 @@ target_include_directories(${LIB_OBJ} ${MKL_INCLUDE} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT} -DBUILD_COMP) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) endif() diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp index a26ad545e..ab95cb5af 100644 --- a/src/dft/backends/mklcpu/backward.cpp +++ b/src/dft/backends/mklcpu/backward.cpp @@ -30,81 +30,293 @@ #include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "dft/backends/mklcpu/commit_derived_impl.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { +namespace detail { // BUFFER version +// backward a MKLCPU DFT call to the backend, checking that the commit impl is valid. +template +inline void check_bwd_commit(dft::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklcpu) { + throw mkl::invalid_argument("DFT", "computer_backward", + "DFT descriptor has not been commited for MKLCPU"); + } + + auto mklcpu_desc = reinterpret_cast(commit_handle->get_handle()); + MKL_LONG commit_status{ DFTI_UNCOMMITTED }; + DftiGetValue(mklcpu_desc[1], DFTI_COMMIT_STATUS, &commit_status); + if (commit_status != DFTI_COMMITTED) { + throw mkl::invalid_argument("DFT", "compute_backward", + "MKLCPU DFT descriptor was not successfully committed."); + } +} + +// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::detail::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("DFT", "compute_backward", message); + } +} +// convert the base commit class to derived cpu commit class +template +auto get_buffer(commit_t *commit_handle) { + commit_derived_t *derived_commit = + static_cast *>(commit_handle); + return derived_commit->get_handle_buffer(); +} +} // namespace detail //In-place transform template -ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/, - sycl::buffer& /*inout*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inout_acc = inout.template get_access(cgh); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeBackward(desc_acc[detail::DIR::bwd], inout_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/, - sycl::buffer& /*inout_re*/, - sycl::buffer& /*inout_im*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto re_acc = inout_re.template get_access(cgh); + auto im_acc = inout_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], re_acc.get_pointer(), + im_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform template -ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/, sycl::buffer& /*in*/, - sycl::buffer& /*out*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out) { + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], in_acc.get_pointer(), + out_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT void compute_backward(descriptor_type& /*desc*/, - sycl::buffer& /*in_re*/, - sycl::buffer& /*in_im*/, - sycl::buffer& /*out_re*/, - sycl::buffer& /*out_im*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inre_acc = in_re.template get_access(cgh); + auto inim_acc = in_im.template get_access(cgh); + auto outre_acc = out_re.template get_access(cgh); + auto outim_acc = out_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward( + desc_acc[detail::DIR::bwd], inre_acc.get_pointer(), inim_acc.get_pointer(), + outre_acc.get_pointer(), outim_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //USM version //In-place transform template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, data_type* /*inout*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout, + const std::vector &dependencies) { + detail::expect_config( + desc, "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], inout); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, data_type* /*inout_re*/, - data_type* /*inout_im*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, + data_type *inout_im, + const std::vector &dependencies) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], inout_re, inout_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, input_type* /*in*/, - output_type* /*out*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, + const std::vector &dependencies) { + // Check: inplace, complex storage + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], in, out); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type& /*desc*/, input_type* /*in_re*/, - input_type* /*in_im*/, output_type* /*out_re*/, - output_type* /*out_im*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re, + input_type *in_im, output_type *out_re, + output_type *out_im, + const std::vector &dependencies) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeBackward(desc_acc[detail::DIR::bwd], in_re, in_im, out_re, out_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } // Template function instantiations diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index 5010250da..fe05625bd 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -31,6 +31,8 @@ #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" #include "oneapi/mkl/dft/detail/commit_impl.hpp" + +#include "dft/backends/mklcpu/commit_derived_impl.hpp" #include "mkl_service.h" #include "mkl_dfti.h" @@ -38,106 +40,152 @@ namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { +namespace detail { + +template +commit_derived_impl::commit_derived_impl( + sycl::queue queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklcpu) { + // create the descriptor once for the lifetime of the descriptor class + DFT_ERROR status[2] = { DFTI_BAD_DESCRIPTOR, DFTI_BAD_DESCRIPTOR }; -template -class commit_derived_impl final : public detail::commit_impl { -public: - commit_derived_impl(sycl::queue queue, const detail::dft_values& config_values) - : detail::commit_impl(queue, backend::mklcpu) { - DFT_ERROR status = DFT_NOTSET; + for (auto dir : { DIR::fwd, DIR::bwd }) { const auto rank = static_cast(config_values.dimensions.size()); - if (rank == 1) { - status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), 1, - config_values.dimensions[0]); + if (config_values.dimensions.size() == 1) { + status[dir] = DftiCreateDescriptor(&bidirection_handle[dir], mklcpu_prec, mklcpu_dom, 1, + config_values.dimensions[0]); } else { - status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), rank, - config_values.dimensions.data()); - } - if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception( - "dft/backends/mklcpu", "commit", - "DftiCreateDescriptor failed with status: " + std::to_string(status)); + status[dir] = DftiCreateDescriptor(&bidirection_handle[dir], mklcpu_prec, mklcpu_dom, + rank, config_values.dimensions.data()); } } - void commit(const detail::dft_values& config_values) override { - set_value(handle, config_values); - auto status = DftiCommitDescriptor(handle); - if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception( - "dft/backends/mklcpu", "commit", - "DftiCommitDescriptor failed with status: " + std::to_string(status)); - } + if (status[0] != DFTI_NO_ERROR || status[1] != DFTI_NO_ERROR) { + std::string err = std::string("DftiCreateDescriptor failed with status : ") + + DftiErrorMessage(status[0]) + std::string(", ") + + DftiErrorMessage(status[1]); + throw oneapi::mkl::exception("dft/backends/mklcpu", "create_descriptor", err); } +} - virtual void* get_handle() noexcept override { - return handle; +template +commit_derived_impl::~commit_derived_impl() { + for (auto dir : { DIR::fwd, DIR::bwd }) { + DftiFreeDescriptor(&bidirection_handle[dir]); } +} - virtual ~commit_derived_impl() override { - DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE*)&handle); - } +template +void commit_derived_impl::commit( + const dft::detail::dft_values& config_values) { + set_value(bidirection_handle.data(), config_values); + + this->get_queue() + .submit([&](sycl::handler& cgh) { + auto bidir_handle_obj = + bidirection_buffer.get_access(cgh); + + host_task>(cgh, [=]() { + DFT_ERROR status[2] = { DFTI_BAD_DESCRIPTOR, DFTI_BAD_DESCRIPTOR }; + + for (auto dir : { DIR::fwd, DIR::bwd }) + status[dir] = DftiCommitDescriptor(bidir_handle_obj[dir]); + + // this is important for real-batched transforms, as the backward transform would + // be inconsistent based on the stride setup, but once recommited before backward + // it should work just fine. so we error out only if there is a issue with both. + if (status[0] != DFTI_NO_ERROR && status[1] != DFTI_NO_ERROR) { + std::string err = std::string("DftiCommitDescriptor failed with status : ") + + DftiErrorMessage(status[0]) + std::string(", ") + + DftiErrorMessage(status[1]); + throw oneapi::mkl::exception("dft/backends/mklcpu", "commit", err); + } + }); + }) + .wait(); +} -private: - DFTI_DESCRIPTOR_HANDLE handle = nullptr; +template +void* commit_derived_impl::get_handle() noexcept { + return reinterpret_cast(bidirection_handle.data()); +} - constexpr DFTI_CONFIG_VALUE get_domain(domain d) { - if (d == domain::COMPLEX) { - return DFTI_COMPLEX; - } - else { - return DFTI_REAL; - } +template +template +void commit_derived_impl::set_value_item(mklcpu_desc_t hand, enum DFTI_CONFIG_PARAM name, + Args... args) { + DFT_ERROR value_err = DftiSetValue(hand, name, args...); + if (value_err != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft/backends/mklcpu", "set_value_item", + DftiErrorMessage(value_err)); } +} - constexpr DFTI_CONFIG_VALUE get_precision(precision p) { - if (p == precision::SINGLE) { - return DFTI_SINGLE; +template +void commit_derived_impl::set_value(mklcpu_desc_t* descHandle, + const dft::detail::dft_values& config) { + for (auto dir : { DIR::fwd, DIR::bwd }) { + set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data()); + set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data()); + set_value_item(descHandle[dir], DFTI_BACKWARD_SCALE, config.bwd_scale); + set_value_item(descHandle[dir], DFTI_FORWARD_SCALE, config.fwd_scale); + set_value_item(descHandle[dir], DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); + set_value_item(descHandle[dir], DFTI_INPUT_DISTANCE, + (dir == detail::DIR::fwd) ? config.fwd_dist : config.bwd_dist); + set_value_item(descHandle[dir], DFTI_OUTPUT_DISTANCE, + (dir == detail::DIR::fwd) ? config.bwd_dist : config.fwd_dist); + set_value_item(descHandle[dir], DFTI_COMPLEX_STORAGE, + to_mklcpu(config.complex_storage)); + set_value_item(descHandle[dir], DFTI_REAL_STORAGE, + to_mklcpu(config.real_storage)); + set_value_item(descHandle[dir], DFTI_CONJUGATE_EVEN_STORAGE, + to_mklcpu(config.conj_even_storage)); + set_value_item(descHandle[dir], DFTI_PLACEMENT, + to_mklcpu(config.placement)); + set_value_item(descHandle[dir], DFTI_PACKED_FORMAT, + to_mklcpu(config.packed_format)); + // Setting the workspace causes an FFT_INVALID_DESCRIPTOR. + if (config.workspace != config_value::ALLOW) { + throw mkl::invalid_argument("dft/backends/mklcpu", "commit", + "MKLCPU only supports workspace set to allow"); } - else { - return DFTI_DOUBLE; + // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.ordering != dft::detail::config_value::ORDERED) { + throw mkl::invalid_argument("dft/backends/mklcpu", "commit", + "MKLCPU only supports ordered ordering."); } - } - - template - void set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) { - if (auto ret = DftiSetValue(hand, name, args...); ret != DFTI_NO_ERROR) { - throw oneapi::mkl::exception( - "dft/backends/mklcpu", "set_value_item", - "name: " + std::to_string(name) + " error: " + std::to_string(ret)); + // Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.transpose != false) { + throw mkl::invalid_argument("dft/backends/mklcpu", "commit", + "MKLCPU only supports non-transposed."); } } +} +} // namespace detail - void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, - const detail::dft_values& config) { - set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); - set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); - set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); - set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); - set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - set_value_item( - descHandle, DFTI_PLACEMENT, - (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); - } -}; - -template -detail::commit_impl* create_commit(const descriptor& desc, - sycl::queue& sycl_queue) { - return new commit_derived_impl(sycl_queue, desc.get_values()); +template +dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { + return new detail::commit_derived_impl(sycl_queue, desc.get_values()); } -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); } // namespace mklcpu } // namespace dft diff --git a/src/dft/backends/mklcpu/commit_derived_impl.hpp b/src/dft/backends/mklcpu/commit_derived_impl.hpp new file mode 100644 index 000000000..9b418d94c --- /dev/null +++ b/src/dft/backends/mklcpu/commit_derived_impl.hpp @@ -0,0 +1,83 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_ +#define _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_ + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" +#include "dft/backends/mklcpu/mklcpu_helpers.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { +namespace detail { + +// this is used for indexing bidirectional_handle +enum DIR { fwd = 0, bwd = 1 }; + +template +class commit_derived_impl final : public dft::detail::commit_impl { +private: + static constexpr DFTI_CONFIG_VALUE mklcpu_prec = to_mklcpu(prec); + static constexpr DFTI_CONFIG_VALUE mklcpu_dom = to_mklcpu(dom); + using mklcpu_desc_t = DFTI_DESCRIPTOR_HANDLE; + +public: + commit_derived_impl(sycl::queue queue, const dft::detail::dft_values& config_values); + + virtual void commit(const dft::detail::dft_values& config_values) override; + + virtual void* get_handle() noexcept override; + + virtual ~commit_derived_impl() override; + + sycl::buffer get_handle_buffer() noexcept { + return bidirection_buffer; + }; + +private: + // bidirectional_handle[0] is the forward handle, bidirectional_handle[1] is the backward handle + std::array bidirection_handle{ nullptr, nullptr }; + sycl::buffer bidirection_buffer{ bidirection_handle.data(), + sycl::range<1>{ 2 } }; + + template + void set_value_item(mklcpu_desc_t hand, enum DFTI_CONFIG_PARAM name, Args... args); + + void set_value(mklcpu_desc_t* descHandle, const dft::detail::dft_values& config); +}; + +template +using commit_t = dft::detail::commit_impl; + +template +using commit_derived_t = detail::commit_derived_impl; + +} // namespace detail +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_ diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp index a852f43a0..9f7eea851 100644 --- a/src/dft/backends/mklcpu/forward.cpp +++ b/src/dft/backends/mklcpu/forward.cpp @@ -30,79 +30,300 @@ #include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "dft/backends/mklcpu/commit_derived_impl.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { +namespace detail { + +// BUFFER version +// Forward a MKLCPU DFT call to the backend, checking that the commit impl is valid. +template +inline void check_fwd_commit(dft::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklcpu) { + throw mkl::invalid_argument("DFT", "computer_forward", + "DFT descriptor has not been commited for MKLCPU"); + } + + auto mklcpu_desc = reinterpret_cast(commit_handle->get_handle()); + MKL_LONG commit_status{ DFTI_UNCOMMITTED }; + DftiGetValue(mklcpu_desc[0], DFTI_COMMIT_STATUS, &commit_status); + if (commit_status != DFTI_COMMITTED) { + throw mkl::invalid_argument("DFT", "compute_forward", + "MKLCPU DFT descriptor was not successfully committed."); + } +} + +// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::detail::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("DFT", "compute_forward", message); + } +} + +// convert the base commit class to derived cpu commit class +template +auto get_buffer(commit_t *commit_handle) { + commit_derived_t *derived_commit = + static_cast *>(commit_handle); + return derived_commit->get_handle_buffer(); +} +} // namespace detail //In-place transform template -ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/, - sycl::buffer& /*inout*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inout_acc = inout.template get_access(cgh); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeForward(desc_acc[detail::DIR::fwd], inout_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/, - sycl::buffer& /*inout_re*/, - sycl::buffer& /*inout_im*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto re_acc = inout_re.template get_access(cgh); + auto im_acc = inout_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], re_acc.get_pointer(), + im_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform template -ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/, sycl::buffer& /*in*/, - sycl::buffer& /*out*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out) { + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], in_acc.get_pointer(), + out_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT void compute_forward(descriptor_type& /*desc*/, - sycl::buffer& /*in_re*/, - sycl::buffer& /*in_im*/, - sycl::buffer& /*out_re*/, - sycl::buffer& /*out_im*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inre_acc = in_re.template get_access(cgh); + auto inim_acc = in_im.template get_access(cgh); + auto outre_acc = out_re.template get_access(cgh); + auto outim_acc = out_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], + inre_acc.get_pointer(), inim_acc.get_pointer(), + outre_acc.get_pointer(), outim_acc.get_pointer()); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //USM version //In-place transform template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, data_type* /*inout*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout, + const std::vector &dependencies) { + detail::expect_config( + desc, "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], inout); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, data_type* /*inout_re*/, - data_type* /*inout_im*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, + data_type *inout_im, + const std::vector &dependencies) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], inout_re, inout_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, input_type* /*in*/, - output_type* /*out*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, + const std::vector &dependencies) { + // Check: inplace + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], in, out); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type& /*desc*/, input_type* /*in_re*/, - input_type* /*in_im*/, output_type* /*out_re*/, - output_type* /*out_im*/, - const std::vector& /*dependencies*/) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re, + input_type *in_im, output_type *out_re, + output_type *out_im, + const std::vector &dependencies) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeForward(desc_acc[detail::DIR::fwd], in_re, in_im, out_re, out_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } // Template function instantiations diff --git a/src/dft/backends/mklcpu/mklcpu_helpers.hpp b/src/dft/backends/mklcpu/mklcpu_helpers.hpp new file mode 100644 index 000000000..e4aad0dde --- /dev/null +++ b/src/dft/backends/mklcpu/mklcpu_helpers.hpp @@ -0,0 +1,182 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_ +#define _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_ + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { +namespace detail { + +template +static inline auto host_task_internal(H& cgh, F f, int) -> decltype(cgh.host_task(f)) { + return cgh.host_task(f); +} + +template +static inline void host_task(H& cgh, F f) { + (void)host_task_internal(cgh, f, 0); +} + +template +class kernel_name {}; + +/// Convert domain to equivalent backend native value. +inline constexpr DFTI_CONFIG_VALUE to_mklcpu(dft::detail::domain dom) { + if (dom == dft::detail::domain::REAL) { + return DFTI_REAL; + } + else { + return DFTI_COMPLEX; + } +} + +/// Convert precision to equivalent backend native value. +inline constexpr DFTI_CONFIG_VALUE to_mklcpu(dft::detail::precision dom) { + if (dom == dft::detail::precision::SINGLE) { + return DFTI_SINGLE; + } + else { + return DFTI_DOUBLE; + } +} + +/// Convert a config_param to equivalent backend native value. +inline constexpr DFTI_CONFIG_PARAM to_mklcpu(dft::detail::config_param param) { + using iparam = dft::detail::config_param; + switch (param) { + case iparam::FORWARD_DOMAIN: return DFTI_FORWARD_DOMAIN; + case iparam::DIMENSION: return DFTI_DIMENSION; + case iparam::LENGTHS: return DFTI_LENGTHS; + case iparam::PRECISION: return DFTI_PRECISION; + case iparam::FORWARD_SCALE: return DFTI_FORWARD_SCALE; + case iparam::NUMBER_OF_TRANSFORMS: return DFTI_NUMBER_OF_TRANSFORMS; + case iparam::COMPLEX_STORAGE: return DFTI_COMPLEX_STORAGE; + case iparam::REAL_STORAGE: return DFTI_REAL_STORAGE; + case iparam::CONJUGATE_EVEN_STORAGE: return DFTI_CONJUGATE_EVEN_STORAGE; + case iparam::INPUT_STRIDES: return DFTI_INPUT_STRIDES; + case iparam::OUTPUT_STRIDES: return DFTI_OUTPUT_STRIDES; + case iparam::FWD_DISTANCE: return DFTI_FWD_DISTANCE; + case iparam::BWD_DISTANCE: return DFTI_BWD_DISTANCE; + case iparam::WORKSPACE: return DFTI_WORKSPACE; + case iparam::ORDERING: return DFTI_ORDERING; + case iparam::TRANSPOSE: return DFTI_TRANSPOSE; + case iparam::PACKED_FORMAT: return DFTI_PACKED_FORMAT; + case iparam::COMMIT_STATUS: return DFTI_COMMIT_STATUS; + default: + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config param."); + return static_cast(0); + } +} + +/** Convert a config_value to the backend's native value. Throw on invalid input. + * @tparam Param The config param the value is for. + * @param value The config value to convert. +**/ +template +inline constexpr int to_mklcpu(dft::detail::config_value value); + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::COMPLEX_COMPLEX) { + return DFTI_COMPLEX_COMPLEX; + } + else if (value == dft::detail::config_value::REAL_REAL) { + return DFTI_REAL_REAL; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for complex storage."); + return 0; + } +} + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::REAL_REAL) { + return DFTI_REAL_REAL; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for real storage."); + return 0; + } +} +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::COMPLEX_COMPLEX) { + return DFTI_COMPLEX_COMPLEX; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for conjugate even storage."); + return 0; + } +} + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::INPLACE) { + return DFTI_INPLACE; + } + else if (value == dft::detail::config_value::NOT_INPLACE) { + return DFTI_NOT_INPLACE; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for inplace."); + return 0; + } +} + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::CCE_FORMAT) { + return DFTI_CCE_FORMAT; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for packed format."); + return 0; + } +} + +using mklcpu_desc_t = DFTI_DESCRIPTOR_HANDLE; + +} // namespace detail +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_ diff --git a/src/dft/descriptor.cxx b/src/dft/descriptor.cxx index d77d012b6..13dac0b2f 100644 --- a/src/dft/descriptor.cxx +++ b/src/dft/descriptor.cxx @@ -172,7 +172,9 @@ void descriptor::get_value(config_param param, ...) const { va_start(vl, param); switch (param) { case config_param::FORWARD_DOMAIN: *va_arg(vl, dft::domain*) = dom; break; - case config_param::DIMENSION: *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size()); break; + case config_param::DIMENSION: + *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size()); + break; case config_param::LENGTHS: std::copy(values_.dimensions.begin(), values_.dimensions.end(), va_arg(vl, std::int64_t*)); diff --git a/tests/unit_tests/dft/include/compute_inplace.hpp b/tests/unit_tests/dft/include/compute_inplace.hpp index cfcd465e7..e5bad3cff 100644 --- a/tests/unit_tests/dft/include/compute_inplace.hpp +++ b/tests/unit_tests/dft/include/compute_inplace.hpp @@ -256,6 +256,7 @@ int DFT_Test::test_in_place_USM() { if constexpr (domain == oneapi::mkl::dft::domain::REAL) { const auto real_strides = get_conjugate_even_real_component_strides(sizes); const auto complex_strides = get_conjugate_even_complex_strides(sizes); + descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, complex_strides.data()); descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, real_strides.data()); commit_descriptor(descriptor, sycl_queue); diff --git a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp index a7ab36da4..f2a30eff8 100644 --- a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp +++ b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp @@ -22,18 +22,19 @@ #include "compute_tester.hpp" -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_in_place_real_real_USM() { if (!init(MemoryAccessModel::usm)) { return test_skipped; } + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; - try { + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; PrecisionType backward_scale = 1.f / static_cast(forward_elements); - descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, @@ -42,6 +43,7 @@ int DFT_Test::test_in_place_real_real_USM() { descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); @@ -52,38 +54,54 @@ int DFT_Test::test_in_place_real_real_USM() { std::copy(input_im.begin(), input_im.end(), inout_im.begin()); std::vector dependencies; - sycl::event done = oneapi::mkl::dft::compute_forward( - descriptor, inout_re.data(), inout_im.data(), dependencies); - done.wait(); - - done = oneapi::mkl::dft::compute_backward, - PrecisionType>(descriptor, inout_re.data(), - inout_im.data(), dependencies); - done.wait(); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + try { + oneapi::mkl::dft::compute_forward( + descriptor, inout_re.data(), inout_im.data(), dependencies) + .wait(); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + + std::vector output_data(size_total); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { inout_re[i], inout_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), output_data.size(), + abs_error_margin, rel_error_margin, std::cout)); + + oneapi::mkl::dft::compute_backward, + PrecisionType>(descriptor, inout_re.data(), + inout_im.data(), dependencies) + .wait(); - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { inout_re[i], inout_im[i] }; + } - return !::testing::Test::HasFailure(); + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + + return !::testing::Test::HasFailure(); + } } -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_in_place_real_real_buffer() { if (!init(MemoryAccessModel::buffer)) { return test_skipped; } - try { + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; + + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; - PrecisionType backward_scale = 1.f / static_cast(forward_elements); + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, @@ -92,26 +110,62 @@ int DFT_Test::test_in_place_real_real_buffer() { descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); - commit_descriptor(descriptor, sycl_queue); - - sycl::buffer inout_re_buf{ input_re.data(), sycl::range<1>(size_total) }; - sycl::buffer inout_im_buf{ input_im.data(), sycl::range<1>(size_total) }; - oneapi::mkl::dft::compute_forward(descriptor, inout_re_buf, - inout_im_buf); + commit_descriptor(descriptor, sycl_queue); - oneapi::mkl::dft::compute_backward, - PrecisionType>(descriptor, inout_re_buf, inout_im_buf); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + std::vector host_inout_re(size_total, static_cast(0)); + std::vector host_inout_im(size_total, static_cast(0)); + std::copy(input_re.begin(), input_re.end(), host_inout_re.begin()); + std::copy(input_im.begin(), input_im.end(), host_inout_im.begin()); + + sycl::buffer inout_re_buf{ host_inout_re.data(), + sycl::range<1>(size_total) }; + sycl::buffer inout_im_buf{ host_inout_im.data(), + sycl::range<1>(size_total) }; + + try { + oneapi::mkl::dft::compute_forward(descriptor, inout_re_buf, + inout_im_buf); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + + { + auto acc_inout_re = inout_re_buf.template get_host_access(); + auto acc_inout_im = inout_im_buf.template get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_inout_re[i], acc_inout_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), + output_data.size(), abs_error_margin, rel_error_margin, + std::cout)); + } + + try { + oneapi::mkl::dft::compute_backward, + PrecisionType>(descriptor, inout_re_buf, + inout_im_buf); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + + { + auto acc_inout_re = inout_re_buf.template get_host_access(); + auto acc_inout_im = inout_im_buf.template get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_inout_re[i], acc_inout_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + } + return !::testing::Test::HasFailure(); } - - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); - - return !::testing::Test::HasFailure(); } #endif //ONEMKL_COMPUTE_INPLACE_REAL_REAL_HPP diff --git a/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp b/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp index 9883e06da..04e0c5561 100644 --- a/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp +++ b/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp @@ -22,18 +22,21 @@ #include "compute_tester.hpp" -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_out_of_place_real_real_USM() { if (!init(MemoryAccessModel::usm)) { return test_skipped; } - try { + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; + + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; - PrecisionType backward_scale = 1.f / static_cast(forward_elements); + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, @@ -42,6 +45,7 @@ int DFT_Test::test_out_of_place_real_real_USM() { descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); @@ -58,39 +62,60 @@ int DFT_Test::test_out_of_place_real_real_USM() { std::copy(input_im.begin(), input_im.end(), in_im.begin()); std::vector dependencies; - sycl::event done = - oneapi::mkl::dft::compute_forward( - descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), dependencies); - done.wait(); - done = oneapi::mkl::dft::compute_backward, - PrecisionType, PrecisionType>( - descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data()); - done.wait(); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + try { + oneapi::mkl::dft::compute_forward( + descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), dependencies) + .wait(); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + std::vector output_data(size_total); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { out_re[i], out_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), output_data.size(), + abs_error_margin, rel_error_margin, std::cout)); + + try { + oneapi::mkl::dft::compute_backward, + PrecisionType, PrecisionType>( + descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data()) + .wait(); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { out_back_re[i], out_back_im[i] }; + } + + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); } - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); - return !::testing::Test::HasFailure(); } -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_out_of_place_real_real_buffer() { if (!init(MemoryAccessModel::buffer)) { return test_skipped; } - try { + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; + + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; - PrecisionType backward_scale = 1.f / static_cast(forward_elements); + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, @@ -99,6 +124,7 @@ int DFT_Test::test_out_of_place_real_real_buffer() { descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); sycl::buffer in_dev_re{ input_re.data(), sycl::range<1>(size_total) }; @@ -108,21 +134,49 @@ int DFT_Test::test_out_of_place_real_real_buffer() { sycl::buffer out_back_dev_re{ sycl::range<1>(size_total) }; sycl::buffer out_back_dev_im{ sycl::range<1>(size_total) }; - oneapi::mkl::dft::compute_forward( - descriptor, in_dev_re, in_dev_im, out_dev_re, out_dev_im); - - oneapi::mkl::dft::compute_backward, - PrecisionType, PrecisionType>( - descriptor, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + try { + oneapi::mkl::dft::compute_forward( + descriptor, in_dev_re, in_dev_im, out_dev_re, out_dev_im); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + + { + auto acc_out_re = out_dev_re.template get_host_access(); + auto acc_out_im = out_dev_im.template get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_out_re[i], acc_out_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), + output_data.size(), abs_error_margin, rel_error_margin, + std::cout)); + } + + try { + oneapi::mkl::dft::compute_backward, + PrecisionType, PrecisionType>( + descriptor, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im); + } + catch (oneapi::mkl::unimplemented &e) { + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; + return test_skipped; + } + + { + auto acc_back_out_re = out_back_dev_re.template get_host_access(); + auto acc_back_out_im = out_back_dev_im.template get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_back_out_re[i], acc_back_out_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + } } - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); - return !::testing::Test::HasFailure(); } diff --git a/tests/unit_tests/dft/source/compute_tests.cpp b/tests/unit_tests/dft/source/compute_tests.cpp index 1627dbb42..e5685dced 100644 --- a/tests/unit_tests/dft/source/compute_tests.cpp +++ b/tests/unit_tests/dft/source/compute_tests.cpp @@ -87,16 +87,14 @@ std::vector test_params{ { shape{ 3, 7, 2 }, i64{ 1 } }, { shape{ 8, 8, 9 }, i64{ 1 } }, }; -// not currently implemented apis -std::vector no_tests{}; - INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_in_place, testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(test_params)), DFTParamsPrint{}); INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_in_place, - testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(no_tests)), + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params)), DFTParamsPrint{}); INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_out_of_place, @@ -105,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_out_of_place, DFTParamsPrint{}); INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_out_of_place, - testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(no_tests)), + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params)), DFTParamsPrint{}); } // anonymous namespace diff --git a/tests/unit_tests/dft/source/descriptor_tests.cpp b/tests/unit_tests/dft/source/descriptor_tests.cpp index fc27c1653..b3955fec0 100644 --- a/tests/unit_tests/dft/source/descriptor_tests.cpp +++ b/tests/unit_tests/dft/source/descriptor_tests.cpp @@ -432,13 +432,13 @@ inline void recommit_values(sycl::queue& sycl_queue) { // not changeable // FORWARD_DOMAIN, PRECISION, DIMENSION, COMMIT_STATUS { std::make_pair(config_param::LENGTHS, std::int64_t{ 10 }), - std::make_pair(config_param::FORWARD_SCALE, PrecisionType(1.2)), - std::make_pair(config_param::BACKWARD_SCALE, PrecisionType(3.4)) }, - { std::make_pair(config_param::NUMBER_OF_TRANSFORMS, std::int64_t{ 5 }), - std::make_pair(config_param::COMPLEX_STORAGE, config_value::COMPLEX_COMPLEX), + std::make_pair(config_param::FORWARD_SCALE, PrecisionType{ 1.2f }), + std::make_pair(config_param::BACKWARD_SCALE, PrecisionType{ 3.4f }) }, + { std::make_pair(config_param::COMPLEX_STORAGE, config_value::COMPLEX_COMPLEX), std::make_pair(config_param::REAL_STORAGE, config_value::REAL_REAL), std::make_pair(config_param::CONJUGATE_EVEN_STORAGE, config_value::COMPLEX_COMPLEX) }, { std::make_pair(config_param::PLACEMENT, config_value::NOT_INPLACE), + std::make_pair(config_param::NUMBER_OF_TRANSFORMS, std::int64_t{ 5 }), std::make_pair(config_param::INPUT_STRIDES, strides.data()), std::make_pair(config_param::OUTPUT_STRIDES, strides.data()), std::make_pair(config_param::FWD_DISTANCE, std::int64_t{ 60 }), @@ -565,13 +565,6 @@ int test_commit(sycl::device* dev) { } } - // TODO remove after #288 - if (sycl_queue.get_device().get_info() == - sycl::info::device_type::cpu) { - std::cout << "MKLCPU not implemented, skipping.\n"; - return test_skipped; - } - get_commited(sycl_queue); recommit_values(sycl_queue); change_queue_causes_wait(sycl_queue);