Skip to content

Commit

Permalink
Use device mem instead of managed for fft workbuf (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbensonatl authored Aug 23, 2023
1 parent c47d098 commit 60864f3
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions include/matx/transforms/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
FftParams_t params;

// Default to default stream, but caller will generally overwrite this
params.stream = 0;

params.irank = i.Rank();
params.orank = o.Rank();

Expand Down Expand Up @@ -429,9 +432,11 @@ class matxFFTPlan1D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
* Output view
* @param i
* Input view
* @param stream
* CUDA stream in which device memory allocations may be made
*
* */
matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

Expand Down Expand Up @@ -468,8 +473,7 @@ matxFFTPlan1D_t(OutTensorType &o, const InTensorType &i)
&workspaceSize, this->params_.exec_type);
MATX_ASSERT(error == CUFFT_SUCCESS, matxCufftError);

matxAlloc((void **)&this->workspace_, workspaceSize);
cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0);
matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
cufftSetWorkArea(this->plan_, this->workspace_);

error = cufftXtMakePlanMany(
Expand Down Expand Up @@ -531,6 +535,8 @@ virtual void inline Exec(OutTensorType &o, const InTensorType &i,
* Output view data type
* @tparam T2
* Input view data type
* @param stream
* CUDA stream in which device memory allocations may be made
*/
template <typename OutTensorType, typename InTensorType = OutTensorType>
class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
Expand All @@ -548,7 +554,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
* Input view
*
* */
matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i)
matxFFTPlan2D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream = 0)
{
static_assert(RANK >= 2, "2D FFTs require a rank-2 tensor or higher");

Expand Down Expand Up @@ -595,8 +601,7 @@ class matxFFTPlan2D_t : public matxFFTPlan_t<OutTensorType, InTensorType> {
this->params_.output_type, this->params_.batch,
&workspaceSize, this->params_.exec_type);

matxAlloc((void **)&this->workspace_, workspaceSize);
cudaMemPrefetchAsync(this->workspace_, workspaceSize, dev, 0);
matxAlloc((void **)&this->workspace_, workspaceSize, MATX_ASYNC_DEVICE_MEMORY, stream);
cufftSetWorkArea(this->plan_, this->workspace_);

error = cufftXtMakePlanMany(
Expand Down Expand Up @@ -892,7 +897,7 @@ __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i,
// Get cache or new FFT plan if it doesn't exist
auto ret = detail::cache_1d.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in};
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in, stream};
detail::cache_1d.Insert(params, static_cast<void *>(tmp));
tmp->Forward(out, in, stream, norm);
}
Expand Down Expand Up @@ -935,7 +940,7 @@ __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i,
// Get cache or new FFT plan if it doesn't exist
auto ret = detail::cache_1d.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in};
auto tmp = new detail::matxFFTPlan1D_t<decltype(out), decltype(in)>{out, in, stream};
detail::cache_1d.Insert(params, static_cast<void *>(tmp));
tmp->Inverse(out, in, stream, norm);
}
Expand Down Expand Up @@ -991,7 +996,7 @@ __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i,
// Get cache or new FFT plan if it doesn't exist
auto ret = detail::cache_2d.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxFFTPlan2D_t<decltype(out), decltype(in)>{out, in};
auto tmp = new detail::matxFFTPlan2D_t<decltype(out), decltype(in)>{out, in, stream};
detail::cache_2d.Insert(params, static_cast<void *>(tmp));
tmp->Forward(out, in, stream);
}
Expand Down Expand Up @@ -1047,7 +1052,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i,
// Get cache or new FFT plan if it doesn't exist
auto ret = detail::cache_2d.Lookup(params);
if (ret == std::nullopt) {
auto tmp = new detail::matxFFTPlan2D_t<decltype(in), decltype(out)>{out, in};
auto tmp = new detail::matxFFTPlan2D_t<decltype(in), decltype(out)>{out, in, stream};
detail::cache_2d.Insert(params, static_cast<void *>(tmp));
tmp->Inverse(out, in, stream);
}
Expand Down

0 comments on commit 60864f3

Please sign in to comment.