Skip to content

Commit

Permalink
[CUDA][cuFFT] Initialize CUDA context for cuFFT before execute is cal…
Browse files Browse the repository at this point in the history
  • Loading branch information
eqy authored and pytorchmergebot committed Oct 13, 2023
1 parent f68d6e8 commit 5a2ab7d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 1 deletion.
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/detail/LazyNVRTC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,10 @@ CUDA_STUB3(cuModuleGetFunction, CUfunction *, CUmodule, const char *);
CUDA_STUB4(cuOccupancyMaxActiveBlocksPerMultiprocessor, int *, CUfunction, int, size_t);
CUDA_STUB2(cuGetErrorString, CUresult, const char **);
CUDA_STUB1(cuCtxGetCurrent, CUcontext *);
CUDA_STUB1(cuCtxSetCurrent, CUcontext);
CUDA_STUB1(cuModuleUnload, CUmodule);
CUDA_STUB3(cuDevicePrimaryCtxGetState, CUdevice, unsigned int *, int *);
CUDA_STUB2(cuDevicePrimaryCtxRetain, CUcontext *, CUdevice);
CUDA_STUB4(cuLinkCreate, unsigned int, CUjit_option *, void **, CUlinkState *);
CUDA_STUB3(cuLinkComplete, CUlinkState, void **, size_t *);
CUDA_STUB3(cuFuncSetAttribute, CUfunction, CUfunction_attribute, int);
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ namespace at { namespace cuda {
_(cuLaunchKernel) \
_(cuLaunchCooperativeKernel) \
_(cuCtxGetCurrent) \
_(cuCtxSetCurrent) \
_(cuModuleUnload) \
_(cuDevicePrimaryCtxGetState) \
_(cuDevicePrimaryCtxRetain) \
_(cuLinkCreate) \
_(cuLinkAddData) \
_(cuLinkComplete) \
Expand Down
13 changes: 12 additions & 1 deletion aten/src/ATen/native/cuda/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/native/SpectralOpsUtils.h>
#include <ATen/native/cuda/CuFFTUtils.h>
#include <ATen/native/cuda/CuFFTPlanCache.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/util/irange.h>

#ifndef AT_PER_OPERATOR_HEADERS
Expand All @@ -27,7 +28,6 @@
#include <cufftXt.h>

#include <cmath>
#include <vector>


namespace at::native {
Expand Down Expand Up @@ -304,6 +304,17 @@ static const Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_
CUFFT_CHECK(cufftSetWorkArea(plan, workspace.mutable_data_ptr()));

// execute transform plan
#if !defined(USE_ROCM)
CUcontext pctx = nullptr;
at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
if (C10_UNLIKELY(!pctx)) {
// workaround for corner case where a primary context exists but is not
// the current context
TORCH_WARN_ONCE("Attempting to run cuFFT, but there was no current CUDA context! Attempting to set the primary context...");
at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, 0);
at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
}
#endif /* !defined(USE_ROCM) */
exec_cufft_plan(*config, input.data_ptr(), out.data_ptr(), forward);

// Inplace reshaping to original batch shape and inverting the dimension permutation
Expand Down
16 changes: 16 additions & 0 deletions test/test_spectral_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,22 @@ def plan_cache_max_size(device, n):
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1

@onlyCUDA
@dtypes(torch.cfloat, torch.cdouble)
def test_cufft_context(self, device, dtype):
# Regression test for https://github.com/pytorch/pytorch/issues/109448
x = torch.randn(32, dtype=dtype, device=device, requires_grad=True)
dout = torch.zeros(32, dtype=dtype, device=device)

# compute iFFT(FFT(x))
out = torch.fft.ifft(torch.fft.fft(x))
out.backward(dout, retain_graph=True)

dx = torch.fft.fft(torch.fft.ifft(dout))

self.assertTrue((x.grad - dx).abs().max() == 0)
self.assertFalse((x.grad - x).abs().max() == 0)

# passes on ROCm w/ python 2.7, fails w/ python 3.6
@skipCPUIfNoFFT
@onlyNativeDeviceTypes
Expand Down

0 comments on commit 5a2ab7d

Please sign in to comment.