Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenMP for FFTW #2040

Merged
merged 2 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Source/FieldSolver/SpectralSolver/SpectralFieldData.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ SpectralFieldData::ForwardTransform (const int lev,
#endif

// Loop over boxes
// Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for
// the FFTs on each box!
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
{
Expand Down Expand Up @@ -247,6 +249,8 @@ SpectralFieldData::BackwardTransform( const int lev,
#endif

// Loop over boxes
// Note: we do NOT OpenMP parallelize here, since we use OpenMP threads for
// the iFFTs on each box!
for ( MFIter mfi(mf); mfi.isValid(); ++mfi ){
if (cost && WarpX::load_balance_costs_update_algo == LoadBalanceCostsUpdateAlgo::Timers)
{
Expand Down
12 changes: 11 additions & 1 deletion Source/FieldSolver/SpectralSolver/WrapFFTW.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2019-2020
/* Copyright 2019-2021
*
* This file is part of WarpX.
*
Expand Down Expand Up @@ -32,6 +32,16 @@ namespace AnyFFT
{
FFTplan fft_plan;

#if defined(AMREX_USE_OMP) && defined(WarpX_FFTW_OMP)
# ifdef AMREX_USE_FLOAT
fftwf_init_threads();
fftwf_plan_with_nthreads(omp_get_max_threads());
# else
fftw_init_threads();
fftw_plan_with_nthreads(omp_get_max_threads());
# endif
#endif

// Initialize fft_plan.m_plan with the vendor fft plan.
// Swap dimensions: AMReX FAB are Fortran-order but FFTW is C-order
if (dir == direction::R2C){
Expand Down
105 changes: 75 additions & 30 deletions cmake/dependencies/FFT.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,54 @@
if(WarpX_PSATD)
# Helper Functions ############################################################
#
option(WarpX_FFTW_IGNORE_OMP "Ignore FFTW3 OpenMP support, even if found" OFF)
mark_as_advanced(WarpX_FFTW_IGNORE_OMP)

# Set the WarpX_FFTW_OMP=1 define on WarpX::thirdparty::FFT if TRUE and print
# a message
#
function(fftw_add_define HAS_FFTW_OMP_LIB)
if(HAS_FFTW_OMP_LIB)
message(STATUS "FFTW: Found OpenMP support")
target_compile_definitions(WarpX::thirdparty::FFT INTERFACE WarpX_FFTW_OMP=1)
else()
message(STATUS "FFTW: Could NOT find OpenMP support")
endif()
endfunction()

# Check if the PkgConfig target location has an _omp library, e.g.,
# libfftw3(f)_omp.a shipped and if yes, set the WarpX_FFTW_OMP=1 define.
#
function(fftw_check_omp library_paths fftw_precision_suffix)
if(WarpX_FFTW_IGNORE_OMP)
fftw_add_define(FALSE)
return()
endif()

find_library(HAS_FFTW_OMP_LIB fftw3${fftw_precision_suffix}_omp
PATHS ${library_paths}
NO_DEFAULT_PATH
NO_PACKAGE_ROOT_PATH
NO_CMAKE_PATH
NO_CMAKE_ENVIRONMENT_PATH
NO_SYSTEM_ENVIRONMENT_PATH
NO_CMAKE_SYSTEM_PATH
NO_CMAKE_FIND_ROOT_PATH
)
if(HAS_FFTW_OMP_LIB)
# the .pc files here forget to link the _omp.a/so files
# explicitly - we add those manually to avoid any trouble,
# e.g., in static builds.
target_link_libraries(WarpX::thirdparty::FFT INTERFACE ${HAS_FFTW_OMP_LIB})
endif()

fftw_add_define("${HAS_FFTW_OMP_LIB}")
endfunction()


# Various FFT implementations that we want to use #############################
#

# cuFFT (CUDA)
# TODO: check if `find_package` search works

Expand Down Expand Up @@ -29,20 +79,18 @@ if(WarpX_PSATD)
endif()
mark_as_advanced(WarpX_FFTW_SEARCH)

# floating point precision suffixes: float, double and quad precision
if(WarpX_PRECISION STREQUAL "DOUBLE")
set(HFFTWp "")
else()
set(HFFTWp "f")
endif()

if(WarpX_FFTW_SEARCH STREQUAL CMAKE)
if(WarpX_PRECISION STREQUAL "DOUBLE")
find_package(FFTW3 CONFIG REQUIRED)
else()
find_package(FFTW3f CONFIG REQUIRED)
endif()
find_package(FFTW3${HFFTWp} CONFIG REQUIRED)
else()
if(WarpX_PRECISION STREQUAL "DOUBLE")
find_package(PkgConfig REQUIRED QUIET)
pkg_check_modules(fftw3 REQUIRED IMPORTED_TARGET fftw3)
else()
find_package(PkgConfig REQUIRED QUIET)
pkg_check_modules(fftw3f REQUIRED IMPORTED_TARGET fftw3f)
endif()
find_package(PkgConfig REQUIRED QUIET)
pkg_check_modules(fftw3${HFFTWp} REQUIRED IMPORTED_TARGET fftw3${HFFTWp})
endif()
endif()

Expand All @@ -53,28 +101,25 @@ if(WarpX_PSATD)
elseif(WarpX_COMPUTE STREQUAL HIP)
make_third_party_includes_system(roc::rocfft FFT)
else()
if(WarpX_PRECISION STREQUAL "DOUBLE")
if(FFTW3_FOUND)
# subtargets: fftw3, fftw3_threads, fftw3_omp
if(WarpX_COMPUTE STREQUAL OMP AND TARGET FFTW3::fftw3_omp)
make_third_party_includes_system(FFTW3::fftw3_omp FFT)
else()
make_third_party_includes_system(FFTW3::fftw3 FFT)
endif()
if(FFTW3_FOUND)
# subtargets: fftw3(p), fftw3(p)_threads, fftw3(p)_omp
if(WarpX_COMPUTE STREQUAL OMP AND
TARGET FFTW3::fftw3${HFFTWp}_omp AND
NOT WarpX_FFTW_IGNORE_OMP)
make_third_party_includes_system(FFTW3::fftw3${HFFTWp}_omp FFT)
fftw_add_define(TRUE)
else()
make_third_party_includes_system(PkgConfig::fftw3 FFT)
make_third_party_includes_system(FFTW3::fftw3${HFFTWp} FFT)
fftw_add_define(FALSE)
endif()
else()
if(FFTW3f_FOUND)
# subtargets: fftw3f, fftw3f_threads, fftw3f_omp
if(WarpX_COMPUTE STREQUAL OMP AND TARGET FFTW3::fftw3f_omp)
make_third_party_includes_system(FFTW3::fftw3f_omp FFT)
else()
make_third_party_includes_system(FFTW3::fftw3f FFT)
endif()
make_third_party_includes_system(PkgConfig::fftw3${HFFTWp} FFT)
if(WarpX_COMPUTE STREQUAL OMP AND
NOT WarpX_FFTW_IGNORE_OMP)
fftw_check_omp("${fftw3${HFFTWp}_LIBRARY_DIRS}" "${HFFTWp}")
else()
make_third_party_includes_system(PkgConfig::fftw3f FFT)
fftw_add_define(FALSE)
endif()
endif()
endif()
endif()
endif(WarpX_PSATD)