Skip to content

Commit

Permalink
Refactor AnyFFT (#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSinn authored May 2, 2024
1 parent 9ce6af5 commit 34711d3
Show file tree
Hide file tree
Showing 30 changed files with 1,399 additions and 1,738 deletions.
30 changes: 23 additions & 7 deletions docs/source/run/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,31 @@ The default is to use the explicit solver. **We strongly recommend to use the ex
Which solver to use.
Possible values: ``explicit`` and ``predictor-corrector``.

* ``fields.poisson_solver`` (`string`) optional (default `FFTDirichlet`)
* ``fields.poisson_solver`` (`string`) optional (default CPU: `FFTDirichletDirect`, GPU: `FFTDirichletFast`)
Which Poisson solver to use for ``Psi``, ``Ez`` and ``Bz``. The ``predictor-corrector`` BxBy
solver also uses this poisson solver for ``Bx`` and ``By`` internally. Available solvers are
``FFTDirichlet``, ``FFTPeriodic`` and ``MGDirichlet``.
solver also uses this poisson solver for ``Bx`` and ``By`` internally. Available solvers are:

* ``hipace.use_small_dst`` (`bool`) optional (default `0` or `1`)
Whether to use a large R2C or a small C2R fft in the dst of the Poisson solver.
The small dst is quicker for simulations with :math:`\geq 511` transverse grid points.
The default is set accordingly.
* ``FFTDirichletDirect`` Use the discrete sine transformation that is directly implemented
by FFTW to solve the Poisson equation with Dirichlet boundary conditions.
This option is only available when compiling for CPUs with FFTW.
Preferred resolution: :math:`2^N-1`.

* ``FFTDirichletExpanded`` Perform the discrete sine transformation by symmetrically
expanding the field to twice its size.
Preferred resolution: :math:`2^N-1`.

* ``FFTDirichletFast`` Perform the discrete sine transformation using a fast sine transform
algorithm that uses FFTs of the same size as the fields.
Preferred resolution: :math:`2^N-1`.

* ``MGDirichlet`` Use the HiPACE++ multigrid solver to solve the Poisson equation with
Dirichlet boundary conditions.
Preferred resolution: :math:`2^N` and :math:`2^N-1`.

* ``FFTPeriodic`` Use FFTs to solve the Poisson equation with Periodic boundary conditions.
Note that this does not work with features that change the boundary values,
like mesh refinement or open boundaries.
Preferred resolution: :math:`2^N`.

* ``fields.extended_solve`` (`bool`) optional (default `0`)
Extends the area of the FFT Poisson solver to the ghost cells. This can reduce artifacts
Expand Down
3 changes: 3 additions & 0 deletions src/Hipace.H
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ struct Hipace_early_init
* and Parser Constants */
Hipace_early_init (Hipace* instance);

/** Destructor for FFT cleanup */
~Hipace_early_init ();

/** Struct containing physical constants (which values depends on the unit system, determined
* at runtime): SI or normalized units. */
PhysConst m_phys_const;
Expand Down
7 changes: 7 additions & 0 deletions src/Hipace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "utils/GPUUtil.H"
#include "particles/pusher/GetAndSetPosition.H"
#include "mg_solver/HpMultiGrid.H"
#include "fields/fft_poisson_solver/fft/AnyFFT.H"

#include <AMReX_ParmParse.H>
#include <AMReX_IntVect.H>
Expand Down Expand Up @@ -52,6 +53,12 @@ Hipace_early_init::Hipace_early_init (Hipace* instance)
int max_level = 0;
queryWithParser(pp_amr, "max_level", max_level);
m_N_level = max_level + 1;
AnyFFT::setup();
}

Hipace_early_init::~Hipace_early_init ()
{
AnyFFT::cleanup();
}

Hipace&
Expand Down
2 changes: 1 addition & 1 deletion src/fields/Fields.H
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ private:
/** Vector over levels of all required fields to compute current slice */
amrex::Vector<amrex::MultiFab> m_slices;
/** Type of poisson solver to use */
std::string m_poisson_solver_str = "FFTDirichlet";
std::string m_poisson_solver_str = "";
/** Class to handle transverse FFT Poisson solver on 1 slice */
amrex::Vector<std::unique_ptr<FFTPoissonSolver>> m_poisson_solver;
/** Stores temporary values for z interpolation in Fields::Copy */
Expand Down
33 changes: 26 additions & 7 deletions src/fields/Fields.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
*/
#include "Fields.H"
#include "fft_poisson_solver/FFTPoissonSolverPeriodic.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichlet.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichletDirect.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichletExpanded.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichletFast.H"
#include "fft_poisson_solver/MGPoissonSolverDirichlet.H"
#include "Hipace.H"
#include "OpenBoundary.H"
Expand All @@ -29,6 +31,12 @@ Fields::Fields (const int nlev)
{
amrex::ParmParse ppf("fields");
DeprecatedInput("fields", "do_dirichlet_poisson", "poisson_solver", "");
// set default Poisson solver based on the platform
#ifdef AMREX_USE_GPU
m_poisson_solver_str = "FFTDirichletFast";
#else
m_poisson_solver_str = "FFTDirichletDirect";
#endif
queryWithParser(ppf, "poisson_solver", m_poisson_solver_str);
queryWithParser(ppf, "extended_solve", m_extended_solve);
queryWithParser(ppf, "open_boundary", m_open_boundary);
Expand Down Expand Up @@ -178,11 +186,21 @@ Fields::AllocData (
// The Poisson solver operates on transverse slices only.
// The constructor takes the BoxArray and the DistributionMap of a slice,
// so the FFTPlans are built on a slice.
if (m_poisson_solver_str == "FFTDirichlet"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichlet>(
new FFTPoissonSolverDirichlet(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
if (m_poisson_solver_str == "FFTDirichletDirect"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletDirect>(
new FFTPoissonSolverDirichletDirect(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
} else if (m_poisson_solver_str == "FFTDirichletExpanded"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletExpanded>(
new FFTPoissonSolverDirichletExpanded(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
} else if (m_poisson_solver_str == "FFTDirichletFast"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletFast>(
new FFTPoissonSolverDirichletFast(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
} else if (m_poisson_solver_str == "FFTPeriodic") {
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverPeriodic>(
new FFTPoissonSolverPeriodic(getSlices(lev).boxArray(),
Expand All @@ -195,7 +213,8 @@ Fields::AllocData (
geom)) );
} else {
amrex::Abort("Unknown poisson solver '" + m_poisson_solver_str +
"', must be 'FFTDirichlet', 'FFTPeriodic' or 'MGDirichlet'");
"', must be 'FFTDirichletDirect', 'FFTDirichletExpanded', 'FFTDirichletFast', " +
"'FFTPeriodic' or 'MGDirichlet'");
}

if (lev == 0 && m_insitu_period > 0) {
Expand Down
4 changes: 3 additions & 1 deletion src/fields/fft_poisson_solver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ target_sources(HiPACE
PRIVATE
FFTPoissonSolver.cpp
FFTPoissonSolverPeriodic.cpp
FFTPoissonSolverDirichlet.cpp
FFTPoissonSolverDirichletDirect.cpp
FFTPoissonSolverDirichletExpanded.cpp
FFTPoissonSolverDirichletFast.cpp
MGPoissonSolverDirichlet.cpp
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
* Authors: AlexanderSinn, MaxThevenet, Severin Diederichs
* License: BSD-3-Clause-LBNL
*/
#ifndef FFT_POISSON_SOLVER_DIRICHLET_H_
#define FFT_POISSON_SOLVER_DIRICHLET_H_
#ifndef FFT_POISSON_SOLVER_DIRICHLET_DIRECT_H_
#define FFT_POISSON_SOLVER_DIRICHLET_DIRECT_H_

#include "fields/fft_poisson_solver/fft/AnyDST.H"
#include "fields/fft_poisson_solver/fft/AnyFFT.H"
#include "FFTPoissonSolver.H"

#include <AMReX_MultiFab.H>
Expand All @@ -23,16 +23,16 @@
* 2. Call FFTPoissonSolver::SolvePoissonEquation(mf), which will solve Poisson equation with RHS
* in the staging area and return the LHS in mf.
*/
class FFTPoissonSolverDirichlet final : public FFTPoissonSolver
class FFTPoissonSolverDirichletDirect final : public FFTPoissonSolver
{
public:
/** Constructor */
FFTPoissonSolverDirichlet ( amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm);
FFTPoissonSolverDirichletDirect ( amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm);

/** virtual destructor */
virtual ~FFTPoissonSolverDirichlet () override final {}
virtual ~FFTPoissonSolverDirichletDirect () override final {}

/**
* \brief Define real space and spectral space boxes and multifabs, Dirichlet
Expand Down Expand Up @@ -63,8 +63,12 @@ private:
amrex::MultiFab m_tmpSpectralField;
/** Multifab eigenvalues, to solve Poisson equation with Dirichlet BC. */
amrex::MultiFab m_eigenvalue_matrix;
/** DST plans */
AnyDST::DSTplans m_plan;
/** forward DST plan */
AnyFFT m_forward_fft;
/** backward DST plan */
AnyFFT m_backward_fft;
/** work area for both DST plans */
amrex::Gpu::DeviceVector<char> m_fft_work_area;
};

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
*
* License: BSD-3-Clause-LBNL
*/
#include "FFTPoissonSolverDirichlet.H"
#include "fft/AnyDST.H"
#include "FFTPoissonSolverDirichletDirect.H"
#include "fft/AnyFFT.H"
#include "fields/Fields.H"
#include "utils/Constants.H"
#include "utils/GPUUtil.H"
#include "utils/HipaceProfilerWrapper.H"

FFTPoissonSolverDirichlet::FFTPoissonSolverDirichlet (
FFTPoissonSolverDirichletDirect::FFTPoissonSolverDirichletDirect (
amrex::BoxArray const& realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm )
Expand All @@ -22,10 +22,11 @@ FFTPoissonSolverDirichlet::FFTPoissonSolverDirichlet (
}

void
FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm )
FFTPoissonSolverDirichletDirect::define (amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm )
{
HIPACE_PROFILE("FFTPoissonSolverDirichletDirect::define()");
using namespace amrex::literals;

// If we are going to support parallel FFT, the constructor needs to take a communicator.
Expand All @@ -48,16 +49,18 @@ FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
"There should be only one box locally.");

const amrex::Box fft_box = m_stagingArea[0].box();
const amrex::IntVect fft_size = fft_box.length();
const int nx = fft_size[0];
const int ny = fft_size[1];
const auto dx = gm.CellSizeArray();
const amrex::Real dxsquared = dx[0]*dx[0];
const amrex::Real dysquared = dx[1]*dx[1];
const amrex::Real sine_x_factor = MathConst::pi / ( 2. * ( fft_box.length(0) + 1 ));
const amrex::Real sine_y_factor = MathConst::pi / ( 2. * ( fft_box.length(1) + 1 ));
const amrex::Real sine_x_factor = MathConst::pi / ( 2. * ( nx + 1 ));
const amrex::Real sine_y_factor = MathConst::pi / ( 2. * ( ny + 1 ));

// Normalization of FFTW's 'DST-I' discrete sine transform (FFTW_RODFT00)
// This normalization is used regardless of the sine transform library
const amrex::Real norm_fac = 0.5 / ( 2 * (( fft_box.length(0) + 1 )
*( fft_box.length(1) + 1 )));
const amrex::Real norm_fac = 0.5 / ( 2 * (( nx + 1 ) * ( ny + 1 )));

// Calculate the array of m_eigenvalue_matrix
for (amrex::MFIter mfi(m_eigenvalue_matrix, DfltMfi); mfi.isValid(); ++mfi ){
Expand All @@ -67,9 +70,9 @@ FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
fft_box, [=] AMREX_GPU_DEVICE (int i, int j, int /* k */) noexcept
{
/* fast poisson solver diagonal x coeffs */
amrex::Real sinex_sq = sin(( i - lo[0] + 1 ) * sine_x_factor) * sin(( i - lo[0] + 1 ) * sine_x_factor);
amrex::Real sinex_sq = std::sin(( i - lo[0] + 1 ) * sine_x_factor) * std::sin(( i - lo[0] + 1 ) * sine_x_factor);
/* fast poisson solver diagonal y coeffs */
amrex::Real siney_sq = sin(( j - lo[1] + 1 ) * sine_y_factor) * sin(( j - lo[1] + 1 ) * sine_y_factor);
amrex::Real siney_sq = std::sin(( j - lo[1] + 1 ) * sine_y_factor) * std::sin(( j - lo[1] + 1 ) * sine_y_factor);

if ((sinex_sq!=0) && (siney_sq!=0)) {
eigenvalue_matrix(i,j) = norm_fac / ( -4.0 * ( sinex_sq / dxsquared + siney_sq / dysquared ));
Expand All @@ -81,29 +84,25 @@ FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
}

// Allocate and initialize the FFT plans
m_plan = AnyDST::DSTplans(a_realspace_ba, dm);
// Loop over boxes and allocate the corresponding plan
// for each box owned by the local MPI proc
for ( amrex::MFIter mfi(m_stagingArea, DfltMfi); mfi.isValid(); ++mfi ){
// Note: the size of the real-space box and spectral-space box
// differ when using real-to-complex FFT. When initializing
// the FFT plan, the valid dimensions are those of the real-space box.
amrex::IntVect fft_size = fft_box.length();
m_plan[mfi] = AnyDST::CreatePlan(
fft_size, &m_stagingArea[mfi], &m_tmpSpectralField[mfi]);
}
std::size_t fwd_area = m_forward_fft.Initialize(FFTType::R2R_2D, fft_size[0], fft_size[1]);
std::size_t bkw_area = m_backward_fft.Initialize(FFTType::R2R_2D, fft_size[0], fft_size[1]);

// Allocate work area for both FFTs
m_fft_work_area.resize(std::max(fwd_area, bkw_area));

m_forward_fft.SetBuffers(m_stagingArea[0].dataPtr(), m_tmpSpectralField[0].dataPtr(),
m_fft_work_area.dataPtr());
m_backward_fft.SetBuffers(m_tmpSpectralField[0].dataPtr(), m_stagingArea[0].dataPtr(),
m_fft_work_area.dataPtr());
}


void
FFTPoissonSolverDirichlet::SolvePoissonEquation (amrex::MultiFab& lhs_mf)
FFTPoissonSolverDirichletDirect::SolvePoissonEquation (amrex::MultiFab& lhs_mf)
{
HIPACE_PROFILE("FFTPoissonSolverDirichlet::SolvePoissonEquation()");
HIPACE_PROFILE("FFTPoissonSolverDirichletDirect::SolvePoissonEquation()");

for ( amrex::MFIter mfi(m_stagingArea, DfltMfi); mfi.isValid(); ++mfi ){
// Perform Fourier transform from the staging area to `tmpSpectralField`
AnyDST::Execute(m_plan[mfi], AnyDST::direction::forward);
}
m_forward_fft.Execute();

#ifdef AMREX_USE_OMP
#pragma omp parallel if (amrex::Gpu::notInLaunchRegion())
Expand All @@ -120,10 +119,7 @@ FFTPoissonSolverDirichlet::SolvePoissonEquation (amrex::MultiFab& lhs_mf)
});
}

for ( amrex::MFIter mfi(m_stagingArea, DfltMfi); mfi.isValid(); ++mfi ){
// Perform Fourier transform from `tmpSpectralField` to the staging area
AnyDST::Execute(m_plan[mfi], AnyDST::direction::backward);
}
m_backward_fft.Execute();

#ifdef AMREX_USE_OMP
#pragma omp parallel if (amrex::Gpu::notInLaunchRegion())
Expand Down
Loading

0 comments on commit 34711d3

Please sign in to comment.