Skip to content

Enable GPU execution of mpas_reconstruct_2d via OpenACC #1289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/core_atmosphere/dynamics/mpas_atm_time_integration.F
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ subroutine mpas_atm_dynamics_init(domain)
real (kind=RKIND), dimension(:,:), pointer :: zxu
real (kind=RKIND), dimension(:,:), pointer :: dss
real (kind=RKIND), dimension(:), pointer :: specZoneMaskCell
real (kind=RKIND), dimension(:), pointer :: latCell
real (kind=RKIND), dimension(:), pointer :: lonCell
real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct
#endif


Expand Down Expand Up @@ -395,6 +398,15 @@ subroutine mpas_atm_dynamics_init(domain)

call mpas_pool_get_array(mesh, 'specZoneMaskCell', specZoneMaskCell)
!$acc enter data copyin(specZoneMaskCell)

call mpas_pool_get_array(mesh, 'latCell', latCell)
!$acc enter data copyin(latCell)

call mpas_pool_get_array(mesh, 'lonCell', lonCell)
!$acc enter data copyin(lonCell)

call mpas_pool_get_array(mesh, 'coeffs_reconstruct', coeffs_reconstruct)
!$acc enter data copyin(coeffs_reconstruct)
#endif

end subroutine mpas_atm_dynamics_init
Expand Down Expand Up @@ -474,6 +486,9 @@ subroutine mpas_atm_dynamics_finalize(domain)
real (kind=RKIND), dimension(:,:), pointer :: zxu
real (kind=RKIND), dimension(:,:), pointer :: dss
real (kind=RKIND), dimension(:), pointer :: specZoneMaskCell
real (kind=RKIND), dimension(:), pointer :: latCell
real (kind=RKIND), dimension(:), pointer :: lonCell
real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct
#endif


Expand Down Expand Up @@ -626,6 +641,15 @@ subroutine mpas_atm_dynamics_finalize(domain)

call mpas_pool_get_array(mesh, 'specZoneMaskCell', specZoneMaskCell)
!$acc exit data delete(specZoneMaskCell)

call mpas_pool_get_array(mesh, 'latCell', latCell)
!$acc exit data delete(latCell)

call mpas_pool_get_array(mesh, 'lonCell', lonCell)
!$acc exit data delete(lonCell)

call mpas_pool_get_array(mesh, 'coeffs_reconstruct', coeffs_reconstruct)
!$acc exit data delete(coeffs_reconstruct)
#endif

end subroutine mpas_atm_dynamics_finalize
Expand Down
97 changes: 76 additions & 21 deletions src/operators/mpas_vector_reconstruction.F
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ module mpas_vector_reconstruction
use mpas_rbf_interpolation
use mpas_vector_operations

#ifdef MPAS_OPENACC
! For use in regions ported with OpenACC to track in-function transfers
use mpas_timer, only : mpas_timer_start, mpas_timer_stop
#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X)
#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X)
#else
#define MPAS_ACC_TIMER_START(X)
#define MPAS_ACC_TIMER_STOP(X)
#endif

implicit none

public :: mpas_init_reconstruct, mpas_reconstruct
Expand Down Expand Up @@ -207,10 +217,11 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon

! temporary arrays needed in the compute procedure
logical :: includeHalosLocal
integer, pointer :: nCells
integer, pointer :: nCells_ptr, nVertLevels_ptr
integer :: nCells, nVertLevels
integer, dimension(:,:), pointer :: edgesOnCell
integer, dimension(:), pointer :: nEdgesOnCell
integer :: iCell,iEdge, i
integer :: iCell,iEdge, i, k
real(kind=RKIND), dimension(:), pointer :: latCell, lonCell

real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct
Expand All @@ -233,64 +244,108 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon
call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell)

if ( includeHalosLocal ) then
call mpas_pool_get_dimension(meshPool, 'nCells', nCells)
call mpas_pool_get_dimension(meshPool, 'nCells', nCells_ptr)
else
call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCells)
call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCells_ptr)
end if
call mpas_pool_get_dimension(meshPool, 'nVertLevels', nVertLevels_ptr)
! Dereference scalar (single-value) pointers to ensure OpenACC copies the value pointed to implicitly
nCells = nCells_ptr
nVertLevels = nVertLevels_ptr

call mpas_pool_get_array(meshPool, 'latCell', latCell)
call mpas_pool_get_array(meshPool, 'lonCell', lonCell)

call mpas_pool_get_config(meshPool, 'on_a_sphere', on_a_sphere)

MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]')
! Only use sections needed, nCells may be all cells or only non-halo cells
!$acc enter data copyin(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), &
!$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells))
!$acc enter data copyin(u(:,:))
!$acc enter data create(uReconstructX(:,1:nCells),uReconstructY(:,1:nCells), &
!$acc uReconstructZ(:,1:nCells),uReconstructZonal(:,1:nCells), &
!$acc uReconstructMeridional(:,1:nCells))
MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]')

! loop over cell centers
!$omp do schedule(runtime)
!$acc parallel default(present)
!$acc loop gang
do iCell = 1, nCells
! initialize the reconstructed vectors
uReconstructX(:,iCell) = 0.0
uReconstructY(:,iCell) = 0.0
uReconstructZ(:,iCell) = 0.0
!$acc loop vector
do k = 1, nVertLevels
uReconstructX(k,iCell) = 0.0
uReconstructY(k,iCell) = 0.0
uReconstructZ(k,iCell) = 0.0
end do

! a more efficient reconstruction where rbf_values*matrix_reconstruct
! has been precomputed in coeffs_reconstruct
do i=1,nEdgesOnCell(iCell)
!$acc loop seq
do i = 1, nEdgesOnCell(iCell)
iEdge = edgesOnCell(i,iCell)
uReconstructX(:,iCell) = uReconstructX(:,iCell) &
+ coeffs_reconstruct(1,i,iCell) * u(:,iEdge)
uReconstructY(:,iCell) = uReconstructY(:,iCell) &
+ coeffs_reconstruct(2,i,iCell) * u(:,iEdge)
uReconstructZ(:,iCell) = uReconstructZ(:,iCell) &
+ coeffs_reconstruct(3,i,iCell) * u(:,iEdge)
!$acc loop vector
do k = 1, nVertLevels
uReconstructX(k,iCell) = uReconstructX(k,iCell) &
+ coeffs_reconstruct(1,i,iCell) * u(k,iEdge)
uReconstructY(k,iCell) = uReconstructY(k,iCell) &
+ coeffs_reconstruct(2,i,iCell) * u(k,iEdge)
uReconstructZ(k,iCell) = uReconstructZ(k,iCell) &
+ coeffs_reconstruct(3,i,iCell) * u(k,iEdge)
end do

enddo
enddo ! iCell
!$acc end parallel
!$omp end do

call mpas_threading_barrier()

if (on_a_sphere) then
!$omp do schedule(runtime)
!$acc parallel default(present)
!$acc loop gang
do iCell = 1, nCells
clat = cos(latCell(iCell))
slat = sin(latCell(iCell))
clon = cos(lonCell(iCell))
slon = sin(lonCell(iCell))
uReconstructZonal(:,iCell) = -uReconstructX(:,iCell)*slon + &
uReconstructY(:,iCell)*clon
uReconstructMeridional(:,iCell) = -(uReconstructX(:,iCell)*clon &
+ uReconstructY(:,iCell)*slon)*slat &
+ uReconstructZ(:,iCell)*clat
!$acc loop vector
do k = 1, nVertLevels
uReconstructZonal(k,iCell) = -uReconstructX(k,iCell)*slon + &
uReconstructY(k,iCell)*clon
uReconstructMeridional(k,iCell) = -(uReconstructX(k,iCell)*clon &
+ uReconstructY(k,iCell)*slon)*slat &
+ uReconstructZ(k,iCell)*clat
end do
end do
!$acc end parallel
!$omp end do
else
!$omp do schedule(runtime)
!$acc parallel default(present)
!$acc loop gang vector collapse(2)
do iCell = 1, nCells
uReconstructZonal (:,iCell) = uReconstructX(:,iCell)
uReconstructMeridional(:,iCell) = uReconstructY(:,iCell)
do k = 1, nVertLevels
uReconstructZonal (k,iCell) = uReconstructX(k,iCell)
uReconstructMeridional(k,iCell) = uReconstructY(k,iCell)
end do
end do
!$acc end parallel
!$omp end do
end if

MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]')
!$acc exit data delete(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), &
!$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells))
!$acc exit data delete(u(:,:))
!$acc exit data copyout(uReconstructX(:,1:nCells),uReconstructY(:,1:nCells), &
!$acc uReconstructZ(:,1:nCells), uReconstructZonal(:,1:nCells), &
!$acc uReconstructMeridional(:,1:nCells))
MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]')

end subroutine mpas_reconstruct_2d!}}}


Expand Down