@@ -172,6 +172,9 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
172
172
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
173
173
FftParams_t params;
174
174
175
+ // Default to default stream, but caller will generally overwrite this
176
+ params.stream = 0 ;
177
+
175
178
params.irank = i.Rank ();
176
179
params.orank = o.Rank ();
177
180
@@ -429,9 +432,11 @@ class matxFFTPlan1D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
429
432
* Output view
430
433
* @param i
431
434
* Input view
435
+ * @param stream
436
+ * CUDA stream in which device memory allocations may be made
432
437
*
433
438
* */
434
- matxFFTPlan1D_t (OutTensorType &o, const InTensorType &i)
439
+ matxFFTPlan1D_t (OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0 )
435
440
{
436
441
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
437
442
@@ -468,8 +473,7 @@ matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
468
473
&workspaceSize, this ->params_ .exec_type );
469
474
MATX_ASSERT (error == CUFFT_SUCCESS, matxCufftError);
470
475
471
- matxAlloc ((void **)&this ->workspace_ , workspaceSize);
472
- cudaMemPrefetchAsync (this ->workspace_ , workspaceSize, dev, 0 );
476
+ matxAlloc ((void **)&this ->workspace_ , workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
473
477
cufftSetWorkArea (this ->plan_ , this ->workspace_ );
474
478
475
479
error = cufftXtMakePlanMany (
@@ -531,6 +535,8 @@ virtual void inline Exec(OutTensorType &o, const InTensorType &i,
531
535
* Output view data type
532
536
* @tparam T2
533
537
* Input view data type
538
+ * @param stream
539
+ * CUDA stream in which device memory allocations may be made
534
540
*/
535
541
template <typename OutTensorType, typename InTensorType = OutTensorType>
536
542
class matxFFTPlan2D_t : public matxFFTPlan_t <OutTensorType, InTensorType> {
@@ -548,7 +554,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
548
554
* Input view
549
555
*
550
556
* */
551
- matxFFTPlan2D_t (OutTensorType &o, const InTensorType &i)
557
+ matxFFTPlan2D_t (OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0 )
552
558
{
553
559
static_assert (RANK >= 2 , " 2D FFTs require a rank-2 tensor or higher" );
554
560
@@ -595,8 +601,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
595
601
this ->params_ .output_type , this ->params_ .batch ,
596
602
&workspaceSize, this ->params_ .exec_type );
597
603
598
- matxAlloc ((void **)&this ->workspace_ , workspaceSize);
599
- cudaMemPrefetchAsync (this ->workspace_ , workspaceSize, dev, 0 );
604
+ matxAlloc ((void **)&this ->workspace_ , workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
600
605
cufftSetWorkArea (this ->plan_ , this ->workspace_ );
601
606
602
607
error = cufftXtMakePlanMany (
@@ -892,7 +897,7 @@ __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i,
892
897
// Get cache or new FFT plan if it doesn't exist
893
898
auto ret = detail::cache_1d.Lookup (params);
894
899
if (ret == std::nullopt) {
895
- auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in};
900
+ auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in, stream };
896
901
detail::cache_1d.Insert (params, static_cast <void *>(tmp));
897
902
tmp->Forward (out, in, stream, norm);
898
903
}
@@ -935,7 +940,7 @@ __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i,
935
940
// Get cache or new FFT plan if it doesn't exist
936
941
auto ret = detail::cache_1d.Lookup (params);
937
942
if (ret == std::nullopt) {
938
- auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in};
943
+ auto tmp = new detail::matxFFTPlan1D_t<decltype (out), decltype (in)>{out, in, stream };
939
944
detail::cache_1d.Insert (params, static_cast <void *>(tmp));
940
945
tmp->Inverse (out, in, stream, norm);
941
946
}
@@ -991,7 +996,7 @@ __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i,
991
996
// Get cache or new FFT plan if it doesn't exist
992
997
auto ret = detail::cache_2d.Lookup (params);
993
998
if (ret == std::nullopt) {
994
- auto tmp = new detail::matxFFTPlan2D_t<decltype (out), decltype (in)>{out, in};
999
+ auto tmp = new detail::matxFFTPlan2D_t<decltype (out), decltype (in)>{out, in, stream };
995
1000
detail::cache_2d.Insert (params, static_cast <void *>(tmp));
996
1001
tmp->Forward (out, in, stream);
997
1002
}
@@ -1047,7 +1052,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i,
1047
1052
// Get cache or new FFT plan if it doesn't exist
1048
1053
auto ret = detail::cache_2d.Lookup (params);
1049
1054
if (ret == std::nullopt) {
1050
- auto tmp = new detail::matxFFTPlan2D_t<decltype (in), decltype (out)>{out, in};
1055
+ auto tmp = new detail::matxFFTPlan2D_t<decltype (in), decltype (out)>{out, in, stream };
1051
1056
detail::cache_2d.Insert (params, static_cast <void *>(tmp));
1052
1057
tmp->Inverse (out, in, stream);
1053
1058
}
0 commit comments