Skip to content

Commit

Permalink
[DFT] Add FWD/BWD_STRIDES to public API, deprecate INPUT/OUTPUT_STRID…
Browse files Browse the repository at this point in the history
…ES (#528)
  • Loading branch information
hjabird authored Jul 18, 2024
1 parent 30aedaf commit f2d2dcb
Show file tree
Hide file tree
Showing 18 changed files with 576 additions and 312 deletions.
11 changes: 8 additions & 3 deletions include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ enum class config_param {

PLACEMENT,

INPUT_STRIDES,
OUTPUT_STRIDES,
INPUT_STRIDES [[deprecated("Use FWD/BWD_STRIDES")]],
OUTPUT_STRIDES [[deprecated("Use FWD/BWD_STRIDES")]],

FWD_DISTANCE,
BWD_DISTANCE,
Expand All @@ -160,7 +160,10 @@ enum class config_param {
ORDERING,
TRANSPOSE,
PACKED_FORMAT,
COMMIT_STATUS
COMMIT_STATUS,

FWD_STRIDES,
BWD_STRIDES
};

enum class config_value {
Expand Down Expand Up @@ -204,6 +207,8 @@ class dft_values {
public:
std::vector<std::int64_t> input_strides;
std::vector<std::int64_t> output_strides;
std::vector<std::int64_t> fwd_strides;
std::vector<std::int64_t> bwd_strides;
real_t bwd_scale;
real_t fwd_scale;
std::int64_t number_of_transforms;
Expand Down
10 changes: 5 additions & 5 deletions src/dft/backends/cufft/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace oneapi::mkl::dft::cufft {
namespace detail {
//forward declaration
template <dft::precision prec, dft::domain dom>
std::array<std::int64_t, 2> get_offsets(dft::detail::commit_impl<prec, dom> *commit);
std::array<std::int64_t, 2> get_offsets_bwd(dft::detail::commit_impl<prec, dom> *commit);

template <dft::precision prec, dft::domain dom>
cufftHandle get_bwd_plan(dft::detail::commit_impl<prec, dom> *commit) {
Expand All @@ -56,7 +56,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_bwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_bwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer
Expand Down Expand Up @@ -102,7 +102,7 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc,
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_bwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_bwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
if (offsets[1] % 2 != 0) {
Expand Down Expand Up @@ -156,7 +156,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd<descriptor
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_bwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_bwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer
Expand Down Expand Up @@ -203,7 +203,7 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd<descriptor
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_bwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_bwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
if (offsets[1] % 2 != 0) {
Expand Down
203 changes: 104 additions & 99 deletions src/dft/backends/cufft/commit.cpp

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions src/dft/backends/cufft/forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace oneapi::mkl::dft::cufft {
namespace detail {
//forward declaration
template <dft::precision prec, dft::domain dom>
std::array<std::int64_t, 2> get_offsets(dft::detail::commit_impl<prec, dom> *commit);
std::array<std::int64_t, 2> get_offsets_fwd(dft::detail::commit_impl<prec, dom> *commit);

template <dft::precision prec, dft::domain dom>
cufftHandle get_fwd_plan(dft::detail::commit_impl<prec, dom> *commit) {
Expand All @@ -59,7 +59,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc,
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_fwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_fwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
if (offsets[0] % 2 != 0) {
Expand Down Expand Up @@ -104,7 +104,7 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer<fwd<descr
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_fwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_fwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
if (offsets[0] % 2 != 0) {
Expand Down Expand Up @@ -158,7 +158,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_fwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_fwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
if (offsets[0] % 2 != 0) {
Expand Down Expand Up @@ -205,7 +205,7 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd<descriptor_
auto commit = detail::checked_get_commit(desc);
auto queue = commit->get_queue();
auto plan = detail::get_fwd_plan(commit);
auto offsets = detail::get_offsets(commit);
auto offsets = detail::get_offsets_fwd(commit);

if constexpr (std::is_floating_point_v<fwd<descriptor_type>>) {
if (offsets[0] % 2 != 0) {
Expand Down
19 changes: 17 additions & 2 deletions src/dft/backends/mklcpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "oneapi/mkl/dft/detail/commit_impl.hpp"

#include "dft/backends/mklcpu/commit_derived_impl.hpp"
#include "../stride_helper.hpp"
#include "mkl_service.h"
#include "mkl_dfti.h"

Expand Down Expand Up @@ -129,9 +130,23 @@ void commit_derived_impl<prec, dom>::set_value_item(mklcpu_desc_t hand, enum DFT
template <dft::detail::precision prec, dft::detail::domain dom>
void commit_derived_impl<prec, dom>::set_value(mklcpu_desc_t* descHandle,
const dft::detail::dft_values<prec, dom>& config) {
auto stride_choice = dft::detail::get_stride_api(config);
dft::detail::throw_on_invalid_stride_api("MKLCPU commit", stride_choice);
for (auto dir : { DIR::fwd, DIR::bwd }) {
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data());
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data());
if (stride_choice == dft::detail::stride_api::IO_STRIDES) {
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data());
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data());
}
else { // Forward / backward strides
if (dir == DIR::fwd) {
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.fwd_strides.data());
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.bwd_strides.data());
}
else {
set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.bwd_strides.data());
set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.fwd_strides.data());
}
}
set_value_item(descHandle[dir], DFTI_BACKWARD_SCALE, config.bwd_scale);
set_value_item(descHandle[dir], DFTI_FORWARD_SCALE, config.fwd_scale);
set_value_item(descHandle[dir], DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms);
Expand Down
2 changes: 0 additions & 2 deletions src/dft/backends/mklcpu/mklcpu_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ inline constexpr DFTI_CONFIG_PARAM to_mklcpu(dft::detail::config_param param) {
case iparam::COMPLEX_STORAGE: return DFTI_COMPLEX_STORAGE;
case iparam::REAL_STORAGE: return DFTI_REAL_STORAGE;
case iparam::CONJUGATE_EVEN_STORAGE: return DFTI_CONJUGATE_EVEN_STORAGE;
case iparam::INPUT_STRIDES: return DFTI_INPUT_STRIDES;
case iparam::OUTPUT_STRIDES: return DFTI_OUTPUT_STRIDES;
case iparam::FWD_DISTANCE: return DFTI_FWD_DISTANCE;
case iparam::BWD_DISTANCE: return DFTI_BWD_DISTANCE;
case iparam::WORKSPACE: return DFTI_WORKSPACE;
Expand Down
45 changes: 30 additions & 15 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp"

#include "dft/backends/mklgpu/mklgpu_helpers.hpp"
#include "../stride_helper.hpp"

// MKLGPU header
#include "oneapi/mkl/dfti.hpp"
Expand Down Expand Up @@ -90,14 +91,18 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
config_values.workspace_placement ==
oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL);

auto stride_choice = dft::detail::get_stride_api(config_values);
throw_on_invalid_stride_api("MKLGPU commit", stride_choice);
// A separate descriptor for each direction may not be required.
bool one_descriptor = config_values.input_strides == config_values.output_strides;
bool one_descriptor = (stride_choice == dft::detail::stride_api::FB_STRIDES) ||
(config_values.input_strides == config_values.output_strides);
bool forward_good = true;
// Make sure that second is always pointing to something new if this is a recommit.
// Make sure that second is always pointing to something new if this is a recommit.
handle.second = handle.first;

// Generate forward DFT descriptor.
set_value(*handle.first, config_values, true);
// Generate forward DFT descriptor. If using FWD/BWD_STRIDES API, only
// one descriptor is needed.
set_value(*handle.first, config_values, true, stride_choice);
try {
handle.first->commit(this->get_queue());
}
Expand All @@ -114,7 +119,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
// Generate backward DFT descriptor only if required.
if (!one_descriptor) {
handle.second = std::make_shared<mklgpu_descriptor_t>(config_values.dimensions);
set_value(*handle.second, config_values, false);
set_value(*handle.second, config_values, false, stride_choice);
try {
handle.second->commit(this->get_queue());
}
Expand Down Expand Up @@ -160,7 +165,7 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
handle_t handle;

void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values<prec, dom>& config,
bool assume_fwd_dft) {
bool assume_fwd_dft, dft::detail::stride_api stride_choice) {
using onemkl_param = dft::detail::config_param;
using backend_param = dft::config_param;

Expand All @@ -181,17 +186,27 @@ class mklgpu_commit final : public dft::detail::commit_impl<prec, dom> {
desc.set_value(backend_param::PLACEMENT,
to_mklgpu<onemkl_param::PLACEMENT>(config.placement));

if (config.input_strides[0] != 0 || config.output_strides[0] != 0) {
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
if (stride_choice == dft::detail::stride_api::FB_STRIDES) {
if (config.fwd_strides[0] != 0 || config.fwd_strides[0] != 0) {
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
desc.set_value(backend_param::FWD_STRIDES, config.fwd_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.bwd_strides.data());
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
if (config.input_strides[0] != 0 || config.output_strides[0] != 0) {
throw mkl::unimplemented("dft/backends/mklgpu", "commit",
"MKLGPU does not support nonzero offsets.");
}
if (assume_fwd_dft) {
desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data());
}
else {
desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data());
desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data());
}
}
desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist);
desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist);
Expand Down
2 changes: 0 additions & 2 deletions src/dft/backends/mklgpu/mklgpu_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ inline constexpr dft::config_param to_mklgpu(dft::detail::config_param param) {
case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE;
case iparam::REAL_STORAGE: return oparam::REAL_STORAGE;
case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE;
case iparam::INPUT_STRIDES: return oparam::INPUT_STRIDES;
case iparam::OUTPUT_STRIDES: return oparam::OUTPUT_STRIDES;
case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE;
case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE;
case iparam::WORKSPACE: return oparam::WORKSPACE;
Expand Down
28 changes: 16 additions & 12 deletions src/dft/backends/portfft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp"
#include "oneapi/mkl/dft/types.hpp"

#include "../stride_helper.hpp"

#include "portfft_helper.hpp"

// alias to avoid ambiguity
Expand Down Expand Up @@ -87,6 +89,10 @@ class portfft_commit final : public dft::detail::commit_impl<prec, dom> {
"portFFT does not supported transposed output");
}

auto stride_api_choice = dft::detail::get_stride_api(config_values);
dft::detail::throw_on_invalid_stride_api("portFFT commit", stride_api_choice);
dft::detail::stride_vectors<std::int64_t> stride_vecs(config_values, stride_api_choice);

// forward descriptor
pfft::descriptor<scalar_type, domain> fwd_desc(
{ config_values.dimensions.cbegin(), config_values.dimensions.cend() });
Expand All @@ -100,12 +106,11 @@ class portfft_commit final : public dft::detail::commit_impl<prec, dom> {
fwd_desc.placement = config_values.placement == config_value::INPLACE
? pfft::placement::IN_PLACE
: pfft::placement::OUT_OF_PLACE;
fwd_desc.forward_offset = static_cast<std::size_t>(config_values.input_strides[0]);
fwd_desc.backward_offset = static_cast<std::size_t>(config_values.output_strides[0]);
fwd_desc.forward_strides = { config_values.input_strides.cbegin() + 1,
config_values.input_strides.cend() };
fwd_desc.backward_strides = { config_values.output_strides.cbegin() + 1,
config_values.output_strides.cend() };
fwd_desc.forward_offset = static_cast<std::size_t>(stride_vecs.offset_fwd_in);
fwd_desc.backward_offset = static_cast<std::size_t>(stride_vecs.offset_fwd_out);
fwd_desc.forward_strides = { stride_vecs.fwd_in.cbegin() + 1, stride_vecs.fwd_in.cend() };
fwd_desc.backward_strides = { stride_vecs.fwd_out.cbegin() + 1,
stride_vecs.fwd_out.cend() };
fwd_desc.forward_distance = static_cast<std::size_t>(config_values.fwd_dist);
fwd_desc.backward_distance = static_cast<std::size_t>(config_values.bwd_dist);

Expand All @@ -122,12 +127,11 @@ class portfft_commit final : public dft::detail::commit_impl<prec, dom> {
bwd_desc.placement = config_values.placement == config_value::INPLACE
? pfft::placement::IN_PLACE
: pfft::placement::OUT_OF_PLACE;
bwd_desc.forward_offset = static_cast<std::size_t>(config_values.output_strides[0]);
bwd_desc.backward_offset = static_cast<std::size_t>(config_values.input_strides[0]);
bwd_desc.forward_strides = { config_values.output_strides.cbegin() + 1,
config_values.output_strides.cend() };
bwd_desc.backward_strides = { config_values.input_strides.cbegin() + 1,
config_values.input_strides.cend() };
bwd_desc.forward_offset = static_cast<std::size_t>(stride_vecs.offset_bwd_out);
bwd_desc.backward_offset = static_cast<std::size_t>(stride_vecs.offset_bwd_in);
bwd_desc.forward_strides = { stride_vecs.bwd_out.cbegin() + 1, stride_vecs.bwd_out.cend() };
bwd_desc.backward_strides = { stride_vecs.bwd_in.cbegin() + 1,
stride_vecs.bwd_in.cend() };
bwd_desc.forward_distance = static_cast<std::size_t>(config_values.fwd_dist);
bwd_desc.backward_distance = static_cast<std::size_t>(config_values.bwd_dist);

Expand Down
Loading

0 comments on commit f2d2dcb

Please sign in to comment.