Skip to content

Consolidating OpenACC device-host memory transfers #1315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
32 changes: 0 additions & 32 deletions src/core_atmosphere/dynamics/mpas_atm_boundaries.F
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,14 @@ subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t
nullify(tend)
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_tend [ACC_data_xfer]')
if (associated(tend)) then
!$acc enter data copyin(tend)
else
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
!$acc enter data copyin(tend_scalars)

! Ensure the integer pointed to by idx_ptr is copied to the gpu device
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)
idx = idx_ptr
end if
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_tend [ACC_data_xfer]')

!$acc parallel default(present)
if (associated(tend)) then
Expand All @@ -426,13 +422,6 @@ subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t
end if
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_tend [ACC_data_xfer]')
if (associated(tend)) then
!$acc exit data delete(tend)
else
!$acc exit data delete(tend_scalars)
end if
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_tend [ACC_data_xfer]')

end subroutine mpas_atm_get_bdy_tend

Expand Down Expand Up @@ -533,9 +522,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del
! query the field as a scalar constituent
!
if (associated(tend) .and. associated(state)) then
MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc enter data copyin(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')

!$acc parallel default(present)
!$acc loop gang vector collapse(2)
Expand All @@ -546,20 +532,13 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del
end do
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc exit data delete(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
else
call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1)
call mpas_pool_get_array(lbc, 'lbc_scalars', state_scalars, 2)
call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr)

idx=idx_ptr ! Avoid non-array pointer for OpenACC

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc enter data copyin(tend_scalars, state_scalars)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')

!$acc parallel default(present)
!$acc loop gang vector collapse(2)
do i=1, horizDim+1
Expand All @@ -569,9 +548,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del
end do
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
!$acc exit data delete(tend_scalars, state_scalars)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]')
end if

end subroutine mpas_atm_get_bdy_state_2d
Expand Down Expand Up @@ -652,10 +628,6 @@ subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim,
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1)
call mpas_pool_get_array(lbc, 'lbc_'//trim(field), state, 2)

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
!$acc enter data copyin(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')

!$acc parallel default(present)
!$acc loop gang vector collapse(3)
do i=1, horizDim+1
Expand All @@ -667,10 +639,6 @@ subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim,
end do
!$acc end parallel

MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')
!$acc exit data delete(tend, state)
MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]')

end subroutine mpas_atm_get_bdy_state_3d


Expand Down
47 changes: 45 additions & 2 deletions src/core_atmosphere/dynamics/mpas_atm_iau.F
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@
! Additional copyright and license information can be found in the LICENSE file
! distributed with this code, or at http://mpas-dev.github.com/license.html
!

#ifdef MPAS_OPENACC
#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X)
#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X)
#else
#define MPAS_ACC_TIMER_START(X)
#define MPAS_ACC_TIMER_STOP(X)
#endif

module mpas_atm_iau

use mpas_derived_types
Expand All @@ -13,9 +22,10 @@ module mpas_atm_iau
use mpas_dmpar
use mpas_constants
use mpas_log, only : mpas_log_write
use mpas_timer

!public :: atm_compute_iau_coef, atm_add_tend_anal_incr

!public :: atm_compute_iau_coef, atm_add_tend_anal_incr

contains

!==================================================================================================
Expand Down Expand Up @@ -76,6 +86,39 @@ real (kind=RKIND) function atm_iau_coef(configs, itimestep, dt) result(wgt_iau)
end if

end function atm_iau_coef

!==================================================================================================
subroutine update_d2h_pre_add_tend_anal_incr(configs,structs)
!==================================================================================================

implicit none

type (mpas_pool_type), intent(in) :: configs
type (mpas_pool_type), intent(inout) :: structs

type (mpas_pool_type), pointer :: tend
type (mpas_pool_type), pointer :: state
type (mpas_pool_type), pointer :: diag

real (kind=RKIND), dimension(:,:), pointer :: rho_edge, rho_zz, theta_m
real(kind=RKIND),dimension(:,:,:), pointer :: scalars, tend_scalars

call mpas_pool_get_subpool(structs, 'tend', tend)
call mpas_pool_get_subpool(structs, 'state', state)
call mpas_pool_get_subpool(structs, 'diag', diag)

MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer')
call mpas_pool_get_array(state, 'theta_m', theta_m, 1)
call mpas_pool_get_array(state, 'scalars', scalars, 1)
call mpas_pool_get_array(state, 'rho_zz', rho_zz, 2)
call mpas_pool_get_array(diag , 'rho_edge', rho_edge)
!$acc update self(theta_m, scalars, rho_zz, rho_edge)

call mpas_pool_get_array(tend, 'scalars_tend', tend_scalars)
!$acc update self(tend_scalars)
MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer')

end subroutine update_d2h_pre_add_tend_anal_incr

!==================================================================================================
subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, tend_rtheta, tend_rho)
Expand Down
Loading