Skip to content

Commit

Permalink
Fix parallel diffusion tests with dace orchestration (#572)
Browse files Browse the repository at this point in the history
* filter out unused connectivities
* remove diffusion_instance fixture
* update member blacklist in Diffusion.orchestration_uid
  • Loading branch information
DropD authored Nov 20, 2024
1 parent bac96a5 commit df0c34a
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -922,19 +922,12 @@ def orchestration_uid(self) -> str:
members_to_disregard = [
"_backend",
"_exchange",
"mo_intp_rbf_rbf_vec_interpol_vertex",
"calculate_nabla2_and_smag_coefficients_for_vn",
"calculate_diagnostic_quantities_for_turbulence",
"apply_diffusion_to_vn",
"apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence",
"calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools",
"calculate_nabla2_for_theta",
"truly_horizontal_diffusion_nabla_of_theta_over_steep_points",
"update_theta_and_exner",
"copy_field",
"scale_k",
"setup_fields_for_initial_step",
"init_diffusion_local_fields_for_regular_timestep",
"_grid",
*[
name
for name in self.__dict__.keys()
if isinstance(self.__dict__[name], gtx.ffront.decorator.Program)
],
]
return orchestration.generate_orchestration_uid(
self, members_to_disregard=members_to_disregard
Expand Down
2 changes: 0 additions & 2 deletions model/atmosphere/diffusion/tests/diffusion_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,3 @@
stretch_factor,
top_height_limit_for_maximal_layer_thickness,
)

from .utils import diffusion_instance # noqa: F401 # import fixtures from test_utils package
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from icon4py.model.common.test_utils import datatest_utils, helpers, parallel_helpers

from .. import utils
from ..utils import diffusion_instance # noqa


@pytest.mark.mpi
Expand All @@ -41,7 +40,6 @@ def test_parallel_diffusion(
damping_height,
caplog,
backend,
diffusion_instance, # noqa: F811
):
caplog.set_level("INFO")
parallel_helpers.check_comm_size(processor_props)
Expand All @@ -66,8 +64,50 @@ def test_parallel_diffusion(
print(
f"rank={processor_props.rank}/{processor_props.comm_size}: setup: using {processor_props.comm_name} with {processor_props.comm_size} nodes"
)
vertical_config = v_grid.VerticalGridConfig(
icon_grid.num_levels,
lowest_layer_thickness=lowest_layer_thickness,
model_top_height=model_top_height,
stretch_factor=stretch_factor,
rayleigh_damping_height=damping_height,
)

diffusion = diffusion_instance # the fixture makes sure that the orchestrator cache is cleared properly between pytest runs -if applicable-
diffusion_params = diffusion_.DiffusionParams(config)
metric_state = diffusion_states.DiffusionMetricState(
mask_hdiff=metrics_savepoint.mask_hdiff(),
theta_ref_mc=metrics_savepoint.theta_ref_mc(),
wgtfac_c=metrics_savepoint.wgtfac_c(),
zd_intcoef=metrics_savepoint.zd_intcoef(),
zd_vertoffset=metrics_savepoint.zd_vertoffset(),
zd_diffcoef=metrics_savepoint.zd_diffcoef(),
)
interpolation_state = diffusion_states.DiffusionInterpolationState(
e_bln_c_s=helpers.as_1D_sparse_field(interpolation_savepoint.e_bln_c_s(), dims.CEDim),
rbf_coeff_1=interpolation_savepoint.rbf_vec_coeff_v1(),
rbf_coeff_2=interpolation_savepoint.rbf_vec_coeff_v2(),
geofac_div=helpers.as_1D_sparse_field(interpolation_savepoint.geofac_div(), dims.CEDim),
geofac_n2s=interpolation_savepoint.geofac_n2s(),
geofac_grg_x=interpolation_savepoint.geofac_grg()[0],
geofac_grg_y=interpolation_savepoint.geofac_grg()[1],
nudgecoeff_e=interpolation_savepoint.nudgecoeff_e(),
)
cell_geometry = grid_savepoint.construct_cell_geometry()
edge_geometry = grid_savepoint.construct_edge_geometry()
exchange = definitions.create_exchange(processor_props, decomposition_info)
diffusion = diffusion_.Diffusion(
grid=icon_grid,
config=config,
params=diffusion_params,
vertical_grid=v_grid.VerticalGrid(
vertical_config, grid_savepoint.vct_a(), grid_savepoint.vct_b()
),
metric_state=metric_state,
interpolation_state=interpolation_state,
edge_params=edge_geometry,
cell_params=cell_geometry,
exchange=exchange,
backend=backend,
)

print(f"rank={processor_props.rank}/{processor_props.comm_size}: diffusion initialized ")

Expand Down Expand Up @@ -127,7 +167,6 @@ def test_parallel_diffusion_multiple_steps(
damping_height,
caplog,
backend,
diffusion_instance, # noqa: F811
):
if settings.dace_orchestration is None:
raise pytest.skip("This test is only executed for `--dace-orchestration=True`.")
Expand Down Expand Up @@ -244,9 +283,8 @@ def test_parallel_diffusion_multiple_steps(
######################################################################
settings.dace_orchestration = True

diffusion = diffusion_instance # the fixture makes sure that the orchestrator cache is cleared properly between pytest runs -if applicable-

diffusion.init(
exchange = definitions.create_exchange(processor_props, decomposition_info)
diffusion = diffusion_.Diffusion(
grid=icon_grid,
config=config,
params=diffusion_params,
Expand All @@ -257,6 +295,8 @@ def test_parallel_diffusion_multiple_steps(
interpolation_state=interpolation_state,
edge_params=edge_geometry,
cell_params=cell_geometry,
exchange=exchange,
backend=backend,
)
print(f"rank={processor_props.rank}/{processor_props.comm_size}: diffusion initialized ")

Expand Down
13 changes: 11 additions & 2 deletions model/atmosphere/diffusion/tests/diffusion_tests/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def test_run_diffusion_multiple_steps(
damping_height,
ndyn_substeps,
backend,
diffusion_instance, # F811 fixture
icon_grid,
):
if settings.dace_orchestration is None:
Expand Down Expand Up @@ -605,7 +604,17 @@ def test_run_diffusion_multiple_steps(
)
prognostic_state_dace_orch = savepoint_diffusion_init.construct_prognostics()

diffusion_granule = diffusion_instance # the fixture makes sure that the orchestrator cache is cleared properly between pytest runs -if applicable-
diffusion_granule = diffusion.Diffusion(
grid=icon_grid,
config=config,
params=additional_parameters,
vertical_grid=vertical_params,
metric_state=metric_state,
interpolation_state=interpolation_state,
edge_params=edge_geometry,
cell_params=cell_geometry,
backend=backend,
)

for _ in range(3):
diffusion_granule.run(
Expand Down
78 changes: 0 additions & 78 deletions model/atmosphere/diffusion/tests/diffusion_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pytest

import icon4py.model.common.dimension as dims
from icon4py.model.atmosphere.diffusion import diffusion, diffusion_states
from icon4py.model.common import settings
from icon4py.model.common.decomposition import definitions
from icon4py.model.common.grid import vertical as v_grid
from icon4py.model.common.states import prognostic_state as prognostics
from icon4py.model.common.test_utils import helpers, serialbox_utils as sb

Expand Down Expand Up @@ -159,76 +154,3 @@ def compare_dace_orchestration_multiple_steps(
assert np.allclose(vn_dace_non_orch, vn_dace_orch)
else:
raise ValueError("Field type not recognized")


@pytest.fixture
def diffusion_instance(
icon_grid,
metrics_savepoint,
interpolation_savepoint,
ndyn_substeps,
experiment,
lowest_layer_thickness,
model_top_height,
stretch_factor,
damping_height,
grid_savepoint,
backend,
processor_props, # fixture
decomposition_info, # fixture
):
"""Fixture to create a diffusion instance and clear the orchestration cache properly -if applicable-."""
exchange = definitions.create_exchange(processor_props, decomposition_info)
edge_geometry = grid_savepoint.construct_edge_geometry()
cell_geometry = grid_savepoint.construct_cell_geometry()
vertical_config = v_grid.VerticalGridConfig(
icon_grid.num_levels,
lowest_layer_thickness=lowest_layer_thickness,
model_top_height=model_top_height,
stretch_factor=stretch_factor,
rayleigh_damping_height=damping_height,
)
vertical_params = v_grid.VerticalGrid(
config=vertical_config,
vct_a=grid_savepoint.vct_a(),
vct_b=grid_savepoint.vct_b(),
_min_index_flat_horizontal_grad_pressure=grid_savepoint.nflat_gradp(),
)
config = construct_diffusion_config(experiment, ndyn_substeps=ndyn_substeps)
additional_parameters = diffusion.DiffusionParams(config)
metric_state = diffusion_states.DiffusionMetricState(
mask_hdiff=metrics_savepoint.mask_hdiff(),
theta_ref_mc=metrics_savepoint.theta_ref_mc(),
wgtfac_c=metrics_savepoint.wgtfac_c(),
zd_intcoef=metrics_savepoint.zd_intcoef(),
zd_vertoffset=metrics_savepoint.zd_vertoffset(),
zd_diffcoef=metrics_savepoint.zd_diffcoef(),
)
interpolation_state = diffusion_states.DiffusionInterpolationState(
e_bln_c_s=helpers.as_1D_sparse_field(interpolation_savepoint.e_bln_c_s(), dims.CEDim),
rbf_coeff_1=interpolation_savepoint.rbf_vec_coeff_v1(),
rbf_coeff_2=interpolation_savepoint.rbf_vec_coeff_v2(),
geofac_div=helpers.as_1D_sparse_field(interpolation_savepoint.geofac_div(), dims.CEDim),
geofac_n2s=interpolation_savepoint.geofac_n2s(),
geofac_grg_x=interpolation_savepoint.geofac_grg()[0],
geofac_grg_y=interpolation_savepoint.geofac_grg()[1],
nudgecoeff_e=interpolation_savepoint.nudgecoeff_e(),
)

diffusion_instance_ = diffusion.Diffusion(
grid=icon_grid,
config=config,
params=additional_parameters,
vertical_grid=vertical_params,
metric_state=metric_state,
interpolation_state=interpolation_state,
edge_params=edge_geometry,
cell_params=cell_geometry,
backend=backend,
exchange=exchange,
)

yield diffusion_instance_

if settings.dace_orchestration is not None:
diffusion_instance_._do_diffusion_step.clear_cache()
17 changes: 8 additions & 9 deletions model/common/src/icon4py/model/common/orchestration/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def wrapper(*args, **kwargs):
)
updated_kwargs = {
**updated_kwargs,
**dace_specific_kwargs(exchange_obj, grid.offset_providers),
**dace_specific_kwargs(
exchange_obj,
{
k: v
for k, v in grid.offset_providers.items()
if connectivity_identifier(k) in sdfg.arrays
},
),
}
updated_kwargs = {
**updated_kwargs,
Expand All @@ -183,14 +190,6 @@ def wrapper(*args, **kwargs):
configure_dace_temp_env(default_build_folder)
return compiled_sdfg(**sdfg_args)

# Pytest does not clear the cache between runs in a proper way -pytest.mark.parametrize(...)-.
# This leads to corrupted cache and subsequent errors.
# To avoid this, we provide a way to clear the cache.
def clear_cache():
orchestrator_cache.clear()

wrapper.clear_cache = clear_cache

return wrapper

else:
Expand Down

0 comments on commit df0c34a

Please sign in to comment.