Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace 2-elements sequences by Pair utility for double buffering #590

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
f80987e
Initial refactoring to use Swapping for double buffering
egparedes Nov 6, 2024
ec37607
More refactorings and cleanups in the driver.
egparedes Nov 7, 2024
3136ba3
Use keyword arg in DriverParams
egparedes Nov 7, 2024
cc12537
Extend docstrings
egparedes Nov 7, 2024
bbeab74
Extend double buffer changes to solve_nonhydro
egparedes Nov 7, 2024
3a83830
Refactor Swapping to generic Pair class, and specialized it for diffe…
egparedes Nov 11, 2024
84d544c
Change common.utils import alias
egparedes Nov 11, 2024
4a8605a
Recover methods of Pair deleted by accident in previous commits
egparedes Nov 11, 2024
9098322
Format
egparedes Nov 11, 2024
d0c8919
Export `namedproperty` utility
egparedes Nov 11, 2024
04736c9
Update Pair to have both accessors read/writable by default
egparedes Nov 11, 2024
283e422
Replace `ddt_vn_apc_ntl1` and `ddt_vn_apc_ntl2` by `ddt_vn_apc_pc`
egparedes Nov 11, 2024
8fb1a2b
Replace `ddt_w_adv_ntl1` and `ddt_w_adv_ntl2` by `ddt_w_adv_pc`
egparedes Nov 12, 2024
c06fe90
Fix Pair
egparedes Nov 12, 2024
0fc58f6
Format issues
egparedes Nov 12, 2024
81afd2e
Fixes
egparedes Nov 12, 2024
9cc4967
More replacements of prognostic_states lists
egparedes Nov 12, 2024
24110c4
More replacements
egparedes Nov 12, 2024
857a762
Final missing replacement (in theory)
egparedes Nov 13, 2024
636944f
Fixes
egparedes Nov 13, 2024
12d3a08
More missing replacements in tests
egparedes Nov 13, 2024
5f9b795
Update model/driver/src/icon4py/model/driver/icon4py_driver.py
egparedes Nov 13, 2024
5bfd3c4
More missing replacements and deletions.
egparedes Nov 13, 2024
e633ba9
Fix
egparedes Nov 13, 2024
9a8f1fe
More fixes
egparedes Nov 13, 2024
090241a
Simplify Pair base class.
egparedes Nov 13, 2024
d4db5f0
Simplify docstrings
egparedes Nov 13, 2024
1074f53
Testing
egparedes Nov 13, 2024
c2e0026
Fixes after debugging
egparedes Nov 14, 2024
cb7327d
More fixes
egparedes Nov 14, 2024
7f3b229
Refactorings and style
egparedes Nov 14, 2024
efad473
Minor fix to Pair and namedproperty utility classes
egparedes Nov 15, 2024
b3447fb
Rename named_property
egparedes Nov 15, 2024
031e0f4
Missing changes from previous commits
egparedes Nov 15, 2024
ba05499
New refactoring of Pair adding direct item access.
egparedes Nov 18, 2024
a1111e7
More fixes
egparedes Nov 19, 2024
26c435c
Fix remaining failing tests
egparedes Nov 19, 2024
55eae93
Merge branch 'main' into double_buffer_backup
egparedes Nov 20, 2024
f18ad08
Renaming symbols
egparedes Nov 20, 2024
333897f
Fix merging errors
egparedes Nov 21, 2024
69ad9c7
Rename NextStepPair
egparedes Nov 21, 2024
186c93d
Fix spellings
egparedes Nov 21, 2024
62873ba
Fix typos and expand documentation
egparedes Nov 21, 2024
2f97371
Readability improvements
egparedes Nov 22, 2024
da44a25
Address reviewer's comments.
egparedes Nov 22, 2024
0a6772d
Enhance diagnostic states swap documentation.
egparedes Nov 22, 2024
d8eb6f6
Fix style of comments
egparedes Nov 22, 2024
da4c9c9
Rename and enhance documentation related to the velocity tendencies i…
egparedes Nov 22, 2024
4801651
Minor rename
egparedes Nov 22, 2024
857e2ad
Add forgotten changes from Pair to TimeStepPair
egparedes Nov 22, 2024
74bf1fa
Merge branch 'main' into use_double_buffer_class
egparedes Nov 22, 2024
93cc389
Fix bug in dycore wrapper
egparedes Nov 22, 2024
1825590
Remove unneeded indices in dycore_wrapper
egparedes Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 105 additions & 89 deletions model/driver/src/icon4py/model/driver/icon4py_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,24 @@
import logging
import pathlib
import uuid
from typing import Callable
from typing import Callable, NamedTuple

import click
from devtools import Timer
from gt4py.next import gtfn_cpu

import icon4py.model.common.utils as imc_utils
egparedes marked this conversation as resolved.
Show resolved Hide resolved
from icon4py.model.atmosphere.diffusion import (
diffusion,
diffusion_states,
)
from icon4py.model.atmosphere.dycore.nh_solve import solve_nonhydro as solve_nh
from icon4py.model.atmosphere.dycore.state_utils import states as solve_nh_states
from icon4py.model.common.decomposition import definitions as decomposition
from icon4py.model.common.states import prognostic_state as prognostics
from icon4py.model.common.states import (
diagnostic_state as diagnostics,
prognostic_state as prognostics,
)
from icon4py.model.driver import (
icon4py_configuration as driver_config,
initialization_utils as driver_init,
Expand Down Expand Up @@ -62,15 +66,10 @@ def __init__(

self._is_first_step_in_simulation: bool = not self.run_config.restart_mode

self._now: int = 0 # TODO (Chia Rui): move to PrognosticState
self._next: int = 1 # TODO (Chia Rui): move to PrognosticState

def re_init(self):
self._simulation_date = self.run_config.start_date
self._is_first_step_in_simulation = True
self._n_substeps_var = self.run_config.n_substeps
self._now: int = 0 # TODO (Chia Rui): move to PrognosticState
self._next: int = 1 # TODO (Chia Rui): move to PrognosticState

def _validate_config(self):
if self._n_time_steps < 0:
Expand Down Expand Up @@ -98,14 +97,6 @@ def n_substeps_var(self):
def simulation_date(self):
return self._simulation_date

@property
def prognostic_now(self):
return self._now

@property
def prognostic_next(self):
return self._next

@property
def n_time_steps(self):
return self._n_time_steps
Expand All @@ -114,23 +105,17 @@ def n_time_steps(self):
def substep_timestep(self):
return self._substep_timestep

def _swap(self):
time_n_swap = self._next
self._next = self._now
self._now = time_n_swap

def _full_name(self, func: Callable):
return ":".join((self.__class__.__name__, func.__name__))

def time_integration(
self,
diffusion_diagnostic_state: diffusion_states.DiffusionDiagnosticState,
solve_nonhydro_diagnostic_state: solve_nh_states.DiagnosticStateNonHydro,
# TODO (Chia Rui): expand the PrognosticState to include indices of now and next, now it is always assumed that now = 0, next = 1 at the beginning
prognostic_state_list: list[prognostics.PrognosticState],
prognostic_state_swp: imc_utils.Swapping[prognostics.PrognosticState],
egparedes marked this conversation as resolved.
Show resolved Hide resolved
# below is a long list of arguments for dycore time_step that many can be moved to initialization of SolveNonhydro)
prep_adv: solve_nh_states.PrepAdvection,
inital_divdamp_fac_o2: float,
initial_divdamp_fac_o2: float,
do_prep_adv: bool,
):
log.info(
Expand All @@ -154,7 +139,7 @@ def time_integration(
log.info("running initial step to diffuse fields before timeloop starts")
self.diffusion.initial_run(
diffusion_diagnostic_state,
prognostic_state_list[self._now],
prognostic_state_swp.current,
self.dtime_in_seconds,
)
log.info(
Expand All @@ -164,10 +149,10 @@ def time_integration(
for time_step in range(self._n_time_steps):
log.info(f"simulation date : {self._simulation_date} run timestep : {time_step}")
log.info(
f" MAX VN: {prognostic_state_list[self._now].vn.asnumpy().max():.15e} , MAX W: {prognostic_state_list[self._now].w.asnumpy().max():.15e}"
f" MAX VN: {prognostic_state_swp.current.vn.asnumpy().max():.15e} , MAX W: {prognostic_state_swp.current.w.asnumpy().max():.15e}"
)
log.info(
f" MAX RHO: {prognostic_state_list[self._now].rho.asnumpy().max():.15e} , MAX THETA_V: {prognostic_state_list[self._now].theta_v.asnumpy().max():.15e}"
f" MAX RHO: {prognostic_state_swp.current.rho.asnumpy().max():.15e} , MAX THETA_V: {prognostic_state_swp.current.theta_v.asnumpy().max():.15e}"
)
# TODO (Chia Rui): check with Anurag about printing of max and min of variables.

Expand All @@ -179,9 +164,9 @@ def time_integration(
self._integrate_one_time_step(
diffusion_diagnostic_state,
solve_nonhydro_diagnostic_state,
prognostic_state_list,
prognostic_state_swp,
prep_adv,
inital_divdamp_fac_o2,
initial_divdamp_fac_o2,
do_prep_adv,
)
timer.capture()
Expand All @@ -198,36 +183,38 @@ def _integrate_one_time_step(
self,
diffusion_diagnostic_state: diffusion_states.DiffusionDiagnosticState,
solve_nonhydro_diagnostic_state: solve_nh_states.DiagnosticStateNonHydro,
prognostic_state_list: list[prognostics.PrognosticState],
prognostic_state_swp: imc_utils.Swapping[prognostics.PrognosticState],
prep_adv: solve_nh_states.PrepAdvection,
inital_divdamp_fac_o2: float,
initial_divdamp_fac_o2: float,
do_prep_adv: bool,
):
# TODO (Chia Rui): Add update_spinup_damping here to compute divdamp_fac_o2

self._do_dyn_substepping(
solve_nonhydro_diagnostic_state,
prognostic_state_list,
prognostic_state_swp,
prep_adv,
inital_divdamp_fac_o2,
initial_divdamp_fac_o2,
do_prep_adv,
)

if self.diffusion.config.apply_to_horizontal_wind:
self.diffusion.run(
diffusion_diagnostic_state, prognostic_state_list[self._next], self.dtime_in_seconds
diffusion_diagnostic_state,
prognostic_state_swp.other,
self.dtime_in_seconds,
)

self._swap()
prognostic_state_swp.swap()
egparedes marked this conversation as resolved.
Show resolved Hide resolved

# TODO (Chia Rui): add tracer advection here

def _do_dyn_substepping(
self,
solve_nonhydro_diagnostic_state: solve_nh_states.DiagnosticStateNonHydro,
prognostic_state_list: list[prognostics.PrognosticState],
prognostic_state_swp: imc_utils.Swapping[prognostics.PrognosticState],
prep_adv: solve_nh_states.PrepAdvection,
inital_divdamp_fac_o2: float,
initial_divdamp_fac_o2: float,
do_prep_adv: bool,
):
# TODO (Chia Rui): compute airmass for prognostic_state here
Expand All @@ -242,9 +229,9 @@ def _do_dyn_substepping(
)
self.solve_nonhydro.time_step(
solve_nonhydro_diagnostic_state,
prognostic_state_list,
prognostic_state_swp,
prep_adv=prep_adv,
divdamp_fac_o2=inital_divdamp_fac_o2,
divdamp_fac_o2=initial_divdamp_fac_o2,
dtime=self._substep_timestep,
l_recompute=do_recompute,
l_init=self._is_first_step_in_simulation,
Expand All @@ -260,13 +247,43 @@ def _do_dyn_substepping(
do_clean_mflx = False

if not self._is_last_substep(dyn_substep):
self._swap()
prognostic_state_swp.swap()

self._is_first_step_in_simulation = False

# TODO (Chia Rui): compute airmass for prognostic_state here


class DriverStates(NamedTuple):
egparedes marked this conversation as resolved.
Show resolved Hide resolved
"""
Initialized states for the driver run.

Attributes:
prep_advection_prognostic: Fields collecting data for advection during the solve nonhydro timestep.
solve_nonhydro_diagnostic: Initial state for solve_nonhydro diagnostic variables.
diffusion_diagnostic: Initial state for diffusion diagnostic variables.
prognostic_swp: Initial state for prognostic variables (double buffered).
diagnostic: Initial state for global diagnostic variables.
"""

prep_advection_prognostic: solve_nh_states.PrepAdvection
solve_nonhydro_diagnostic: solve_nh_states.DiagnosticStateNonHydro
diffusion_diagnostic: diffusion_states.DiffusionDiagnosticState
prognostic_swp: imc_utils.Swapping[prognostics.PrognosticState]
diagnostic: diagnostics.DiagnosticState


class DriverParams(NamedTuple):
"""
Parameters for the driver run.

Attributes:
divdamp_fac_o2: Second order divdamp factor.
egparedes marked this conversation as resolved.
Show resolved Hide resolved
"""

divdamp_fac_o2: float


def initialize(
file_path: pathlib.Path,
props: decomposition.ProcessProperties,
Expand All @@ -275,30 +292,30 @@ def initialize(
grid_id: uuid.UUID,
grid_root,
grid_level,
):
) -> tuple[TimeLoop, DriverStates, DriverParams]:
"""
Inititalize the driver run.

"reads" in
- load configuration

- load grid information

- initialize components: diffusion and solve_nh

- load diagnostic and prognostic variables (serialized data)

- setup the time loop

Returns:
tl: configured timeloop,
diffusion_diagnostic_state: initial state for diffusion diagnostic variables
nonhydro_diagnostic_state: initial state for solve_nonhydro diagnostic variables
prognostic_state: initial state for prognostic variables
diagnostic_state: initial state for global diagnostic variables
prep_advection: fields collecting data for advection during the solve nonhydro timestep
inital_divdamp_fac_o2: initial divergence damping factor

Initialize the driver run.

This function does the following:
- load configuration
- load grid information
- initialize components: diffusion and solve_nh
- load diagnostic and prognostic variables (serialized data)
- setup the time loop

Parameters:
file_path: Path to the serialized data.
props: Processor properties.
serialization_type: Serialization type.
experiment_type: Experiment type.
grid_id: Grid ID.
grid_root: Grid root.
grid_level: Grid level.

Returns:
TimeLoop: Time loop object.
DriverStates: Initial states for the driver run.
DriverParams: Parameters for the driver run.
"""
log.info("initialize parallel runtime")
log.info(f"reading configuration: experiment {experiment_type}")
Expand Down Expand Up @@ -379,7 +396,7 @@ def initialize(
diffusion_diagnostic_state,
solve_nonhydro_diagnostic_state,
prep_adv,
inital_divdamp_fac_o2,
initial_divdamp_fac_o2,
diagnostic_state,
prognostic_state_now,
prognostic_state_next,
Expand All @@ -391,21 +408,24 @@ def initialize(
rank=props.rank,
experiment_type=experiment_type,
)
prognostic_state_list = [prognostic_state_now, prognostic_state_next]
prognostics_swp = imc_utils.Swapping(prognostic_state_now, prognostic_state_next)

timeloop = TimeLoop(
run_config=config.run_config,
diffusion_granule=diffusion_granule,
solve_nonhydro_granule=solve_nonhydro_granule,
)

return (
timeloop,
diffusion_diagnostic_state,
solve_nonhydro_diagnostic_state,
prognostic_state_list,
diagnostic_state,
prep_adv,
inital_divdamp_fac_o2,
DriverStates(
prep_advection_prognostic=prep_adv,
solve_nonhydro_diagnostic=solve_nonhydro_diagnostic_state,
diffusion_diagnostic=diffusion_diagnostic_state,
prognostic_swp=prognostics_swp,
diagnostic=diagnostic_state,
),
DriverParams(divdamp_fac_o2=initial_divdamp_fac_o2),
)


Expand Down Expand Up @@ -455,7 +475,7 @@ def initialize(
)
def icon4py_driver(
input_path, run_path, mpi, serialization_type, experiment_type, grid_id, grid_root, grid_level
):
) -> None:
"""
usage: python dycore_driver.py abs_path_to_icon4py/testdata/ser_icondata/mpitask1/mch_ch_r04b09_dsl/ser_data

Expand All @@ -479,15 +499,11 @@ def icon4py_driver(
parallel_props = decomposition.get_processor_properties(decomposition.get_runtype(with_mpi=mpi))
grid_id = uuid.UUID(grid_id)
driver_init.configure_logging(run_path, experiment_type, parallel_props)
(
timeloop,
diffusion_diagnostic_state,
solve_nonhydro_diagnostic_state,
prognostic_state_list,
diagnostic_state,
prep_adv,
inital_divdamp_fac_o2,
) = initialize(

time_loop: TimeLoop
ds: DriverStates
dp: DriverParams
time_loop, ds, dp = initialize(
pathlib.Path(input_path),
parallel_props,
serialization_type,
Expand All @@ -496,22 +512,22 @@ def icon4py_driver(
grid_root,
grid_level,
)
log.info(f"Starting ICON dycore run: {timeloop.simulation_date.isoformat()}")
log.info(f"Starting ICON dycore run: {time_loop.simulation_date.isoformat()}")
log.info(
f"input args: input_path={input_path}, n_time_steps={timeloop.n_time_steps}, ending date={timeloop.run_config.end_date}"
f"input args: input_path={input_path}, n_time_steps={time_loop.n_time_steps}, ending date={time_loop.run_config.end_date}"
)

log.info(f"input args: input_path={input_path}, n_time_steps={timeloop.n_time_steps}")
log.info(f"input args: input_path={input_path}, n_time_steps={time_loop.n_time_steps}")

log.info("dycore configuring: DONE")
log.info("timeloop: START")

timeloop.time_integration(
diffusion_diagnostic_state,
solve_nonhydro_diagnostic_state,
prognostic_state_list,
prep_adv,
inital_divdamp_fac_o2,
time_loop.time_integration(
ds.diffusion_diagnostic,
ds.solve_nonhydro_diagnostic,
ds.prognostic_swp,
ds.prep_advection_prognostic,
dp.divdamp_fac_o2,
do_prep_adv=False,
)

Expand Down
Loading
Loading