Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 87 additions & 76 deletions dpnp/backend/kernels/dpnp_krnl_fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
const size_t shape_size,
const size_t input_size,
const size_t result_size,
_Descriptor_type& desc,
size_t inverse,
const size_t norm)
{
Expand All @@ -187,14 +186,15 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
(void)input_size;
(void)result_size;

if (!shape_size) {
if (!shape_size)
{
return;
}

sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));

_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
_DataType_output* result = static_cast<_DataType_output *>(result_out);
_DataType_input* array_1 = static_cast<_DataType_input*>(const_cast<void*>(array1_in));
_DataType_output* result = static_cast<_DataType_output*>(result_out);

const size_t n_iter =
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
Expand All @@ -204,39 +204,49 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
double backward_scale = 1.;
double forward_scale = 1.;

if (norm == 0) { // norm = "backward"
if (norm == 0) // norm = "backward"
{
backward_scale = 1. / shift;
} else if (norm == 1) { // norm = "forward"
}
else if (norm == 1) // norm = "forward"
{
forward_scale = 1. / shift;
} else { // norm = "ortho"
if (inverse) {
}
else // norm = "ortho"
{
if (inverse)
{
backward_scale = 1. / sqrt(shift);
} else {
}
else
{
forward_scale = 1. / sqrt(shift);
}
}

desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
// enum value from math library C interface
// instead of mkl_dft::config_value::NOT_INPLACE
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
desc.commit(queue);

std::vector<sycl::event> fft_events;
fft_events.reserve(n_iter);

for (size_t i = 0; i < n_iter; ++i) {
if (inverse) {
fft_events.push_back(mkl_dft::compute_backward(desc, array_1 + i * shift, result + i * shift));
} else {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * shift, result + i * shift));
std::vector<sycl::event> fft_events(n_iter);

for (size_t i = 0; i < n_iter; ++i)
{
std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(shift);
desc->set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc->set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
desc->set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
desc->commit(queue);

if (inverse)
{
fft_events[i] = mkl_dft::compute_backward<_Descriptor_type, _DataType_input, _DataType_output>(
*desc, array_1 + i * shift, result + i * shift);
}
else
{
fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
*desc, array_1 + i * shift, result + i * shift);
}
}

sycl::event::wait(fft_events);

return;
}

template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
Expand All @@ -251,7 +261,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
const size_t shape_size,
const size_t input_size,
const size_t result_size,
_Descriptor_type& desc,
size_t inverse,
const size_t norm,
const size_t real)
Expand All @@ -260,14 +269,15 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
(void)input_size;

DPCTLSyclEventRef event_ref = nullptr;
if (!shape_size) {
if (!shape_size)
{
return event_ref;
}

sycl::queue queue = *(reinterpret_cast<sycl::queue*>(q_ref));

_DataType_input* array_1 = static_cast<_DataType_input *>(const_cast<void *>(array1_in));
_DataType_output* result = static_cast<_DataType_output *>(result_out);
_DataType_input* array_1 = static_cast<_DataType_input*>(const_cast<void*>(array1_in));
_DataType_output* result = static_cast<_DataType_output*>(result_out);

const size_t n_iter =
std::accumulate(input_shape, input_shape + shape_size - 1, 1, std::multiplies<shape_elem_type>());
Expand All @@ -278,38 +288,52 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
double backward_scale = 1.;
double forward_scale = 1.;

if (norm == 0) { // norm = "backward"
if (inverse) {
if (norm == 0) // norm = "backward"
{
if (inverse)
{
forward_scale = 1. / result_shift;
} else {
}
else
{
backward_scale = 1. / result_shift;
}
} else if (norm == 1) { // norm = "forward"
if (inverse) {
}
else if (norm == 1) // norm = "forward"
{
if (inverse)
{
backward_scale = 1. / result_shift;
} else {
}
else
{
forward_scale = 1. / result_shift;
}
} else { // norm = "ortho"
}
else // norm = "ortho"
{
forward_scale = 1. / sqrt(result_shift);
}

desc.set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc.set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
desc.set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
std::vector<sycl::event> fft_events(n_iter);

desc.commit(queue);

std::vector<sycl::event> fft_events;
fft_events.reserve(n_iter);

for (size_t i = 0; i < n_iter; ++i) {
fft_events.push_back(mkl_dft::compute_forward(desc, array_1 + i * input_shift, result + i * result_shift * 2));
for (size_t i = 0; i < n_iter; ++i)
{
std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(input_shift);
desc->set_value(mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
desc->set_value(mkl_dft::config_param::FORWARD_SCALE, forward_scale);
desc->set_value(mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
desc->commit(queue);

// real result_size = 2 * result_size, because real type of "result" is twice wider than '_DataType_output'
fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
*desc, array_1 + i * input_shift, result + i * result_shift * 2);
}

sycl::event::wait(fft_events);

if (real) { // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
if (real) // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
{
return event_ref;
}

Expand All @@ -325,19 +349,22 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
size_t j = global_id[1];
{
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * (i + 1) - (j + 1)) =
std::conj(*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
std::conj(
*(reinterpret_cast<std::complex<_DataType_output>*>(result) + result_shift * i + (j + 1)));
}
}
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.parallel_for<class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
cgh.parallel_for<
class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel<_DataType_input, _DataType_output, _Descriptor_type>>(
gws, kernel_parallel_for_func);
};

event = queue.submit(kernel_func);

if (inverse) {
if (inverse)
{
event.wait();
event = oneapi::mkl::vm::conj(queue,
result_size,
Expand All @@ -346,7 +373,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
}

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand Down Expand Up @@ -375,43 +401,35 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
const size_t input_size =
std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<shape_elem_type>());

size_t dim = input_shape[shape_size - 1];

if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
std::is_same<_DataType_output, std::complex<double>>::value)
{
if constexpr (std::is_same<_DataType_input, std::complex<double>>::value &&
std::is_same<_DataType_output, std::complex<double>>::value)
{
desc_dp_cmplx_t desc(dim);
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
}
/* complex-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, std::complex<float>>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_cmplx_t desc(dim);
dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
}
/* real-to-complex, double precision */
else if constexpr (std::is_same<_DataType_input, double>::value &&
std::is_same<_DataType_output, std::complex<double>>::value)
{
desc_dp_real_t desc(dim);

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0);
}
/* real-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, float>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_real_t desc(dim); // try: 2 * result_size

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0);
}
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
Expand All @@ -428,9 +446,8 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);

desc_dp_real_t desc(dim);
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0);
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0);

DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
Expand Down Expand Up @@ -537,26 +554,21 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
const size_t input_size =
std::accumulate(input_shape, input_shape + shape_size, 1, std::multiplies<shape_elem_type>());

size_t dim = input_shape[shape_size - 1];

if constexpr (std::is_same<_DataType_output, std::complex<float>>::value ||
std::is_same<_DataType_output, std::complex<double>>::value)
{
if constexpr (std::is_same<_DataType_input, double>::value &&
std::is_same<_DataType_output, std::complex<double>>::value)
std::is_same<_DataType_output, std::complex<double>>::value)
{
desc_dp_real_t desc(dim);

event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, desc_dp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1);
}
/* real-to-complex, single precision */
else if constexpr (std::is_same<_DataType_input, float>::value &&
std::is_same<_DataType_output, std::complex<float>>::value)
{
desc_sp_real_t desc(dim); // try: 2 * result_size
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, desc_sp_real_t>(
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1);
}
else if constexpr (std::is_same<_DataType_input, int32_t>::value ||
std::is_same<_DataType_input, int64_t>::value)
Expand All @@ -573,9 +585,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);

desc_dp_real_t desc(dim);
event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double, double, desc_dp_real_t>(
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1);
q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1);

DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
Expand Down