Skip to content
Merged
8 changes: 8 additions & 0 deletions include/oneapi/mkl/dft/detail/commit_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ enum class backend;

namespace oneapi::mkl::dft::detail {

enum class precision;
enum class domain;
template <precision prec, domain dom>
class dft_values;

template <precision prec, domain dom>
class commit_impl {
public:
commit_impl(sycl::queue queue, mkl::backend backend) : queue_(queue), backend_(backend) {}
Expand All @@ -51,6 +57,8 @@ class commit_impl {

virtual void* get_handle() noexcept = 0;

virtual void commit(const dft_values<prec, dom>&) = 0;

private:
mkl::backend backend_;
sycl::queue queue_;
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/mkl/dft/detail/descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template <precision prec, domain dom>
class descriptor;

template <precision prec, domain dom>
inline commit_impl* get_commit(descriptor<prec, dom>& desc);
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc);

template <precision prec, domain dom>
class descriptor {
Expand Down Expand Up @@ -74,16 +74,16 @@ class descriptor {

private:
// Has a value when the descriptor is committed.
std::unique_ptr<commit_impl> pimpl_;
std::unique_ptr<commit_impl<prec, dom>> pimpl_;

// descriptor configuration values_ and structs
dft_values<prec, dom> values_;

friend commit_impl* get_commit<prec, dom>(descriptor<prec, dom>&);
friend commit_impl<prec, dom>* get_commit<prec, dom>(descriptor<prec, dom>&);
};

template <precision prec, domain dom>
inline commit_impl* get_commit(descriptor<prec, dom>& desc) {
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc) {
return desc.pimpl_.get();
}

Expand Down
2 changes: 1 addition & 1 deletion include/oneapi/mkl/dft/detail/dft_ct.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// Commit

template <dft::detail::precision prec, dft::detail::domain dom>
ONEMKL_EXPORT dft::detail::commit_impl *create_commit(
ONEMKL_EXPORT dft::detail::commit_impl<prec, dom> *create_commit(
const dft::detail::descriptor<prec, dom> &desc, sycl::queue &sycl_queue);

// BUFFER version
Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/mkl/dft/detail/dft_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ namespace mkl {
namespace dft {
namespace detail {

template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
class descriptor;

template <precision prec, domain dom>
ONEMKL_EXPORT commit_impl* create_commit(const descriptor<prec, dom>& desc, sycl::queue& queue);
ONEMKL_EXPORT commit_impl<prec, dom>* create_commit(const descriptor<prec, dom>& desc,
sycl::queue& queue);

} // namespace detail
} // namespace dft
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(sycl::queue &queue) {
pimpl_.reset(detail::create_commit(*this, queue));
if (!pimpl_ || pimpl_->get_queue() != queue) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(detail::create_commit(*this, queue));
}
pimpl_->commit(values_);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's worth adding a recommit_values test which would change the queue between commit calls?

Copy link
Contributor Author

@FMarno FMarno Mar 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a tests to show that there is a wait on the old queue here 088fa08
I was unable to prove that work is submitted to the new queue or not submitted to the old.

  • I don't believe there is any direct way to query if a queue has not completed tasks
  • I don't believe there is a way to query if an event is associated with a task.
  • When an command group function object is submitted to the queue, I don't believe there are any guarantees that is will complete before slow work on other queues.

Because of all this I can't think of how to test that reliably.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for that. I think you're right for every bullet point you mentioned and I'm sorry if my original suggestion generated complications/misunderstanding on the testing scope. Those changes look fine to me.
To be clearer, my original concern was about a use case like

  1. create a descriptor desc;
  2. create a queue q1 and commit desc to q1;
  3. delete q1;
  4. create a queue q2 and commit desc to q2;
  5. submit a compute task to desc.
    I am not 100% sure this works as expected for the closed-source mkl. But that is something we can/should check on our side as that would be an issue for us to fix anyways. We may consider this thread resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's interesting. I think this should work because objects like sycl::queue are kept alive by the SYCL runtime with reference counting semantics. I guess I could have a test that follows the steps you've detailed, waits on q2 and then shows that the event returned from the task is complete. It's not perfect but its something. I'm about to finish for the day but I'll do this sometime tomorrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test for that in 8d0a608

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(sycl::queue &);
template void descriptor<precision::SINGLE, domain::REAL>::commit(sycl::queue &);
Expand Down
38 changes: 21 additions & 17 deletions src/dft/backends/mklcpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ namespace dft {
namespace mklcpu {

template <precision prec, domain dom>
class commit_derived_impl final : public detail::commit_impl {
class commit_derived_impl final : public detail::commit_impl<prec, dom> {
public:
commit_derived_impl(sycl::queue queue, const detail::dft_values<prec, dom>& config_values)
: detail::commit_impl(queue, backend::mklcpu) {
: detail::commit_impl<prec, dom>(queue, backend::mklcpu) {
DFT_ERROR status = DFT_NOTSET;
if (config_values.dimensions.size() == 1) {
status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), 1,
Expand All @@ -55,16 +55,19 @@ class commit_derived_impl final : public detail::commit_impl {
config_values.dimensions.data());
}
if (status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
"DftiCreateDescriptor failed");
throw oneapi::mkl::exception(
"dft/backends/mklcpu", "commit",
"DftiCreateDescriptor failed with status: " + std::to_string(status));
}
}

void commit(const detail::dft_values<prec, dom>& config_values) override {
set_value(handle, config_values);

status = DftiCommitDescriptor(handle);
auto status = DftiCommitDescriptor(handle);
if (status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
"DftiCommitDescriptor failed");
throw oneapi::mkl::exception(
"dft/backends/mklcpu", "commit",
"DftiCommitDescriptor failed with status: " + std::to_string(status));
}
}

Expand Down Expand Up @@ -122,18 +125,19 @@ class commit_derived_impl final : public detail::commit_impl {
};

template <precision prec, domain dom>
detail::commit_impl* create_commit(const descriptor<prec, dom>& desc, sycl::queue& sycl_queue) {
detail::commit_impl<prec, dom>* create_commit(const descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
return new commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
}

template detail::commit_impl* create_commit(const descriptor<precision::SINGLE, domain::REAL>&,
sycl::queue&);
template detail::commit_impl* create_commit(const descriptor<precision::SINGLE, domain::COMPLEX>&,
sycl::queue&);
template detail::commit_impl* create_commit(const descriptor<precision::DOUBLE, domain::REAL>&,
sycl::queue&);
template detail::commit_impl* create_commit(const descriptor<precision::DOUBLE, domain::COMPLEX>&,
sycl::queue&);
template detail::commit_impl<precision::SINGLE, domain::REAL>* create_commit(
const descriptor<precision::SINGLE, domain::REAL>&, sycl::queue&);
template detail::commit_impl<precision::SINGLE, domain::COMPLEX>* create_commit(
const descriptor<precision::SINGLE, domain::COMPLEX>&, sycl::queue&);
template detail::commit_impl<precision::DOUBLE, domain::REAL>* create_commit(
const descriptor<precision::DOUBLE, domain::REAL>&, sycl::queue&);
template detail::commit_impl<precision::DOUBLE, domain::COMPLEX>* create_commit(
const descriptor<precision::DOUBLE, domain::COMPLEX>&, sycl::queue&);

} // namespace mklcpu
} // namespace dft
Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/mklcpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklcpu> selector) {
pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue()));
if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
}
pimpl_->commit(values_);
}

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(
Expand Down
28 changes: 16 additions & 12 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace detail {

/// Commit impl class specialization for MKLGPU.
template <dft::detail::precision prec, dft::detail::domain dom>
class commit_derived_impl final : public dft::detail::commit_impl {
class commit_derived_impl final : public dft::detail::commit_impl<prec, dom> {
private:
// Equivalent MKLGPU precision and domain from OneMKL's precision / domain.
static constexpr dft::precision mklgpu_prec = to_mklgpu(prec);
Expand All @@ -60,19 +60,21 @@ class commit_derived_impl final : public dft::detail::commit_impl {

public:
commit_derived_impl(sycl::queue queue, const dft::detail::dft_values<prec, dom>& config_values)
: oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu),
: oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::mklgpu),
handle(config_values.dimensions) {
set_value(handle, config_values);
// MKLGPU does not throw an informative exception for the following:
if constexpr (prec == dft::detail::precision::DOUBLE) {
if (!queue.get_device().has(sycl::aspect::fp64)) {
throw mkl::exception("dft/backends/mklgpu", "commit",
"Device does not support double precision.");
}
}
}

virtual void commit(const dft::detail::dft_values<prec, dom>& config_values) override {
set_value(handle, config_values);
try {
handle.commit(queue);
handle.commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real MKL exception causes headaches with naming.
Expand Down Expand Up @@ -125,28 +127,30 @@ class commit_derived_impl final : public dft::detail::commit_impl {
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"MKLGPU only supports non-transposed.");
}
desc.set_value(backend_param::PACKED_FORMAT,
to_mklgpu<onemkl_param::PACKED_FORMAT>(config.packed_format));
}
};
} // namespace detail

template <dft::detail::precision prec, dft::detail::domain dom>
dft::detail::commit_impl* create_commit(const dft::detail::descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
dft::detail::commit_impl<prec, dom>* create_commit(const dft::detail::descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
return new detail::commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
}

template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::REAL>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::REAL>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>&,
sycl::queue&);

Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/mklgpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklgpu> selector) {
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
}
pimpl_->commit(values_);
}

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(
Expand Down
4 changes: 0 additions & 4 deletions src/dft/descriptor.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ void compute_default_strides(const std::vector<std::int64_t>& dimensions,

template <precision prec, domain dom>
void descriptor<prec, dom>::set_value(config_param param, ...) {
if (pimpl_) {
throw mkl::invalid_argument("DFT", "set_value",
"Cannot set value on committed descriptor.");
}
va_list vl;
va_start(vl, param);
switch (param) {
Expand Down
8 changes: 4 additions & 4 deletions src/dft/dft_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,28 @@ static oneapi::mkl::detail::table_initializer<mkl::domain::dft, dft_function_tab
function_tables;

template <>
commit_impl* create_commit<precision::SINGLE, domain::COMPLEX>(
commit_impl<precision::SINGLE, domain::COMPLEX>* create_commit<precision::SINGLE, domain::COMPLEX>(
const descriptor<precision::SINGLE, domain::COMPLEX>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_fz(desc, sycl_queue);
}

template <>
commit_impl* create_commit<precision::DOUBLE, domain::COMPLEX>(
commit_impl<precision::DOUBLE, domain::COMPLEX>* create_commit<precision::DOUBLE, domain::COMPLEX>(
const descriptor<precision::DOUBLE, domain::COMPLEX>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_dz(desc, sycl_queue);
}

template <>
commit_impl* create_commit<precision::SINGLE, domain::REAL>(
commit_impl<precision::SINGLE, domain::REAL>* create_commit<precision::SINGLE, domain::REAL>(
const descriptor<precision::SINGLE, domain::REAL>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_fr(desc, sycl_queue);
}

template <>
commit_impl* create_commit<precision::DOUBLE, domain::REAL>(
commit_impl<precision::DOUBLE, domain::REAL>* create_commit<precision::DOUBLE, domain::REAL>(
const descriptor<precision::DOUBLE, domain::REAL>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_dr(desc, sycl_queue);
Expand Down
14 changes: 10 additions & 4 deletions src/dft/function_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,25 @@

typedef struct {
int version;
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fz)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>* (
*create_commit_sycl_fz)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>& desc,
sycl::queue& sycl_queue);
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dz)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::COMPLEX>* (
*create_commit_sycl_dz)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::COMPLEX>& desc,
sycl::queue& sycl_queue);
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fr)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::REAL>* (*create_commit_sycl_fr)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::REAL>& desc,
sycl::queue& sycl_queue);
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dr)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::REAL>* (*create_commit_sycl_dr)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::REAL>& desc,
sycl::queue& sycl_queue);
Expand Down
Loading