Skip to content

Commit 5c5258e

Browse files
authored
Merge pull request #37 from iotamudelta/rocfft_improvements
Enable FFT operations on the GPU for ROCm
2 parents c4de7e7 + 3f94bd8 commit 5c5258e

File tree

2 files changed

+14
-28
lines changed

2 files changed

+14
-28
lines changed

tensorflow/stream_executor/rocm/rocm_fft.cc

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,12 @@ namespace wrap {
7070
__macro(hipfftMakePlan3d) \
7171
__macro(hipfftGetSizeMany) \
7272
__macro(hipfftMakePlanMany) \
73-
74-
// ROCM TODO disable before rocFFT uses proper HIP complex types
75-
//__macro(hipfftExecD2Z) \
76-
//__macro(hipfftExecZ2D) \
77-
//__macro(hipfftExecC2C) \
78-
//__macro(hipfftExecC2R) \
79-
//__macro(hipfftExecZ2Z) \
80-
//__macro(hipfftExecR2C) \
73+
__macro(hipfftExecD2Z) \
74+
__macro(hipfftExecZ2D) \
75+
__macro(hipfftExecC2C) \
76+
__macro(hipfftExecC2R) \
77+
__macro(hipfftExecZ2Z) \
78+
__macro(hipfftExecR2C) \
8179

8280
HIPFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_HIPFFT_WRAP)
8381

@@ -536,32 +534,20 @@ bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
536534
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
537535
const DeviceMemory<std::complex<__type>> &input, \
538536
DeviceMemory<std::complex<__type>> *output) { \
539-
LOG(ERROR) << "rocFFT does not current support complex<float> " \
540-
<< "/ complex<double> datatypes" ; \
541-
/* ROCM TODO disable for now until rocFFT properly honors HIP complex types */ \
542-
/* return DoFftWithDirectionInternal( */ \
543-
/* stream, plan, wrap::hipfftExec##__fft_type1, input, output); */ \
544-
return false; \
537+
return DoFftWithDirectionInternal( \
538+
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
545539
} \
546540
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
547541
const DeviceMemory<__type> &input, \
548542
DeviceMemory<std::complex<__type>> *output) { \
549-
LOG(ERROR) << "rocFFT does not current support complex<float> " \
550-
<< "/ complex<double> datatypes" ; \
551-
/* ROCM TODO disable for now until rocFFT properly honors HIP complex types */ \
552-
/* return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, */ \
553-
/* output); */ \
554-
return false; \
543+
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
544+
output); \
555545
} \
556546
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
557547
const DeviceMemory<std::complex<__type>> &input, \
558548
DeviceMemory<__type> *output) { \
559-
LOG(ERROR) << "rocFFT does not current support complex<float> " \
560-
<< "/ complex<double> datatypes" ; \
561-
/* ROCM TODO disable for now until rocFFT properly honors HIP complex types */ \
562-
/* return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, */ \
563-
/* output); */ \
564-
return false; \
549+
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
550+
output); \
565551
}
566552

567553
PERFTOOLS_GPUTOOLS_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)

tensorflow/stream_executor/rocm/rocm_fft.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
// ROCM-specific support for FFT functionality -- this wraps the hipFFT library
16+
// ROCM-specific support for FFT functionality -- this wraps the rocFFT library
1717
// capabilities, and is only included into ROCM implementation code -- it will
1818
// not introduce rocm headers into other code.
1919

@@ -86,7 +86,7 @@ class ROCMFftPlan : public fft::Plan {
8686
bool is_initialized_;
8787
};
8888

89-
// FFT support for ROCM platform via hipFFT library.
89+
// FFT support for ROCM platform via rocFFT library.
9090
//
9191
// This satisfies the platform-agnostic FftSupport interface.
9292
//

0 commit comments

Comments
 (0)