diff --git a/include/matx/transforms/fft.h b/include/matx/transforms/fft.h index 292369ba0..4711a4c30 100644 --- a/include/matx/transforms/fft.h +++ b/include/matx/transforms/fft.h @@ -172,6 +172,9 @@ template class matxFFTPlan_t { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) FftParams_t params; + // Default to default stream, but caller will generally overwrite this + params.stream = 0; + params.irank = i.Rank(); params.orank = o.Rank(); @@ -429,9 +432,11 @@ class matxFFTPlan1D_t : public matxFFTPlan_t { * Output view * @param i * Input view + * @param stream + * CUDA stream in which device memory allocations may be made * * */ -matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i) +matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) @@ -468,8 +473,7 @@ matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i) &workspaceSize, this->params_.exec_type); MATX_ASSERT(error == CUFFT_SUCCESS, matxCufftError); - matxAlloc((void **)&this->workspace_, workspaceSize); - cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0); + matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream); cufftSetWorkArea(this->plan_, this->workspace_); error = cufftXtMakePlanMany( @@ -531,6 +535,8 @@ virtual void inline Exec(OutTensorType &o, const InTensorType &i, * Output view data type * @tparam T2 * Input view data type + * @param stream + * CUDA stream in which device memory allocations may be made */ template class matxFFTPlan2D_t : public matxFFTPlan_t { @@ -548,7 +554,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t { * Input view * * */ - matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i) + matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0) { static_assert(RANK >= 2, "2D FFTs require a rank-2 tensor or higher"); @@ -595,8 +601,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t { this->params_.output_type, this->params_.batch, &workspaceSize, this->params_.exec_type); - matxAlloc((void **)&this->workspace_, workspaceSize); - cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0); + matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream); cufftSetWorkArea(this->plan_, this->workspace_); error = cufftXtMakePlanMany( @@ -892,7 +897,7 @@ __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i, // Get cache or new FFT plan if it doesn't exist auto ret = detail::cache_1d.Lookup(params); if (ret == std::nullopt) { - auto tmp = new detail::matxFFTPlan1D_t{out, in}; + auto tmp = new detail::matxFFTPlan1D_t{out, in, stream}; detail::cache_1d.Insert(params, static_cast(tmp)); tmp->Forward(out, in, stream, norm); } @@ -935,7 +940,7 @@ __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i, // Get cache or new FFT plan if it doesn't exist auto ret = detail::cache_1d.Lookup(params); if (ret == std::nullopt) { - auto tmp = new detail::matxFFTPlan1D_t{out, in}; + auto tmp = new detail::matxFFTPlan1D_t{out, in, stream}; detail::cache_1d.Insert(params, static_cast(tmp)); tmp->Inverse(out, in, stream, norm); } @@ -991,7 +996,7 @@ __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i, // Get cache or new FFT plan if it doesn't exist auto ret = detail::cache_2d.Lookup(params); if (ret == std::nullopt) { - auto tmp = new detail::matxFFTPlan2D_t{out, in}; + auto tmp = new detail::matxFFTPlan2D_t{out, in, stream}; detail::cache_2d.Insert(params, static_cast(tmp)); tmp->Forward(out, in, stream); } @@ -1047,7 +1052,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i, // Get cache or new FFT plan if it doesn't exist auto ret = detail::cache_2d.Lookup(params); if (ret == std::nullopt) { - auto tmp = new detail::matxFFTPlan2D_t{out, in}; + auto tmp = new detail::matxFFTPlan2D_t{out, in, stream}; detail::cache_2d.Insert(params, static_cast(tmp)); tmp->Inverse(out, in, stream); }