Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement dpnp.fft.rfft and dpnp.fft.irfft #1928

Merged
merged 11 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions dpnp/backend/extensions/fft/fft_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,39 +68,57 @@ void register_descriptor(py::module &m, const char *name)
PYBIND11_MODULE(_fft_impl, m)
{
constexpr mkl_dft::domain complex_dom = mkl_dft::domain::COMPLEX;
constexpr mkl_dft::domain real_dom = mkl_dft::domain::REAL;

constexpr mkl_dft::precision single_prec = mkl_dft::precision::SINGLE;
constexpr mkl_dft::precision double_prec = mkl_dft::precision::DOUBLE;

register_descriptor<single_prec, complex_dom>(m, "Complex64Descriptor");
register_descriptor<double_prec, complex_dom>(m, "Complex128Descriptor");
register_descriptor<single_prec, real_dom>(m, "Real32Descriptor");
register_descriptor<double_prec, real_dom>(m, "Real64Descriptor");

// out-of-place c2c FFT, both SINGLE and DOUBLE precisions are supported
// with overloading of "_fft_out_of_place" function on python side
m.def("_fft_out_of_place",
// out-of-place FFT, all possible combination (single/double precisions and
// real/complex domains) are supported with overloading of
// "_fft_out_of_place" function on python side
m.def("_fft_out_of_place", // single precision c2c out-of-place FFT
&fft_ns::compute_fft_out_of_place<single_prec, complex_dom>,
"Compute out-of-place complex-to-complex fft using OneMKL DFT "
"library for complex64 data types.",
py::arg("descriptor"), py::arg("input"), py::arg("output"),
py::arg("is_forward"), py::arg("depends") = py::list());

m.def("_fft_out_of_place",
m.def("_fft_out_of_place", // double precision c2c out-of-place FFT
&fft_ns::compute_fft_out_of_place<double_prec, complex_dom>,
"Compute out-of-place complex-to-complex fft using OneMKL DFT "
"library for complex128 data types.",
py::arg("descriptor"), py::arg("input"), py::arg("output"),
py::arg("is_forward"), py::arg("depends") = py::list());

// in-place c2c FFT, both SINGLE and DOUBLE precisions are supported with
m.def("_fft_out_of_place", // single precision r2c/c2r out-of-place FFT
&fft_ns::compute_fft_out_of_place<single_prec, real_dom>,
"Compute out-of-place real-to-complex fft using OneMKL DFT library "
"for float32 data types.",
py::arg("descriptor"), py::arg("input"), py::arg("output"),
py::arg("is_forward"), py::arg("depends") = py::list());

m.def("_fft_out_of_place", // double precision r2c/c2r out-of-place FFT
&fft_ns::compute_fft_out_of_place<double_prec, real_dom>,
"Compute out-of-place real-to-complex fft using OneMKL DFT library "
"for float64 data types.",
py::arg("descriptor"), py::arg("input"), py::arg("output"),
py::arg("is_forward"), py::arg("depends") = py::list());

// in-place c2c FFT, both single and double precisions are supported with
// overloading of "_fft_in_place" function on python side
m.def("_fft_in_place",
m.def("_fft_in_place", // single precision c2c in-place FFT
&fft_ns::compute_fft_in_place<single_prec, complex_dom>,
"Compute in-place complex-to-complex fft using OneMKL DFT library "
"for complex64 data types.",
py::arg("descriptor"), py::arg("input-output"), py::arg("is_forward"),
py::arg("depends") = py::list());

m.def("_fft_in_place",
m.def("_fft_in_place", // double precision c2c in-place FFT
&fft_ns::compute_fft_in_place<double_prec, complex_dom>,
"Compute in-place complex-to-complex fft using OneMKL DFT library "
"for complex128 data types.",
Expand Down
51 changes: 44 additions & 7 deletions dpnp/backend/extensions/fft/fft_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,58 @@ namespace dpnp::extensions::fft
{
namespace mkl_dft = oneapi::mkl::dft;

// Structure to map MKL precision to float/double types
template <mkl_dft::precision prec>
struct ScaleType
struct PrecisionType;

template <>
struct PrecisionType<mkl_dft::precision::SINGLE>
{
using value_type = void;
using type = float;
};

template <>
struct ScaleType<mkl_dft::precision::SINGLE>
struct PrecisionType<mkl_dft::precision::DOUBLE>
{
using value_type = float;
using type = double;
};

template <>
struct ScaleType<mkl_dft::precision::DOUBLE>
// Structure to map combination of precision, domain, and is_forward flag to
// in/out types
template <mkl_dft::precision prec, mkl_dft::domain dom, bool is_forward>
struct ScaleType
vtavana marked this conversation as resolved.
Show resolved Hide resolved
{
using type_in = void;
using type_out = void;
};

// for r2c FFT, type_in is real and type_out is complex
// is_forward is true
template <mkl_dft::precision prec>
struct ScaleType<prec, mkl_dft::domain::REAL, true>
{
using prec_type = typename PrecisionType<prec>::type;
using type_in = prec_type;
using type_out = std::complex<prec_type>;
};

// for c2r FFT, type_in is complex and type_out is real
// is_forward is false
template <mkl_dft::precision prec>
struct ScaleType<prec, mkl_dft::domain::REAL, false>
{
using prec_type = typename PrecisionType<prec>::type;
using type_in = std::complex<prec_type>;
using type_out = prec_type;
};

// for c2c FFT, both type_in and type_out are complex
// regardless of is_fwd
template <mkl_dft::precision prec, bool is_fwd>
struct ScaleType<prec, mkl_dft::domain::COMPLEX, is_fwd>
{
using value_type = double;
using prec_type = typename PrecisionType<prec>::type;
using type_in = std::complex<prec_type>;
using type_out = std::complex<prec_type>;
};
} // namespace dpnp::extensions::fft
8 changes: 6 additions & 2 deletions dpnp/backend/extensions/fft/in_place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ std::pair<sycl::event, sycl::event>

dpctl::tensor::validation::CheckWritable::throw_if_not_writable(in_out);

using ScaleT = typename ScaleType<prec>::value_type;
std::complex<ScaleT> *in_out_ptr = in_out.get_data<std::complex<ScaleT>>();
// in-place is only used for c2c FFT at this time, passing true or false is
// indifferent
using ScaleT = typename ScaleType<prec, dom, true>::type_in;
ScaleT *in_out_ptr = in_out.get_data<ScaleT>();

sycl::event fft_event = {};
std::stringstream error_msg;
Expand Down Expand Up @@ -104,13 +106,15 @@ std::pair<sycl::event, sycl::event>
}

// Explicit instantiations
// single precision c2c FFT
template std::pair<sycl::event, sycl::event> compute_fft_in_place(
DescriptorWrapper<mkl_dft::precision::SINGLE, mkl_dft::domain::COMPLEX>
&descr,
const dpctl::tensor::usm_ndarray &in_out,
const bool is_forward,
const std::vector<sycl::event> &depends);

// double precision c2c FFT
template std::pair<sycl::event, sycl::event> compute_fft_in_place(
DescriptorWrapper<mkl_dft::precision::DOUBLE, mkl_dft::domain::COMPLEX>
&descr,
Expand Down
71 changes: 60 additions & 11 deletions dpnp/backend/extensions/fft/out_of_place.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,63 @@ std::pair<sycl::event, sycl::event>
"execution queue of the descriptor.");
}

py::ssize_t in_size = in.get_size();
py::ssize_t out_size = out.get_size();
if (in_size != out_size) {
throw py::value_error("The size of the input vector must be "
"equal to the size of the output vector.");
const py::ssize_t *in_shape = in.get_shape_raw();
const py::ssize_t *out_shape = out.get_shape_raw();
const std::int64_t m = in_shape[in_nd - 1];
const std::int64_t n = out_shape[out_nd - 1];

std::int64_t in_size = 1;
if (in_nd > 1) {
for (int i = 0; i < in_nd - 1; ++i) {
if (in_shape[i] != out_shape[i]) {
throw py::value_error("The shape of the input and output "
"arrays must be the same.");
}
in_size *= in_shape[i];
}
}

size_t src_nelems = in_size;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(out, src_nelems);
std::int64_t N;
if (dom == mkl_dft::domain::REAL && is_forward) {
// r2c FFT
N = m / 2 + 1; // integer divide
if (n != N) {
throw py::value_error("The shape of the output array is not "
"correct for real to complex transform.");
}
}
else {
// c2c and c2r FFT. For c2r FFT, input is zero-padded in python side to
// have the same size as output before calling this function
N = m;
if (n != N) {
throw py::value_error("The shape of the input array must be "
"the same as the shape of the output array.");
}
}

using ScaleT = typename ScaleType<prec>::value_type;
std::complex<ScaleT> *in_ptr = in.get_data<std::complex<ScaleT>>();
std::complex<ScaleT> *out_ptr = out.get_data<std::complex<ScaleT>>();
const std::size_t n_elems = in_size * N;
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out);
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(out, n_elems);

sycl::event fft_event = {};
std::stringstream error_msg;
bool is_exception_caught = false;

try {
if (is_forward) {
using ScaleT_in = typename ScaleType<prec, dom, true>::type_in;
using ScaleT_out = typename ScaleType<prec, dom, true>::type_out;
ScaleT_in *in_ptr = in.get_data<ScaleT_in>();
ScaleT_out *out_ptr = out.get_data<ScaleT_out>();
fft_event = mkl_dft::compute_forward(descr.get_descriptor(), in_ptr,
out_ptr, depends);
}
else {
using ScaleT_in = typename ScaleType<prec, dom, false>::type_in;
using ScaleT_out = typename ScaleType<prec, dom, false>::type_out;
ScaleT_in *in_ptr = in.get_data<ScaleT_in>();
ScaleT_out *out_ptr = out.get_data<ScaleT_out>();
fft_event = mkl_dft::compute_backward(descr.get_descriptor(),
in_ptr, out_ptr, depends);
}
Expand All @@ -133,6 +165,7 @@ std::pair<sycl::event, sycl::event>
}

// Explicit instantiations
// single precision c2c FFT
template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
DescriptorWrapper<mkl_dft::precision::SINGLE, mkl_dft::domain::COMPLEX>
&descr,
Expand All @@ -141,6 +174,7 @@ template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
const bool is_forward,
const std::vector<sycl::event> &depends);

// double precision c2c FFT
template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
DescriptorWrapper<mkl_dft::precision::DOUBLE, mkl_dft::domain::COMPLEX>
&descr,
Expand All @@ -149,4 +183,19 @@ template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
const bool is_forward,
const std::vector<sycl::event> &depends);

// single precision r2c/c2r FFT
template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
DescriptorWrapper<mkl_dft::precision::SINGLE, mkl_dft::domain::REAL> &descr,
const dpctl::tensor::usm_ndarray &in,
const dpctl::tensor::usm_ndarray &out,
const bool is_forward,
const std::vector<sycl::event> &depends);

// double precision r2c/c2r FFT
template std::pair<sycl::event, sycl::event> compute_fft_out_of_place(
DescriptorWrapper<mkl_dft::precision::DOUBLE, mkl_dft::domain::REAL> &descr,
const dpctl::tensor::usm_ndarray &in,
const dpctl::tensor::usm_ndarray &out,
const bool is_forward,
const std::vector<sycl::event> &depends);
} // namespace dpnp::extensions::fft
19 changes: 8 additions & 11 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,14 @@ enum class DPNPFuncName : size_t
DPNP_FN_DOT, /**< Used in numpy.dot() impl */
DPNP_FN_DOT_EXT, /**< Used in numpy.dot() impl, requires extra parameters */
DPNP_FN_EDIFF1D, /**< Used in numpy.ediff1d() impl */
DPNP_FN_EDIFF1D_EXT, /**< Used in numpy.ediff1d() impl, requires extra
parameters */
DPNP_FN_ERF, /**< Used in scipy.special.erf impl */
DPNP_FN_ERF_EXT, /**< Used in scipy.special.erf impl, requires extra
parameters */
DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() impl */
DPNP_FN_FFT_FFT_EXT, /**< Used in numpy.fft.fft() impl, requires extra
parameters */
DPNP_FN_FFT_RFFT, /**< Used in numpy.fft.rfft() impl */
vtavana marked this conversation as resolved.
Show resolved Hide resolved
DPNP_FN_FFT_RFFT_EXT, /**< Used in numpy.fft.rfft() impl, requires extra
parameters */
DPNP_FN_EDIFF1D_EXT, /**< Used in numpy.ediff1d() impl, requires extra
parameters */
DPNP_FN_ERF, /**< Used in scipy.special.erf impl */
DPNP_FN_ERF_EXT, /**< Used in scipy.special.erf impl, requires extra
parameters */
DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() impl */
DPNP_FN_FFT_FFT_EXT, /**< Used in numpy.fft.fft() impl, requires extra
parameters */
DPNP_FN_INITVAL, /**< Used in numpy ones, ones_like, zeros, zeros_like impls
*/
DPNP_FN_INITVAL_EXT, /**< Used in numpy ones, ones_like, zeros, zeros_like
Expand Down
Loading
Loading