Skip to content

Commit 60864f3

Browse files
authored
Use device mem instead of managed for fft workbuf (#467)
1 parent c47d098 commit 60864f3

File tree

1 file changed

+15
-10
lines changed
  • include/matx/transforms

1 file changed

+15
-10
lines changed

include/matx/transforms/fft.h

+15-10
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
172172
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
173173
FftParams_t params;
174174

175+
// Default to default stream, but caller will generally overwrite this
176+
params.stream = 0;
177+
175178
params.irank = i.Rank();
176179
params.orank = o.Rank();
177180

@@ -429,9 +432,11 @@ class matxFFTPlan1D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
429432
* Output view
430433
* @param i
431434
* Input view
435+
* @param stream
436+
* CUDA stream in which device memory allocations may be made
432437
*
433438
* */
434-
matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
439+
matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0)
435440
{
436441
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
437442

@@ -468,8 +473,7 @@ matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
468473
&workspaceSize, this->params_.exec_type);
469474
MATX_ASSERT(error == CUFFT_SUCCESS, matxCufftError);
470475

471-
matxAlloc((void **)&this->workspace_, workspaceSize);
472-
cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0);
476+
matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
473477
cufftSetWorkArea(this->plan_, this->workspace_);
474478

475479
error = cufftXtMakePlanMany(
@@ -531,6 +535,8 @@ virtual void inline Exec(OutTensorType &o, const InTensorType &i,
531535
* Output view data type
532536
* @tparam T2
533537
* Input view data type
538+
* @param stream
539+
* CUDA stream in which device memory allocations may be made
534540
*/
535541
template <typename OutTensorType, typename InTensorType = OutTensorType>
536542
class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
@@ -548,7 +554,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
548554
* Input view
549555
*
550556
* */
551-
matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i)
557+
matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0)
552558
{
553559
static_assert(RANK >= 2, "2D FFTs require a rank-2 tensor or higher");
554560

@@ -595,8 +601,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
595601
this->params_.output_type, this->params_.batch,
596602
&workspaceSize, this->params_.exec_type);
597603

598-
matxAlloc((void **)&this->workspace_, workspaceSize);
599-
cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0);
604+
matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
600605
cufftSetWorkArea(this->plan_, this->workspace_);
601606

602607
error = cufftXtMakePlanMany(
@@ -892,7 +897,7 @@ __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i,
892897
// Get cache or new FFT plan if it doesn't exist
893898
auto ret = detail::cache_1d.Lookup(params);
894899
if (ret == std::nullopt) {
895-
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in};
900+
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in, stream};
896901
detail::cache_1d.Insert(params, static_cast<void *>(tmp));
897902
tmp->Forward(out, in, stream, norm);
898903
}
@@ -935,7 +940,7 @@ __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i,
935940
// Get cache or new FFT plan if it doesn't exist
936941
auto ret = detail::cache_1d.Lookup(params);
937942
if (ret == std::nullopt) {
938-
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in};
943+
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in, stream};
939944
detail::cache_1d.Insert(params, static_cast<void *>(tmp));
940945
tmp->Inverse(out, in, stream, norm);
941946
}
@@ -991,7 +996,7 @@ __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i,
991996
// Get cache or new FFT plan if it doesn't exist
992997
auto ret = detail::cache_2d.Lookup(params);
993998
if (ret == std::nullopt) {
994-
auto tmp = new detail::matxFFTPlan2D_t<decltype(out), decltype(in)>{out, in};
999+
auto tmp = new detail::matxFFTPlan2D_t<decltype(out), decltype(in)>{out, in, stream};
9951000
detail::cache_2d.Insert(params, static_cast<void *>(tmp));
9961001
tmp->Forward(out, in, stream);
9971002
}
@@ -1047,7 +1052,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i,
10471052
// Get cache or new FFT plan if it doesn't exist
10481053
auto ret = detail::cache_2d.Lookup(params);
10491054
if (ret == std::nullopt) {
1050-
auto tmp = new detail::matxFFTPlan2D_t<decltype(in), decltype(out)>{out, in};
1055+
auto tmp = new detail::matxFFTPlan2D_t<decltype(in), decltype(out)>{out, in, stream};
10511056
detail::cache_2d.Insert(params, static_cast<void *>(tmp));
10521057
tmp->Inverse(out, in, stream);
10531058
}

0 commit comments

Comments
 (0)