Skip to content

Commit

Permalink
Improve performance of FFTDirichletExpanded (#1111)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderSinn authored Jun 3, 2024
1 parent 691cf8f commit 40690ff
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ public:
virtual amrex::Real BoundaryFactor() override final { return 1.; }

private:
/** Spectral fields, contains (real) field in Fourier space */
amrex::MultiFab m_tmpSpectralField;
/** Multifab eigenvalues, to solve Poisson equation with Dirichlet BC. */
amrex::MultiFab m_eigenvalue_matrix;
/** Twice expanded field that gets filled symmetrically */
Expand Down
154 changes: 91 additions & 63 deletions src/fields/fft_poisson_solver/FFTPoissonSolverDirichletExpanded.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,108 @@ FFTPoissonSolverDirichletExpanded::FFTPoissonSolverDirichletExpanded (
define(realspace_ba, dm, gm);
}

void ExpandR2R (amrex::FArrayBox& dst, amrex::FArrayBox& src)
void ExpandR2R (amrex::FArrayBox& dst, const amrex::FArrayBox& src)
{
const amrex::Box bx = src.box();
// This function expands
//
// 1 2 3
// 4 5 6
// 7 8 9
//
// into
//
// 0 0 0 0 0 0 0 0
// 0 1 2 3 0 -3 -2 -1
// 0 4 5 6 0 -6 -5 -4
// 0 7 8 9 0 -9 -8 -7
// 0 0 0 0 0 0 0 0
// 0 -7 -8 -9 0 9 8 7
// 0 -4 -5 -6 0 6 5 4
// 0 -1 -2 -3 0 3 2 1
amrex::Box bx = src.box();
bx.growLo(0, 1);
bx.growLo(1, 1);
const int lox = bx.smallEnd(0);
const int loy = bx.smallEnd(1);
const int nx = bx.length(0);
const int ny = bx.length(1);
const amrex::IntVect lo = bx.smallEnd();
Array2<amrex::Real const> const src_array = src.const_array();
Array2<amrex::Real> const dst_array = dst.array();
const int refx = dst.box().bigEnd(0)+lox+1;
const int refy = dst.box().bigEnd(1)+loy+1;
const Array2<amrex::Real const> src_array = src.array();
const Array2<amrex::Real> dst_array = dst.array();

amrex::ParallelFor(bx,
[=] AMREX_GPU_DEVICE(int i, int j, int)
{
/* upper left quadrant */
dst_array(i+1,j+1) = src_array(i, j);
/* lower left quadrant */
dst_array(i+1,j+ny+2) = -src_array(i, ny-1-j+2*lo[1]);
/* upper right quadrant */
dst_array(i+nx+2,j+1) = -src_array(nx-1-i+2*lo[0], j);
/* lower right quadrant */
dst_array(i+nx+2,j+ny+2) = src_array(nx-1-i+2*lo[0], ny-1-j+2*lo[1]);
if (i == lox || j == loy) {
dst_array(i, j) = 0;
dst_array(i, j+ny) = 0;
dst_array(i+nx, j) = 0;
dst_array(i+nx, j+ny) = 0;
} else {
const amrex::Real val = src_array(i, j);
/* upper left quadrant */
dst_array(i, j) = val;
/* lower left quadrant */
dst_array(i, refy-j) = -val;
/* upper right quadrant */
dst_array(refx-i, j) = -val;
/* lower right quadrant */
dst_array(refx-i, refy-j) = val;
}
});
}

void Shrink_Mult_Expand (amrex::FArrayBox& dst,
const amrex::BaseFab<amrex::GpuComplex<amrex::Real>>& src,
const amrex::FArrayBox& eigenvalue)
{
// This function combines ShrinkC2R -> multiply with eigenvalue -> ExpandR2R
amrex::Box bx = eigenvalue.box();
bx.growLo(0, 1);
bx.growLo(1, 1);
const int lox = bx.smallEnd(0);
const int loy = bx.smallEnd(1);
const int nx = bx.length(0);
const int ny = bx.length(1);
const int refx = dst.box().bigEnd(0)+lox+1;
const int refy = dst.box().bigEnd(1)+loy+1;
const Array2<amrex::GpuComplex<amrex::Real> const> src_array = src.array();
const Array2<amrex::Real> dst_array = dst.array();
const Array2<amrex::Real const> eigenvalue_array= eigenvalue.array();

amrex::ParallelFor(bx,
[=] AMREX_GPU_DEVICE(int i, int j, int)
{
if (i == lox || j == loy) {
dst_array(i, j) = 0;
dst_array(i, j+ny) = 0;
dst_array(i+nx, j) = 0;
dst_array(i+nx, j+ny) = 0;
} else {
const amrex::Real val = -src_array(i, j).real() * eigenvalue_array(i, j);
/* upper left quadrant */
dst_array(i, j) = val;
/* lower left quadrant */
dst_array(i, refy-j) = -val;
/* upper right quadrant */
dst_array(refx-i, j) = -val;
/* lower right quadrant */
dst_array(refx-i, refy-j) = val;
}
});
}

void ShrinkC2R (amrex::FArrayBox& dst, amrex::BaseFab<amrex::GpuComplex<amrex::Real>>& src)
void ShrinkC2R (amrex::FArrayBox& dst, const amrex::BaseFab<amrex::GpuComplex<amrex::Real>>& src,
amrex::Box bx)
{
const amrex::Box bx = dst.box();
Array2<amrex::GpuComplex<amrex::Real> const> const src_array = src.const_array();
Array2<amrex::Real> const dst_array = dst.array();
const Array2<amrex::GpuComplex<amrex::Real> const> src_array = src.array();
const Array2<amrex::Real> dst_array = dst.array();
amrex::ParallelFor(bx,
[=] AMREX_GPU_DEVICE(int i, int j, int)
{
/* upper left quadrant */
dst_array(i,j) = -src_array(i+1, j+1).real();
dst_array(i,j) = -src_array(i, j).real();
});
}

Expand All @@ -73,16 +142,12 @@ FFTPoissonSolverDirichletExpanded::define (amrex::BoxArray const& a_realspace_ba
// The stagingArea is also created from 0 to nx, because the real space array may have
// an offset for levels > 0
m_stagingArea = amrex::MultiFab(a_realspace_ba, dm, 1, Fields::m_poisson_nguards);
m_tmpSpectralField = amrex::MultiFab(a_realspace_ba, dm, 1, Fields::m_poisson_nguards);
m_eigenvalue_matrix = amrex::MultiFab(a_realspace_ba, dm, 1, Fields::m_poisson_nguards);
m_stagingArea.setVal(0.0, Fields::m_poisson_nguards); // this is not required
m_tmpSpectralField.setVal(0.0, Fields::m_poisson_nguards);

// This must be true even for parallel FFT.
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(m_stagingArea.local_size() == 1,
"There should be only one box locally.");
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(m_tmpSpectralField.local_size() == 1,
"There should be only one box locally.");

const amrex::Box fft_box = m_stagingArea[0].box();
const amrex::IntVect fft_size = fft_box.length();
Expand Down Expand Up @@ -121,8 +186,8 @@ FFTPoissonSolverDirichletExpanded::define (amrex::BoxArray const& a_realspace_ba

// Allocate expanded_position_array Real of size (2*nx+2, 2*ny+2)
// Allocate expanded_fourier_array Complex of size (nx+2, 2*ny+2)
amrex::Box expanded_position_box {{0, 0, 0}, {2*nx+1, 2*ny+1, 0}};
amrex::Box expanded_fourier_box {{0, 0, 0}, {nx+1, 2*ny+1, 0}};
amrex::Box expanded_position_box {{-1, -1, 0}, {2*nx, 2*ny, 0}};
amrex::Box expanded_fourier_box {{-1, -1, 0}, {nx, 2*ny, 0}};
// shift box to match rest of fields
expanded_position_box += fft_box.smallEnd();
expanded_fourier_box += fft_box.smallEnd();
Expand Down Expand Up @@ -150,50 +215,13 @@ FFTPoissonSolverDirichletExpanded::SolvePoissonEquation (amrex::MultiFab& lhs_mf
HIPACE_PROFILE("FFTPoissonSolverDirichletExpanded::SolvePoissonEquation()");
using namespace amrex::literals;

m_expanded_position_array.setVal<amrex::RunOn::Device>(0._rt);

ExpandR2R(m_expanded_position_array, m_stagingArea[0]);

m_fft.Execute();

ShrinkC2R(m_tmpSpectralField[0], m_expanded_fourier_array);

#ifdef AMREX_USE_OMP
#pragma omp parallel if (amrex::Gpu::notInLaunchRegion())
#endif
for ( amrex::MFIter mfi(m_stagingArea, DfltMfiTlng); mfi.isValid(); ++mfi ){
// Solve Poisson equation in Fourier space:
// Multiply `tmpSpectralField` by eigenvalue_matrix
Array2<amrex::Real> tmp_cmplx_arr = m_tmpSpectralField.array(mfi);
Array2<amrex::Real> eigenvalue_matrix = m_eigenvalue_matrix.array(mfi);

amrex::ParallelFor( mfi.growntilebox(),
[=] AMREX_GPU_DEVICE(int i, int j, int) noexcept {
tmp_cmplx_arr(i,j) *= eigenvalue_matrix(i,j);
});
}

m_expanded_position_array.setVal<amrex::RunOn::Device>(0._rt);

ExpandR2R(m_expanded_position_array, m_tmpSpectralField[0]);
Shrink_Mult_Expand(m_expanded_position_array, m_expanded_fourier_array, m_eigenvalue_matrix[0]);

m_fft.Execute();

ShrinkC2R(m_stagingArea[0], m_expanded_fourier_array);

#ifdef AMREX_USE_OMP
#pragma omp parallel if (amrex::Gpu::notInLaunchRegion())
#endif
for ( amrex::MFIter mfi(m_stagingArea, DfltMfiTlng); mfi.isValid(); ++mfi ){
// Copy from the staging area to output array (and normalize)
Array2<amrex::Real> tmp_real_arr = m_stagingArea.array(mfi);
Array2<amrex::Real> lhs_arr = lhs_mf.array(mfi);
AMREX_ALWAYS_ASSERT_WITH_MESSAGE(lhs_mf.size() == 1,
"Slice MFs must be defined on one box only");
amrex::ParallelFor( lhs_mf[mfi].box() & mfi.growntilebox(),
[=] AMREX_GPU_DEVICE(int i, int j, int) noexcept {
// Copy field
lhs_arr(i,j) = tmp_real_arr(i,j);
});
}
ShrinkC2R(lhs_mf[0], m_expanded_fourier_array, m_stagingArea[0].box());
}

0 comments on commit 40690ff

Please sign in to comment.