Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runfile timeloop prephysics computations and configuration #1081

Merged
merged 22 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/fv3gfs-fortran
2 changes: 1 addition & 1 deletion external/fv3gfs-wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,7 @@ namelist:
ldebug: false
nudging: null
orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0
prephysics: null
scikit_learn:
diagnostic_ml: false
input_standard_names: {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,7 @@ namelist:
ldebug: false
nudging: null
orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0
prephysics: null
scikit_learn:
diagnostic_ml: false
input_standard_names: {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,7 @@ nudging:
x_wind: 12
y_wind: 12
orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0
prephysics: null
scikit_learn:
diagnostic_ml: false
input_standard_names: {}
Expand Down
8 changes: 8 additions & 0 deletions workflows/prognostic_c48_run/prepare_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig:
nudge_to_observations = (
config_dict.get("namelist", {}).get("fv_core_nml", {}).get("nudge", False)
)

prephysics: Optional[MachineLearningConfig]
if "prephysics" in config_dict:
prephysics = dacite.from_dict(MachineLearningConfig, config_dict["prephysics"])
else:
prephysics = None

nudging: Optional[NudgingConfig]
if "nudging" in config_dict:
config_dict["nudging"]["restarts_path"] = config_dict["nudging"].get(
Expand Down Expand Up @@ -138,6 +145,7 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig:
)

return UserConfig(
prephysics=prephysics,
nudging=nudging,
diagnostics=diagnostics,
fortran_diagnostics=fortran_diagnostics,
Expand Down
3 changes: 3 additions & 0 deletions workflows/prognostic_c48_run/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class UserConfig:
diagnostics: list of diagnostic file configurations
fortran_diagnostics: list of Fortran diagnostic outputs. Currently only used by
post-processing and so only name and chunks items need to be specified.
prephysics: optional configuration of computations prior to physics,
specified by a machine learning configuation
scikit_learn: a machine learning configuration
nudging: nudge2fine configuration. Cannot be used if any scikit_learn model
urls are specified.
Expand All @@ -36,6 +38,7 @@ class UserConfig:

diagnostics: List[DiagnosticFileConfig]
fortran_diagnostics: List[FortranFileConfig]
prephysics: Optional[MachineLearningConfig] = None
scikit_learn: MachineLearningConfig = MachineLearningConfig()
nudging: Optional[NudgingConfig] = None
step_tendency_variables: List[str] = dataclasses.field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,15 @@ def rename_diagnostics(diags: Diagnostics):
"net_heating",
"column_integrated_dQu",
"column_integrated_dQv",
"override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface",
"override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface",
"override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface",
}
ml_tendencies_in_diags = ml_tendencies & set(diags)
for variable in ml_tendencies_in_diags:
attrs = diags[variable].attrs
diags[f"{variable}_diagnostic"] = diags[variable].assign_attrs(
description=attrs["description"] + " (diagnostic only)"
description=attrs.get("description", "") + " (diagnostic only)"
)
diags[variable] = xr.zeros_like(diags[variable]).assign_attrs(attrs)

Expand Down
185 changes: 136 additions & 49 deletions workflows/prognostic_c48_run/runtime/loop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import datetime
import json
import os
import tempfile
import logging
from typing import Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple
from typing import (
Any,
Callable,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
)

import cftime
import fv3gfs.util
Expand All @@ -14,12 +25,16 @@
from runtime.diagnostics.machine_learning import (
compute_baseline_diagnostics,
rename_diagnostics,
)
from runtime.diagnostics.machine_learning import (
precipitation_rate,
precipitation_sum,
)
from runtime.steppers.machine_learning import PureMLStepper, open_model, download_model
from runtime.steppers.machine_learning import (
PureMLStepper,
open_model,
download_model,
MachineLearningConfig,
MLStateStepper,
)
from runtime.steppers.nudging import PureNudger
from runtime.types import Diagnostics, State, Tendencies
from runtime.names import TENDENCY_TO_STATE_NAME
Expand Down Expand Up @@ -114,6 +129,15 @@ def add_tendency(state: Any, tendency: State, dt: float) -> State:
return updated # type: ignore


def assign_attrs_from(src: Any, dst: State) -> State:
"""Given src state and a dst state, return dst state with src attrs
"""
updated = {}
for name in dst:
updated[name] = dst[name].assign_attrs(src[name].attrs)
return updated # type: ignore


class LoggingMixin:

rank: int
Expand All @@ -139,16 +163,17 @@ class TimeLoop(Iterable[Tuple[cftime.DatetimeJulian, Diagnostics]], LoggingMixin
Each time step of the model evolutions proceeds like this::

step_dynamics,
step_prephysics,
compute_physics,
apply_python_to_physics_state
apply_physics
compute_python_updates
apply_python_to_dycore_state
apply_postphysics_to_physics_state,
apply_physics,
compute_postphysics,
apply_postphysics_to_dycore_state,

The time loop relies on objects implementing the :py:class:`Stepper`
interface to enable ML and other updates. The steppers compute their
updates in ``_compute_python_updates``. The ``TimeLoop`` controls when
and how to apply these updates to the FV3 state.
updates in ``_step_prephysics`` and ``_compute_postphysics``. The
``TimeLoop`` controls when and how to apply these updates to the FV3 state.
"""

def __init__(
Expand All @@ -171,13 +196,17 @@ def __init__(
self._timestep = timestep
self._log_info(f"Timestep: {timestep}")

self._do_only_diagnostic_ml = config.scikit_learn.diagnostic_ml
self._prephysics_only_diagnostic_ml: bool = getattr(
getattr(config, "prephysics"), "diagnostic_ml", False
)
self._postphysics_only_diagnostic_ml: bool = config.scikit_learn.diagnostic_ml
self._tendencies: Tendencies = {}
self._state_updates: State = {}

self._states_to_output: Sequence[str] = self._get_states_to_output(config)
self._log_debug(f"States to output: {self._states_to_output}")
self.stepper = self._get_stepper(config)
self._prephysics_stepper = self._get_prephysics_stepper(config)
self._postphysics_stepper = self._get_postphysics_stepper(config)
self._log_info(self._fv3gfs.get_tracer_metadata())
MPI.COMM_WORLD.barrier() # wait for initialization to finish

Expand All @@ -191,30 +220,51 @@ def _get_states_to_output(self, config: UserConfig) -> Sequence[str]:
states_to_output = diagnostic.variables # type: ignore
return states_to_output

def _get_stepper(self, config: UserConfig) -> Optional[Stepper]:
def _get_prephysics_stepper(self, config: UserConfig) -> Optional[Stepper]:
if config.prephysics is not None and isinstance(
config.prephysics, MachineLearningConfig
):
self._log_info("Using MLStateStepper for prephysics")
model = self._open_model(config.prephysics, "_prephysics")
stepper: Optional[Stepper] = MLStateStepper(model, self._timestep)
else:
self._log_info("No prephysics computations")
stepper = None
return stepper

def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]:
if config.scikit_learn.model:
self._log_info("Using MLStepper")
self._log_info("Downloading ML Model")
if self.rank == 0:
local_model_paths = download_model(config.scikit_learn, "ml_model")
else:
local_model_paths = None # type: ignore
local_model_paths = self.comm.bcast(local_model_paths, root=0)
setattr(config.scikit_learn, "model", local_model_paths)
self._log_info("Model Downloaded From Remote")
model = open_model(config.scikit_learn)
self._log_info("Model Loaded")
return PureMLStepper(model, self._timestep)
self._log_info("Using MLStepper for postphysics updates")
model = self._open_model(config.scikit_learn, "_postphysics")
stepper: Optional[Stepper] = PureMLStepper(model, self._timestep)
elif config.nudging:
self._log_info("Using NudgingStepper")
self._log_info("Using NudgingStepper for postphysics updates")
partitioner = fv3gfs.util.CubedSpherePartitioner.from_namelist(
get_namelist()
)
communicator = fv3gfs.util.CubedSphereCommunicator(self.comm, partitioner)
return PureNudger(config.nudging, communicator)
stepper = PureNudger(config.nudging, communicator)
else:
self._log_info("Performing baseline simulation")
return None
stepper = None
return stepper

def _open_model(self, ml_config: MachineLearningConfig, step: str):
self._log_info("Downloading ML Model")
with tempfile.TemporaryDirectory() as tmpdir:
if self.rank == 0:
local_model_paths = download_model(
ml_config, os.path.join(tmpdir, step)
)
else:
local_model_paths = None # type: ignore
local_model_paths = self.comm.bcast(local_model_paths, root=0)
setattr(ml_config, "model", local_model_paths)
self._log_info("Model Downloaded From Remote")
model = open_model(ml_config)
MPI.COMM_WORLD.barrier()
self._log_info("Model Loaded")
return model

@property
def time(self) -> cftime.DatetimeJulian:
Expand Down Expand Up @@ -288,42 +338,77 @@ def _print_global_timings(self, root=0):
def _substeps(self) -> Sequence[Callable[..., Diagnostics]]:
return [
self._step_dynamics,
self._step_prephysics,
self._compute_physics,
self._apply_python_to_physics_state,
self._apply_postphysics_to_physics_state,
self._apply_physics,
self._compute_python_updates,
self._apply_python_to_dycore_state,
self._compute_postphysics,
self._apply_postphysics_to_dycore_state,
]

def _apply_python_to_physics_state(self) -> Diagnostics:
def _step_prephysics(self) -> Diagnostics:

if self._prephysics_stepper is None:
diagnostics: Diagnostics = {}
else:
self._log_debug("Computing prephysics updates")
_, diagnostics, state_updates = self._prephysics_stepper(
self._state.time, self._state
)
if self._prephysics_only_diagnostic_ml:
rename_diagnostics(diagnostics)
else:
self._state_updates.update(state_updates)
prephysics_overrides = [
"override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface",
"override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface",
"override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface",
]
state_updates = {
k: v for k, v in self._state_updates.items() if k in prephysics_overrides
}
self._state_updates = dissoc(self._state_updates, *prephysics_overrides)
self._log_debug(
f"Applying prephysics state updates for: {list(state_updates.keys())}"
)
updated_state = assign_attrs_from(self._state, state_updates)
self._state.update_mass_conserving(updated_state)

return diagnostics

def _apply_postphysics_to_physics_state(self) -> Diagnostics:
"""Apply computed tendencies and state updates to the physics state

Mostly used for updating the eastward and northward winds.
"""
self._log_debug(f"Apply python tendencies to physics state")
self._log_debug(f"Apply postphysics tendencies to physics state")
tendency = {k: v for k, v in self._tendencies.items() if k in ["dQu", "dQv"]}

diagnostics: Diagnostics = {}

if self.stepper is not None:
diagnostics = self.stepper.get_momentum_diagnostics(self._state, tendency)
if self._do_only_diagnostic_ml:
if self._postphysics_stepper is not None:
diagnostics = self._postphysics_stepper.get_momentum_diagnostics(
self._state, tendency
)
if self._postphysics_only_diagnostic_ml:
rename_diagnostics(diagnostics)
else:
updated_state = add_tendency(self._state, tendency, dt=self._timestep)
self._state.update_mass_conserving(updated_state)

return diagnostics

def _compute_python_updates(self) -> Diagnostics:
self._log_info("Computing Python Updates")
def _compute_postphysics(self) -> Diagnostics:
self._log_info("Computing Postphysics Updates")

if self.stepper is None:
if self._postphysics_stepper is None:
return {}
else:
(self._tendencies, diagnostics, self._state_updates,) = self.stepper(
self._state.time, self._state
)
(
self._tendencies,
diagnostics,
self._state_updates,
) = self._postphysics_stepper(self._state.time, self._state)
try:
rank_updated_points = diagnostics["rank_updated_points"]
except KeyError:
Expand All @@ -340,21 +425,23 @@ def _compute_python_updates(self) -> Diagnostics:
)
return diagnostics

def _apply_python_to_dycore_state(self) -> Diagnostics:
def _apply_postphysics_to_dycore_state(self) -> Diagnostics:

tendency = dissoc(self._tendencies, "dQu", "dQv")

if self.stepper is None:
if self._postphysics_stepper is None:
diagnostics = compute_baseline_diagnostics(self._state)
else:
diagnostics = self.stepper.get_diagnostics(self._state, tendency)
if self._do_only_diagnostic_ml:
diagnostics = self._postphysics_stepper.get_diagnostics(
self._state, tendency
)
if self._postphysics_only_diagnostic_ml:
rename_diagnostics(diagnostics)
else:
updated_state = add_tendency(self._state, tendency, dt=self._timestep)
updated_state[TOTAL_PRECIP] = precipitation_sum(
self._state[TOTAL_PRECIP],
diagnostics[self.stepper.net_moistening],
diagnostics[self._postphysics_stepper.net_moistening],
self._timestep,
)
diagnostics[TOTAL_PRECIP] = updated_state[TOTAL_PRECIP]
Expand Down Expand Up @@ -464,8 +551,8 @@ def __init__(
self._storage_variables = list(storage_variables)

_apply_physics = monitor("fv3_physics", TimeLoop._apply_physics)
_apply_python_to_dycore_state = monitor(
"python", TimeLoop._apply_python_to_dycore_state
_apply_postphysics_to_dycore_state = monitor(
"python", TimeLoop._apply_postphysics_to_dycore_state
)


Expand Down
Loading