@@ -178,7 +178,6 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
178178 const size_t shape_size,
179179 const size_t input_size,
180180 const size_t result_size,
181- _Descriptor_type& desc,
182181 size_t inverse,
183182 const size_t norm)
184183{
@@ -187,14 +186,15 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
187186 (void )input_size;
188187 (void )result_size;
189188
190- if (!shape_size) {
189+ if (!shape_size)
190+ {
191191 return ;
192192 }
193193
194194 sycl::queue queue = *(reinterpret_cast <sycl::queue*>(q_ref));
195195
196- _DataType_input* array_1 = static_cast <_DataType_input *>(const_cast <void *>(array1_in));
197- _DataType_output* result = static_cast <_DataType_output *>(result_out);
196+ _DataType_input* array_1 = static_cast <_DataType_input*>(const_cast <void *>(array1_in));
197+ _DataType_output* result = static_cast <_DataType_output*>(result_out);
198198
199199 const size_t n_iter =
200200 std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
@@ -204,39 +204,49 @@ static void dpnp_fft_fft_mathlib_cmplx_to_cmplx_c(DPCTLSyclQueueRef q_ref,
204204 double backward_scale = 1 .;
205205 double forward_scale = 1 .;
206206
207- if (norm == 0 ) { // norm = "backward"
207+ if (norm == 0 ) // norm = "backward"
208+ {
208209 backward_scale = 1 . / shift;
209- } else if (norm == 1 ) { // norm = "forward"
210+ }
211+ else if (norm == 1 ) // norm = "forward"
212+ {
210213 forward_scale = 1 . / shift;
211- } else { // norm = "ortho"
212- if (inverse) {
214+ }
215+ else // norm = "ortho"
216+ {
217+ if (inverse)
218+ {
213219 backward_scale = 1 . / sqrt (shift);
214- } else {
220+ }
221+ else
222+ {
215223 forward_scale = 1 . / sqrt (shift);
216224 }
217225 }
218226
219- desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
220- desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
221- // enum value from math library C interface
222- // instead of mkl_dft::config_value::NOT_INPLACE
223- desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
224- desc.commit (queue);
225-
226- std::vector<sycl::event> fft_events;
227- fft_events.reserve (n_iter);
228-
229- for (size_t i = 0 ; i < n_iter; ++i) {
230- if (inverse) {
231- fft_events.push_back (mkl_dft::compute_backward (desc, array_1 + i * shift, result + i * shift));
232- } else {
233- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * shift, result + i * shift));
227+ std::vector<sycl::event> fft_events (n_iter);
228+
229+ for (size_t i = 0 ; i < n_iter; ++i)
230+ {
231+ std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(shift);
232+ desc->set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
233+ desc->set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
234+ desc->set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
235+ desc->commit (queue);
236+
237+ if (inverse)
238+ {
239+ fft_events[i] = mkl_dft::compute_backward<_Descriptor_type, _DataType_input, _DataType_output>(
240+ *desc, array_1 + i * shift, result + i * shift);
241+ }
242+ else
243+ {
244+ fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
245+ *desc, array_1 + i * shift, result + i * shift);
234246 }
235247 }
236248
237249 sycl::event::wait (fft_events);
238-
239- return ;
240250}
241251
242252template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2, typename _KernelNameSpecialization3>
@@ -251,7 +261,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
251261 const size_t shape_size,
252262 const size_t input_size,
253263 const size_t result_size,
254- _Descriptor_type& desc,
255264 size_t inverse,
256265 const size_t norm,
257266 const size_t real)
@@ -260,14 +269,15 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
260269 (void )input_size;
261270
262271 DPCTLSyclEventRef event_ref = nullptr ;
263- if (!shape_size) {
272+ if (!shape_size)
273+ {
264274 return event_ref;
265275 }
266276
267277 sycl::queue queue = *(reinterpret_cast <sycl::queue*>(q_ref));
268278
269- _DataType_input* array_1 = static_cast <_DataType_input *>(const_cast <void *>(array1_in));
270- _DataType_output* result = static_cast <_DataType_output *>(result_out);
279+ _DataType_input* array_1 = static_cast <_DataType_input*>(const_cast <void *>(array1_in));
280+ _DataType_output* result = static_cast <_DataType_output*>(result_out);
271281
272282 const size_t n_iter =
273283 std::accumulate (input_shape, input_shape + shape_size - 1 , 1 , std::multiplies<shape_elem_type>());
@@ -278,38 +288,52 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
278288 double backward_scale = 1 .;
279289 double forward_scale = 1 .;
280290
281- if (norm == 0 ) { // norm = "backward"
282- if (inverse) {
291+ if (norm == 0 ) // norm = "backward"
292+ {
293+ if (inverse)
294+ {
283295 forward_scale = 1 . / result_shift;
284- } else {
296+ }
297+ else
298+ {
285299 backward_scale = 1 . / result_shift;
286300 }
287- } else if (norm == 1 ) { // norm = "forward"
288- if (inverse) {
301+ }
302+ else if (norm == 1 ) // norm = "forward"
303+ {
304+ if (inverse)
305+ {
289306 backward_scale = 1 . / result_shift;
290- } else {
307+ }
308+ else
309+ {
291310 forward_scale = 1 . / result_shift;
292311 }
293- } else { // norm = "ortho"
312+ }
313+ else // norm = "ortho"
314+ {
294315 forward_scale = 1 . / sqrt (result_shift);
295316 }
296317
297- desc.set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
298- desc.set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
299- desc.set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
318+ std::vector<sycl::event> fft_events (n_iter);
300319
301- desc.commit (queue);
302-
303- std::vector<sycl::event> fft_events;
304- fft_events.reserve (n_iter);
305-
306- for (size_t i = 0 ; i < n_iter; ++i) {
307- fft_events.push_back (mkl_dft::compute_forward (desc, array_1 + i * input_shift, result + i * result_shift * 2 ));
320+ for (size_t i = 0 ; i < n_iter; ++i)
321+ {
322+ std::unique_ptr<_Descriptor_type> desc = std::make_unique<_Descriptor_type>(input_shift);
323+ desc->set_value (mkl_dft::config_param::BACKWARD_SCALE, backward_scale);
324+ desc->set_value (mkl_dft::config_param::FORWARD_SCALE, forward_scale);
325+ desc->set_value (mkl_dft::config_param::PLACEMENT, DFTI_NOT_INPLACE);
326+ desc->commit (queue);
327+
328+ // real result_size = 2 * result_size, because real type of "result" is twice wider than '_DataType_output'
329+ fft_events[i] = mkl_dft::compute_forward<_Descriptor_type, _DataType_input, _DataType_output>(
330+ *desc, array_1 + i * input_shift, result + i * result_shift * 2 );
308331 }
309332
310333 sycl::event::wait (fft_events);
311334
312- if (real) { // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
335+ if (real) // the output size of the rfft function is input_size/2 + 1 so we don't need to fill the second half of the output
336+ {
313337 return event_ref;
314338 }
315339
@@ -325,19 +349,22 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
325349 size_t j = global_id[1 ];
326350 {
327351 *(reinterpret_cast <std::complex <_DataType_output>*>(result) + result_shift * (i + 1 ) - (j + 1 )) =
328- std::conj (*(reinterpret_cast <std::complex <_DataType_output>*>(result) + result_shift * i + (j + 1 )));
352+ std::conj (
353+ *(reinterpret_cast <std::complex <_DataType_output>*>(result) + result_shift * i + (j + 1 )));
329354 }
330355 }
331356 };
332357
333358 auto kernel_func = [&](sycl::handler& cgh) {
334- cgh.parallel_for <class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel <_DataType_input, _DataType_output, _Descriptor_type>>(
359+ cgh.parallel_for <
360+ class dpnp_fft_fft_mathlib_real_to_cmplx_c_kernel <_DataType_input, _DataType_output, _Descriptor_type>>(
335361 gws, kernel_parallel_for_func);
336362 };
337363
338364 event = queue.submit (kernel_func);
339365
340- if (inverse) {
366+ if (inverse)
367+ {
341368 event.wait ();
342369 event = oneapi::mkl::vm::conj (queue,
343370 result_size,
@@ -346,7 +373,6 @@ static DPCTLSyclEventRef dpnp_fft_fft_mathlib_real_to_cmplx_c(DPCTLSyclQueueRef
346373 }
347374
348375 event_ref = reinterpret_cast <DPCTLSyclEventRef>(&event);
349-
350376 return DPCTLEvent_Copy (event_ref);
351377}
352378
@@ -375,43 +401,35 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
375401 const size_t input_size =
376402 std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<shape_elem_type>());
377403
378- size_t dim = input_shape[shape_size - 1 ];
379-
380404 if constexpr (std::is_same<_DataType_output, std::complex <float >>::value ||
381405 std::is_same<_DataType_output, std::complex <double >>::value)
382406 {
383407 if constexpr (std::is_same<_DataType_input, std::complex <double >>::value &&
384408 std::is_same<_DataType_output, std::complex <double >>::value)
385409 {
386- desc_dp_cmplx_t desc (dim);
387410 dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_dp_cmplx_t >(
388- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
411+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
389412 }
390413 /* complex-to-complex, single precision */
391414 else if constexpr (std::is_same<_DataType_input, std::complex <float >>::value &&
392415 std::is_same<_DataType_output, std::complex <float >>::value)
393416 {
394- desc_sp_cmplx_t desc (dim);
395417 dpnp_fft_fft_mathlib_cmplx_to_cmplx_c<_DataType_input, _DataType_output, desc_sp_cmplx_t >(
396- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm);
418+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm);
397419 }
398420 /* real-to-complex, double precision */
399421 else if constexpr (std::is_same<_DataType_input, double >::value &&
400422 std::is_same<_DataType_output, std::complex <double >>::value)
401423 {
402- desc_dp_real_t desc (dim);
403-
404424 event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
405- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0 );
425+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0 );
406426 }
407427 /* real-to-complex, single precision */
408428 else if constexpr (std::is_same<_DataType_input, float >::value &&
409429 std::is_same<_DataType_output, std::complex <float >>::value)
410430 {
411- desc_sp_real_t desc (dim); // try: 2 * result_size
412-
413431 event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
414- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0 );
432+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0 );
415433 }
416434 else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
417435 std::is_same<_DataType_input, int64_t >::value)
@@ -428,9 +446,8 @@ DPCTLSyclEventRef dpnp_fft_fft_c(DPCTLSyclQueueRef q_ref,
428446 DPCTLEvent_WaitAndThrow (event_ref);
429447 DPCTLEvent_Delete (event_ref);
430448
431- desc_dp_real_t desc (dim);
432449 event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
433- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 0 );
450+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 0 );
434451
435452 DPCTLEvent_WaitAndThrow (event_ref);
436453 DPCTLEvent_Delete (event_ref);
@@ -537,26 +554,21 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
537554 const size_t input_size =
538555 std::accumulate (input_shape, input_shape + shape_size, 1 , std::multiplies<shape_elem_type>());
539556
540- size_t dim = input_shape[shape_size - 1 ];
541-
542557 if constexpr (std::is_same<_DataType_output, std::complex <float >>::value ||
543558 std::is_same<_DataType_output, std::complex <double >>::value)
544559 {
545560 if constexpr (std::is_same<_DataType_input, double >::value &&
546- std::is_same<_DataType_output, std::complex <double >>::value)
561+ std::is_same<_DataType_output, std::complex <double >>::value)
547562 {
548- desc_dp_real_t desc (dim);
549-
550563 event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double , desc_dp_real_t >(
551- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1 );
564+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1 );
552565 }
553566 /* real-to-complex, single precision */
554567 else if constexpr (std::is_same<_DataType_input, float >::value &&
555568 std::is_same<_DataType_output, std::complex <float >>::value)
556569 {
557- desc_sp_real_t desc (dim); // try: 2 * result_size
558570 event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float , desc_sp_real_t >(
559- q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1 );
571+ q_ref, array1_in, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1 );
560572 }
561573 else if constexpr (std::is_same<_DataType_input, int32_t >::value ||
562574 std::is_same<_DataType_input, int64_t >::value)
@@ -573,9 +585,8 @@ DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref,
573585 DPCTLEvent_WaitAndThrow (event_ref);
574586 DPCTLEvent_Delete (event_ref);
575587
576- desc_dp_real_t desc (dim);
577588 event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c<double , double , desc_dp_real_t >(
578- q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, desc, inverse, norm, 1 );
589+ q_ref, array1_copy, result_out, input_shape, result_shape, shape_size, input_size, result_size, inverse, norm, 1 );
579590
580591 DPCTLEvent_WaitAndThrow (event_ref);
581592 DPCTLEvent_Delete (event_ref);
0 commit comments