From ebf22c7d99bda870bbe4eba8f2020203ae57d6c6 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Wed, 10 Mar 2021 05:24:49 +0000 Subject: [PATCH 01/17] prephysics call, mock prescriber and PrephysicsConfig --- .../prognostic_c48_run/prepare_config.py | 7 + .../prognostic_c48_run/runtime/config.py | 2 + workflows/prognostic_c48_run/runtime/loop.py | 138 ++++++++++++++---- .../runtime/steppers/machine_learning.py | 4 +- .../runtime/steppers/prephysics.py | 63 ++++++++ 5 files changed, 184 insertions(+), 30 deletions(-) create mode 100644 workflows/prognostic_c48_run/runtime/steppers/prephysics.py diff --git a/workflows/prognostic_c48_run/prepare_config.py b/workflows/prognostic_c48_run/prepare_config.py index 1ebb96a7e2..8886bd0ab6 100644 --- a/workflows/prognostic_c48_run/prepare_config.py +++ b/workflows/prognostic_c48_run/prepare_config.py @@ -17,6 +17,7 @@ from runtime.steppers.nudging import NudgingConfig from runtime.config import UserConfig from runtime.steppers.machine_learning import MachineLearningConfig +from runtime.steppers.prephysics import PrephysicsConfig logger = logging.getLogger(__name__) @@ -96,6 +97,11 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig: config_dict.get("namelist", {}).get("fv_core_nml", {}).get("nudge", False) ) + if "prephysics" in config_dict: + prephysics = dacite.from_dict(PrephysicsConfig, config_dict["prephysics"]) + else: + prephysics = None + if "nudging" in config_dict: config_dict["nudging"]["restarts_path"] = config_dict["nudging"].get( "restarts_path", args.initial_condition_url @@ -134,6 +140,7 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig: ) return UserConfig( + prephysics=prephysics, nudging=nudging, diagnostics=diagnostics, fortran_diagnostics=fortran_diagnostics, diff --git a/workflows/prognostic_c48_run/runtime/config.py b/workflows/prognostic_c48_run/runtime/config.py index a6e50c4893..38ebdb5cb7 100644 --- a/workflows/prognostic_c48_run/runtime/config.py +++ b/workflows/prognostic_c48_run/runtime/config.py @@ -12,6 +12,7 @@ ) from runtime.steppers.nudging import NudgingConfig from runtime.steppers.machine_learning import MachineLearningConfig +from runtime.steppers.prephysics import PrephysicsConfig FV3CONFIG_FILENAME = "fv3config.yml" @@ -36,6 +37,7 @@ class UserConfig: diagnostics: List[DiagnosticFileConfig] fortran_diagnostics: List[FortranFileConfig] + prephysics: Optional[PrephysicsConfig] = None scikit_learn: MachineLearningConfig = MachineLearningConfig() nudging: Optional[NudgingConfig] = None step_tendency_variables: List[str] = dataclasses.field( diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 1336241594..2eb32248c8 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -1,7 +1,18 @@ import datetime import json +import os import logging -from typing import Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, +) import cftime import fv3gfs.util @@ -14,13 +25,17 @@ from runtime.diagnostics.machine_learning import ( compute_baseline_diagnostics, rename_diagnostics, -) -from runtime.diagnostics.machine_learning import ( precipitation_rate, precipitation_sum, ) -from runtime.steppers.machine_learning import PureMLStepper, open_model, download_model +from runtime.steppers.machine_learning import ( + PureMLStepper, + load_adapted_model, + download_model, + MachineLearningConfig, +) from runtime.steppers.nudging import PureNudger +from runtime.steppers.prephysics import Prescriber, PrescriberConfig from runtime.types import Diagnostics, State, Tendencies from runtime.names import TENDENCY_TO_STATE_NAME from toolz import dissoc @@ -177,7 +192,7 @@ def __init__( self._states_to_output: Sequence[str] = self._get_states_to_output(config) self._log_debug(f"States to output: {self._states_to_output}") - self.stepper = self._get_stepper(config) + self.steppers = self._get_steppers(config) self._log_info(self._fv3gfs.get_tracer_metadata()) MPI.COMM_WORLD.barrier() # wait for initialization to finish @@ -191,30 +206,64 @@ def _get_states_to_output(self, config: UserConfig) -> Sequence[str]: states_to_output = diagnostic.variables # type: ignore return states_to_output - def _get_stepper(self, config: UserConfig) -> Optional[Stepper]: + def _get_steppers(self, config: UserConfig) -> Mapping[str, Optional[Stepper]]: + + steppers: MutableMapping[str, Optional[Stepper]] = {} + + if config.prephysics is not None and isinstance( + config.prephysics.config, MachineLearningConfig + ): + self._log_info("Using MLStepper for prephysics") + model = self._open_model(config.prephysics.config, "_compute_prephysics") + steppers["_compute_prephysics"] = PureMLStepper(model, self._timestep) + elif config.prephysics is not None and isinstance( + config.prephysics.config, PrescriberConfig + ): + self._log_info("Using Prescriber for prephysics") + partitioner = fv3gfs.util.CubedSpherePartitioner.from_namelist( + get_namelist() + ) + communicator = fv3gfs.util.CubedSphereCommunicator(self.comm, partitioner) + steppers["_compute_prephysics"] = Prescriber( + config.prephysics.config, communicator + ) + else: + self._log_info("No prephysics computations") + steppers["_compute_prephysics"] = None + if config.scikit_learn.model: - self._log_info("Using MLStepper") - self._log_info("Downloading ML Model") - if self.rank == 0: - local_model_paths = download_model(config.scikit_learn, "ml_model") - else: - local_model_paths = None # type: ignore - local_model_paths = self.comm.bcast(local_model_paths, root=0) - setattr(config.scikit_learn, "model", local_model_paths) - self._log_info("Model Downloaded From Remote") - model = open_model(config.scikit_learn) - self._log_info("Model Loaded") - return PureMLStepper(model, self._timestep) + self._log_info("Using MLStepper for python updates") + model = self._open_model(config.scikit_learn, "_compute_python_updates") + steppers["_compute_python_updates"] = PureMLStepper(model, self._timestep) elif config.nudging: - self._log_info("Using NudgingStepper") + self._log_info("Using NudgingStepper for python updates") partitioner = fv3gfs.util.CubedSpherePartitioner.from_namelist( get_namelist() ) communicator = fv3gfs.util.CubedSphereCommunicator(self.comm, partitioner) - return PureNudger(config.nudging, communicator) + steppers["_compute_python_updates"] = PureNudger( + config.nudging, communicator + ) else: self._log_info("Performing baseline simulation") - return None + steppers["_compute_python_updates"] = None + + return steppers + + def _open_model(self, ml_config: MachineLearningConfig, step: str): + self._log_info("Downloading ML Model") + if self.rank == 0: + local_model_paths = download_model( + ml_config, os.path.join(step, "ml_model") + ) + else: + local_model_paths = None # type: ignore + local_model_paths = self.comm.bcast(local_model_paths, root=0) + setattr(ml_config, "model", local_model_paths) + self._log_info("Model Downloaded From Remote") + model = load_adapted_model(ml_config) + self._log_info("Model Loaded") + return model @property def time(self) -> cftime.DatetimeJulian: @@ -288,6 +337,8 @@ def _print_global_timings(self, root=0): def _substeps(self) -> Sequence[Callable[..., Diagnostics]]: return [ self._step_dynamics, + self._compute_prephysics, + self._apply_prephysics, self._compute_physics, self._apply_python_to_physics_state, self._apply_physics, @@ -295,6 +346,31 @@ def _substeps(self) -> Sequence[Callable[..., Diagnostics]]: self._apply_python_to_dycore_state, ] + def _compute_prephysics(self) -> Diagnostics: + stepper = self.steppers["_compute_prephysics"] + if stepper is None: + diagnostics: Diagnostics = {} + else: + _, diagnostics, state_updates = stepper(self._state.time, self._state) + self._state_updates.update(state_updates) + return diagnostics + + def _apply_prephysics(self): + radiative_fluxes = [ + "total_sky_downward_shortwave_flux_at_surface_override", + "total_sky_net_shortwave_flux_at_surface_override", + "total_sky_downward_longwave_flux_at_surface_override", + ] + state_updates = { + k: v for k, v in self._state_updates.items() if k in radiative_fluxes + } + self._state_updates = dissoc(self._state_updates, *radiative_fluxes) + self._log_debug( + f"Applying prephysics state updates for {list(state_updates.keys())}" + ) + self._state.update_mass_conserving(state_updates) + return {} + def _apply_python_to_physics_state(self) -> Diagnostics: """Apply computed tendencies and state updates to the physics state @@ -305,8 +381,10 @@ def _apply_python_to_physics_state(self) -> Diagnostics: diagnostics: Diagnostics = {} - if self.stepper is not None: - diagnostics = self.stepper.get_momentum_diagnostics(self._state, tendency) + stepper = self.steppers["_compute_python_updates"] + + if stepper is not None: + diagnostics = stepper.get_momentum_diagnostics(self._state, tendency) if self._do_only_diagnostic_ml: rename_diagnostics(diagnostics) else: @@ -318,10 +396,12 @@ def _apply_python_to_physics_state(self) -> Diagnostics: def _compute_python_updates(self) -> Diagnostics: self._log_info("Computing Python Updates") - if self.stepper is None: + stepper = self.steppers["_compute_python_updates"] + + if stepper is None: return {} else: - (self._tendencies, diagnostics, self._state_updates,) = self.stepper( + (self._tendencies, diagnostics, self._state_updates,) = stepper( self._state.time, self._state ) try: @@ -344,17 +424,19 @@ def _apply_python_to_dycore_state(self) -> Diagnostics: tendency = dissoc(self._tendencies, "dQu", "dQv") - if self.stepper is None: + stepper = self.steppers["_compute_python_updates"] + + if stepper is None: diagnostics = compute_baseline_diagnostics(self._state) else: - diagnostics = self.stepper.get_diagnostics(self._state, tendency) + diagnostics = stepper.get_diagnostics(self._state, tendency) if self._do_only_diagnostic_ml: rename_diagnostics(diagnostics) else: updated_state = add_tendency(self._state, tendency, dt=self._timestep) updated_state[TOTAL_PRECIP] = precipitation_sum( self._state[TOTAL_PRECIP], - diagnostics[self.stepper.net_moistening], + diagnostics[stepper.net_moistening], self._timestep, ) diagnostics[TOTAL_PRECIP] = updated_state[TOTAL_PRECIP] diff --git a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py index e8aac21218..c2b441234a 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py @@ -13,7 +13,7 @@ from vcm import thermo import vcm -__all__ = ["MachineLearningConfig", "PureMLStepper", "open_model"] +__all__ = ["MachineLearningConfig", "PureMLStepper", "load_adapted_model"] logger = logging.getLogger(__name__) @@ -130,7 +130,7 @@ def predict_columnwise(self, arg: xr.Dataset, **kwargs) -> xr.Dataset: return xr.merge(predictions) -def open_model(config: MachineLearningConfig) -> MultiModelAdapter: +def load_adapted_model(config: MachineLearningConfig) -> MultiModelAdapter: model_paths = config.model models = [] for path in model_paths: diff --git a/workflows/prognostic_c48_run/runtime/steppers/prephysics.py b/workflows/prognostic_c48_run/runtime/steppers/prephysics.py new file mode 100644 index 0000000000..95896e1785 --- /dev/null +++ b/workflows/prognostic_c48_run/runtime/steppers/prephysics.py @@ -0,0 +1,63 @@ +from typing import Union, Sequence +import dataclasses +from runtime.steppers.machine_learning import MachineLearningConfig +import fv3gfs.util + + +@dataclasses.dataclass +class PrescriberConfig: + """Configuration for prescribing states in the model from an external source + + Attributes: + variables: list variable names to prescribe + data_source: path to the source of the data to prescribe + + Example:: + + PrescriberConfig( + variables=[''] + data_source="" + ) + + """ + + variables: Sequence[str] + data_source: str + + +class Prescriber: + """A pre-physics stepper which obtains prescribed values from an external source + + TODO: Implement methods + """ + + net_moistening = "net_moistening" + + def __init__( + self, config: PrescriberConfig, communicator: fv3gfs.util.Commmunicator + ): + + self._prescribed_variables: Sequence[str] = list(config.variables) + self._data_source: str = config.data_source + + def __call__(self, time, state): + return {}, {}, {} + + def get_diagnostics(self, state, tendency): + return {} + + def get_momentum_diagnostics(self, state, tendency): + return {} + + +@dataclasses.dataclass +class PrephysicsConfig: + """Configuration of pre-physics computations + + Attributes: + config: can be either a MachineLearningConfig or a + PrescriberConfig, as these are the allowed pre-physics computations + + """ + + config: Union[PrescriberConfig, MachineLearningConfig] From e6543eb8d50287d3e0c08095570463a69b0eb402 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Wed, 10 Mar 2021 06:23:19 +0000 Subject: [PATCH 02/17] tests passing --- ...st_prepare_config.test_prepare_ml_config_regression.out | 1 + ..._config.test_prepare_nudge_to_obs_config_regression.out | 1 + ...epare_config.test_prepare_nudging_config_regression.out | 7 +++++++ .../prognostic_c48_run/examples/nudge_to_fine_config.yml | 6 ++++++ workflows/prognostic_c48_run/prepare_config.py | 4 +++- workflows/prognostic_c48_run/runtime/loop.py | 1 + .../prognostic_c48_run/runtime/steppers/prephysics.py | 4 +++- 7 files changed, 22 insertions(+), 2 deletions(-) diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out index d64bc9de36..350b711036 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out @@ -454,6 +454,7 @@ namelist: ldebug: false nudging: null orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out index e1e0d13caa..a235fc5874 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out @@ -522,6 +522,7 @@ namelist: ldebug: false nudging: null orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out index cf4c9d4d33..a655a72952 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out @@ -539,6 +539,13 @@ nudging: x_wind: 12 y_wind: 12 orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 +prephysics: + config: + data_source: 40day_c48_gfsphysics_15min_may2020 + variables: + - DLWRFsfc_coarse + - DSWRFsfc_coarse + - USWRFsfc_coarse scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml b/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml index 95718dbcfd..cdf3de2fc9 100644 --- a/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml +++ b/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml @@ -6,6 +6,12 @@ nudging: x_wind: 12 y_wind: 12 pressure_thickness_of_atmospheric_layer: 12 +prephysics: + data_source: 40day_c48_gfsphysics_15min_may2020 + variables: + - DLWRFsfc_coarse + - DSWRFsfc_coarse + - USWRFsfc_coarse namelist: coupler_nml: current_date: diff --git a/workflows/prognostic_c48_run/prepare_config.py b/workflows/prognostic_c48_run/prepare_config.py index 8886bd0ab6..fd434ac63f 100644 --- a/workflows/prognostic_c48_run/prepare_config.py +++ b/workflows/prognostic_c48_run/prepare_config.py @@ -98,7 +98,9 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig: ) if "prephysics" in config_dict: - prephysics = dacite.from_dict(PrephysicsConfig, config_dict["prephysics"]) + prephysics = dacite.from_dict( + PrephysicsConfig, {"config": config_dict["prephysics"]} + ) else: prephysics = None diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 2eb32248c8..5cdd49e24e 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -353,6 +353,7 @@ def _compute_prephysics(self) -> Diagnostics: else: _, diagnostics, state_updates = stepper(self._state.time, self._state) self._state_updates.update(state_updates) + self._log_debug(f"Computing prephysics state updates") return diagnostics def _apply_prephysics(self): diff --git a/workflows/prognostic_c48_run/runtime/steppers/prephysics.py b/workflows/prognostic_c48_run/runtime/steppers/prephysics.py index 95896e1785..d0fae1a689 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/prephysics.py +++ b/workflows/prognostic_c48_run/runtime/steppers/prephysics.py @@ -34,7 +34,9 @@ class Prescriber: net_moistening = "net_moistening" def __init__( - self, config: PrescriberConfig, communicator: fv3gfs.util.Commmunicator + self, + config: PrescriberConfig, + communicator: fv3gfs.util.CubedSphereCommunicator, ): self._prescribed_variables: Sequence[str] = list(config.variables) From 1914f4717f54bed52367c299f6d64aa513e0600b Mon Sep 17 00:00:00 2001 From: brianhenn Date: Wed, 10 Mar 2021 07:02:02 +0000 Subject: [PATCH 03/17] updated wrapper and fortran model for radiative flux setting --- external/fv3gfs-fortran | 2 +- external/fv3gfs-wrapper | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/fv3gfs-fortran b/external/fv3gfs-fortran index 8e9f185329..b8340fb5b9 160000 --- a/external/fv3gfs-fortran +++ b/external/fv3gfs-fortran @@ -1 +1 @@ -Subproject commit 8e9f18532998dcf4e144f93b7689c25686cec169 +Subproject commit b8340fb5b990e70c27e57b1d4ef2b86a772ed85c diff --git a/external/fv3gfs-wrapper b/external/fv3gfs-wrapper index 32878b2c77..bebc8a2a52 160000 --- a/external/fv3gfs-wrapper +++ b/external/fv3gfs-wrapper @@ -1 +1 @@ -Subproject commit 32878b2c777d25b21c86274435a0757762dfa65d +Subproject commit bebc8a2a52fff0e1fb34896ff6560e231dda5e45 From 381b76ff34f3baba1edc7a934e15eb85a4dbc23b Mon Sep 17 00:00:00 2001 From: brianhenn Date: Wed, 10 Mar 2021 20:11:34 +0000 Subject: [PATCH 04/17] add prephysics ML subclass --- ...nfig.test_prepare_ml_config_regression.out | 15 +++++++++--- .../examples/prognostic_config.yml | 11 +++++++-- workflows/prognostic_c48_run/runtime/loop.py | 23 +++++++++++++++---- .../runtime/steppers/machine_learning.py | 19 ++++++++++++++- 4 files changed, 58 insertions(+), 10 deletions(-) diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out index 350b711036..10b582b5e2 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out @@ -219,11 +219,11 @@ namelist: - 0 - 0 - 0 - days: 10 + days: 0 dt_atmos: 900 dt_ocean: 900 force_date_from_namelist: true - hours: 0 + hours: 4 memuse_verbose: true minutes: 0 months: 0 @@ -454,7 +454,16 @@ namelist: ldebug: false nudging: null orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 -prephysics: null +prephysics: + config: + diagnostic_ml: false + input_standard_names: {} + model: + - gs://vcm-ml-experiments/2021-03-01-predict-surface-radiative-flux/static-t-q/trained_model + output_standard_names: + total_sky_downward_longwave_flux_at_surface_override: DLWRFsfc_verif + total_sky_downward_shortwave_flux_at_surface_override: DSWRFsfc_verif + total_sky_net_shortwave_flux_at_surface_override: NSWRFsfc_verif scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/examples/prognostic_config.yml b/workflows/prognostic_c48_run/examples/prognostic_config.yml index 38320c0e4b..995e7c975c 100644 --- a/workflows/prognostic_c48_run/examples/prognostic_config.yml +++ b/workflows/prognostic_c48_run/examples/prognostic_config.yml @@ -1,8 +1,8 @@ base_version: v0.5 namelist: coupler_nml: - days: 10 # total length - hours: 0 + days: 0 # total length + hours: 4 minutes: 0 seconds: 0 dt_atmos: 900 # seconds @@ -14,3 +14,10 @@ namelist: fhzero: 0.25 # hours - frequency at which precip is set back to zero fv_core_nml: n_split: 6 # num dynamics steps per physics step +prephysics: + model: + - gs://vcm-ml-experiments/2021-03-01-predict-surface-radiative-flux/static-t-q/trained_model + output_standard_names: + total_sky_downward_shortwave_flux_at_surface_override: DSWRFsfc_verif + total_sky_net_shortwave_flux_at_surface_override: NSWRFsfc_verif + total_sky_downward_longwave_flux_at_surface_override: DLWRFsfc_verif diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 5cdd49e24e..7b198d6ccf 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -33,6 +33,7 @@ load_adapted_model, download_model, MachineLearningConfig, + MLStateStepper, ) from runtime.steppers.nudging import PureNudger from runtime.steppers.prephysics import Prescriber, PrescriberConfig @@ -129,6 +130,17 @@ def add_tendency(state: Any, tendency: State, dt: float) -> State: return updated # type: ignore +def override_state(state: Any, overriding_state: State) -> State: + """Given state and an overriding state, return updated state. Needed + to maintain attributes of the target state + """ + with xr.set_options(keep_attrs=True): + updated = {} + for name in overriding_state: + updated[name] = 0.0 * state[name] + overriding_state[name] + return updated # type: ignore + + class LoggingMixin: rank: int @@ -213,9 +225,9 @@ def _get_steppers(self, config: UserConfig) -> Mapping[str, Optional[Stepper]]: if config.prephysics is not None and isinstance( config.prephysics.config, MachineLearningConfig ): - self._log_info("Using MLStepper for prephysics") + self._log_info("Using MLStateStepper for prephysics") model = self._open_model(config.prephysics.config, "_compute_prephysics") - steppers["_compute_prephysics"] = PureMLStepper(model, self._timestep) + steppers["_compute_prephysics"] = MLStateStepper(model, self._timestep) elif config.prephysics is not None and isinstance( config.prephysics.config, PrescriberConfig ): @@ -353,7 +365,9 @@ def _compute_prephysics(self) -> Diagnostics: else: _, diagnostics, state_updates = stepper(self._state.time, self._state) self._state_updates.update(state_updates) - self._log_debug(f"Computing prephysics state updates") + self._log_debug( + f"Computing prephysics state updates for {list(self._state_updates.keys())}" + ) return diagnostics def _apply_prephysics(self): @@ -369,7 +383,8 @@ def _apply_prephysics(self): self._log_debug( f"Applying prephysics state updates for {list(state_updates.keys())}" ) - self._state.update_mass_conserving(state_updates) + updated_state = override_state(self._state, state_updates) + self._state.update_mass_conserving(updated_state) return {} def _apply_python_to_physics_state(self) -> Diagnostics: diff --git a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py index c2b441234a..8ad67b60d7 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py @@ -155,7 +155,7 @@ def download_model(config: MachineLearningConfig, path: str) -> Sequence[str]: def predict(model: MultiModelAdapter, state: State) -> State: - """Given ML model and state, return tendency prediction.""" + """Given ML model and state, return prediction""" state_loaded = {key: state[key] for key in model.input_variables} ds = xr.Dataset(state_loaded) # type: ignore output = model.predict_columnwise(ds, feature_dim="z") @@ -212,3 +212,20 @@ def get_diagnostics(self, state, tendency): def get_momentum_diagnostics(self, state, tendency): return runtime.compute_ml_momentum_diagnostics(state, tendency) + + +class MLStateStepper(PureMLStepper): + def __call__(self, time, state): + + diagnostics: Diagnostics = {} + state_updates: State = predict(self.model, state) + + for name in state_updates.keys(): + diagnostics[name] = state_updates[name] + + tendency = {} + return ( + tendency, + diagnostics, + state_updates, + ) From 4a0fd3382e4d8b246ada2e82a868361eb8d503c4 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Wed, 10 Mar 2021 21:43:33 +0000 Subject: [PATCH 05/17] update to wrapper with albedo --- external/fv3gfs-wrapper | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/fv3gfs-wrapper b/external/fv3gfs-wrapper index bebc8a2a52..657c7005bd 160000 --- a/external/fv3gfs-wrapper +++ b/external/fv3gfs-wrapper @@ -1 +1 @@ -Subproject commit bebc8a2a52fff0e1fb34896ff6560e231dda5e45 +Subproject commit 657c7005bd4e4ae895d3797d388bd2e1c0c8d234 From c6878094ea0062bb8349f97f6bcc27b8c46484a7 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Thu, 11 Mar 2021 05:46:41 +0000 Subject: [PATCH 06/17] MLStateStepper tests --- external/fv3gfs-wrapper | 2 +- ...nfig.test_prepare_ml_config_regression.out | 15 ++---- ...test_prepare_nudging_config_regression.out | 8 +--- .../examples/nudge_to_fine_config.yml | 6 --- .../examples/prognostic_config.yml | 13 ++--- workflows/prognostic_c48_run/runtime/loop.py | 12 ++--- ...er_regression_checksum[MLStateStepper].out | 17 +++++++ ...er_regression_checksum[PureMLStepper].out} | 2 + ...epper_schema_unchanged[MLStateStepper].out | 29 ++++++++++++ ...epper_schema_unchanged[PureMLStepper].out} | 6 +++ .../tests/machine_learning_mocks.py | 47 +++++++++++++++++++ .../tests/test_machine_learning.py | 42 +++++++++++++---- 12 files changed, 148 insertions(+), 51 deletions(-) create mode 100644 workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out rename workflows/prognostic_c48_run/tests/_regtest_outputs/{test_machine_learning.test_PureMLStepper_regression_checksum.out => test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out} (96%) create mode 100644 workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out rename workflows/prognostic_c48_run/tests/_regtest_outputs/{test_machine_learning.test_PureMLStepper_schema_unchanged.out => test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out} (91%) diff --git a/external/fv3gfs-wrapper b/external/fv3gfs-wrapper index 657c7005bd..6c4d01c937 160000 --- a/external/fv3gfs-wrapper +++ b/external/fv3gfs-wrapper @@ -1 +1 @@ -Subproject commit 657c7005bd4e4ae895d3797d388bd2e1c0c8d234 +Subproject commit 6c4d01c9379f410a5d629dbb4d9f91e3552b10f7 diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out index 10b582b5e2..350b711036 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out @@ -219,11 +219,11 @@ namelist: - 0 - 0 - 0 - days: 0 + days: 10 dt_atmos: 900 dt_ocean: 900 force_date_from_namelist: true - hours: 4 + hours: 0 memuse_verbose: true minutes: 0 months: 0 @@ -454,16 +454,7 @@ namelist: ldebug: false nudging: null orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 -prephysics: - config: - diagnostic_ml: false - input_standard_names: {} - model: - - gs://vcm-ml-experiments/2021-03-01-predict-surface-radiative-flux/static-t-q/trained_model - output_standard_names: - total_sky_downward_longwave_flux_at_surface_override: DLWRFsfc_verif - total_sky_downward_shortwave_flux_at_surface_override: DSWRFsfc_verif - total_sky_net_shortwave_flux_at_surface_override: NSWRFsfc_verif +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out index a655a72952..0ae5181876 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out @@ -539,13 +539,7 @@ nudging: x_wind: 12 y_wind: 12 orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 -prephysics: - config: - data_source: 40day_c48_gfsphysics_15min_may2020 - variables: - - DLWRFsfc_coarse - - DSWRFsfc_coarse - - USWRFsfc_coarse +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml b/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml index cdf3de2fc9..95718dbcfd 100644 --- a/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml +++ b/workflows/prognostic_c48_run/examples/nudge_to_fine_config.yml @@ -6,12 +6,6 @@ nudging: x_wind: 12 y_wind: 12 pressure_thickness_of_atmospheric_layer: 12 -prephysics: - data_source: 40day_c48_gfsphysics_15min_may2020 - variables: - - DLWRFsfc_coarse - - DSWRFsfc_coarse - - USWRFsfc_coarse namelist: coupler_nml: current_date: diff --git a/workflows/prognostic_c48_run/examples/prognostic_config.yml b/workflows/prognostic_c48_run/examples/prognostic_config.yml index 995e7c975c..76c706e70c 100644 --- a/workflows/prognostic_c48_run/examples/prognostic_config.yml +++ b/workflows/prognostic_c48_run/examples/prognostic_config.yml @@ -1,8 +1,8 @@ base_version: v0.5 namelist: coupler_nml: - days: 0 # total length - hours: 4 + days: 10 # total length + hours: 0 minutes: 0 seconds: 0 dt_atmos: 900 # seconds @@ -14,10 +14,5 @@ namelist: fhzero: 0.25 # hours - frequency at which precip is set back to zero fv_core_nml: n_split: 6 # num dynamics steps per physics step -prephysics: - model: - - gs://vcm-ml-experiments/2021-03-01-predict-surface-radiative-flux/static-t-q/trained_model - output_standard_names: - total_sky_downward_shortwave_flux_at_surface_override: DSWRFsfc_verif - total_sky_net_shortwave_flux_at_surface_override: NSWRFsfc_verif - total_sky_downward_longwave_flux_at_surface_override: DLWRFsfc_verif + + \ No newline at end of file diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 7b198d6ccf..b4d3f717e2 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -363,25 +363,23 @@ def _compute_prephysics(self) -> Diagnostics: if stepper is None: diagnostics: Diagnostics = {} else: + self._log_debug("Computing prephysics updates") _, diagnostics, state_updates = stepper(self._state.time, self._state) self._state_updates.update(state_updates) - self._log_debug( - f"Computing prephysics state updates for {list(self._state_updates.keys())}" - ) return diagnostics def _apply_prephysics(self): - radiative_fluxes = [ + prephysics_overrides = [ "total_sky_downward_shortwave_flux_at_surface_override", "total_sky_net_shortwave_flux_at_surface_override", "total_sky_downward_longwave_flux_at_surface_override", ] state_updates = { - k: v for k, v in self._state_updates.items() if k in radiative_fluxes + k: v for k, v in self._state_updates.items() if k in prephysics_overrides } - self._state_updates = dissoc(self._state_updates, *radiative_fluxes) + self._state_updates = dissoc(self._state_updates, *prephysics_overrides) self._log_debug( - f"Applying prephysics state updates for {list(state_updates.keys())}" + f"Applying prephysics state updates for: {list(state_updates.keys())}" ) updated_state = override_state(self._state, state_updates) self._state.update_mass_conserving(updated_state) diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out new file mode 100644 index 0000000000..3e2b87b688 --- /dev/null +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out @@ -0,0 +1,17 @@ +- - tendencies + - [] +- - diagnostics + - - - downward_longwave + - 8b886991eea4c48475709bca29505185 + - - downward_shortwave + - 47735602b45938d453a31013ef9410ba + - - net_shortwave + - 01b249036fc7c5eec78879aa849ec524 +- - states + - - - downward_longwave + - 8b886991eea4c48475709bca29505185 + - - downward_shortwave + - 47735602b45938d453a31013ef9410ba + - - net_shortwave + - 01b249036fc7c5eec78879aa849ec524 + diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_regression_checksum.out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out similarity index 96% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_regression_checksum.out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out index 7967633322..bd5e48845c 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_regression_checksum.out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out @@ -14,4 +14,6 @@ - fcc46bebe36ea131688f8e15700e18d4 - - rank_updated_points - 261cfa795cd717bb3a4b2bde266ba20a +- - states + - [] diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out new file mode 100644 index 0000000000..adb9b74035 --- /dev/null +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out @@ -0,0 +1,29 @@ +xarray.Dataset { +dimensions: + x = 4 ; + y = 4 ; + +variables: + float64 downward_longwave(y, x) ; + float64 downward_shortwave(y, x) ; + float64 net_shortwave(y, x) ; + +// global attributes: +}xarray.Dataset { +dimensions: + +variables: + +// global attributes: +}xarray.Dataset { +dimensions: + x = 4 ; + y = 4 ; + +variables: + float64 downward_longwave(y, x) ; + float64 downward_shortwave(y, x) ; + float64 net_shortwave(y, x) ; + +// global attributes: +} \ No newline at end of file diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_schema_unchanged.out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out similarity index 91% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_schema_unchanged.out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out index 2fa482b289..9be83ff889 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_schema_unchanged.out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out @@ -25,5 +25,11 @@ variables: float64 dQu(z, y, x) ; float64 dQv(z, y, x) ; +// global attributes: +}xarray.Dataset { +dimensions: + +variables: + // global attributes: } \ No newline at end of file diff --git a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py index 1a94e8eb4a..d9ad09422b 100644 --- a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py +++ b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py @@ -27,6 +27,26 @@ def _model_dataset() -> xr.Dataset: return data +def _rad_model_dataset() -> xr.Dataset: + + arr = np.zeros((1,)) + dims = [ + "sample", + ] + + data = xr.Dataset( + { + "specific_humidity": (dims, arr), + "air_temperature": (dims, arr), + "downward_shortwave": (dims, arr), + "net_shortwave": (dims, arr), + "downward_longwave": (dims, arr), + } + ) + + return data + + def get_mock_sklearn_model() -> fv3fit.Predictor: data = _model_dataset() @@ -60,6 +80,33 @@ def get_mock_sklearn_model() -> fv3fit.Predictor: return model +def get_mock_rad_flux_model() -> fv3fit.Predictor: + + data = _rad_model_dataset() + + n_sample = data.sizes["sample"] + downward_shortwave = np.full(n_sample, 300.0) + net_shortwave = np.full(n_sample, 250.0) + downward_longwave = np.full(n_sample, 400.0) + constant = np.concatenate( + [[downward_shortwave], [net_shortwave], [downward_longwave]] + ) + estimator = RegressorEnsemble( + DummyRegressor(strategy="constant", constant=constant) + ) + + model = SklearnWrapper( + "sample", + ["air_temperature", "specific_humidity"], + ["downward_shortwave", "net_shortwave", "downward_longwave"], + estimator, + ) + + # needed to avoid sklearn.exceptions.NotFittedError + model.fit([data]) + return model + + def get_mock_keras_model() -> fv3fit.Predictor: input_variables = ["air_temperature", "specific_humidity"] diff --git a/workflows/prognostic_c48_run/tests/test_machine_learning.py b/workflows/prognostic_c48_run/tests/test_machine_learning.py index cd3c096f7e..ca8a3c9f5f 100644 --- a/workflows/prognostic_c48_run/tests/test_machine_learning.py +++ b/workflows/prognostic_c48_run/tests/test_machine_learning.py @@ -1,5 +1,5 @@ -from runtime.steppers.machine_learning import PureMLStepper -from machine_learning_mocks import get_mock_sklearn_model +from runtime.steppers.machine_learning import PureMLStepper, MLStateStepper +from machine_learning_mocks import get_mock_sklearn_model, get_mock_rad_flux_model import requests import xarray as xr import joblib @@ -26,12 +26,37 @@ def state(tmp_path_factory): return xr.open_dataset(str(lpath)) -def test_PureMLStepper_schema_unchanged(state, regtest): - model = get_mock_sklearn_model() +@pytest.fixture(params=["PureMLStepper", "MLStateStepper"]) +def ml_stepper_name(request): + return request.param + + +@pytest.fixture +def mock_model(ml_stepper_name): + if ml_stepper_name == "PureMLStepper": + model = get_mock_sklearn_model() + elif ml_stepper_name == "MLStateStepper": + model = get_mock_rad_flux_model() + else: + raise ValueError("ML Stepper name not defined.") + return model + + +@pytest.fixture +def ml_stepper(ml_stepper_name, mock_model): timestep = 900 - (tendencies, diagnostics, _,) = PureMLStepper(model, timestep)(None, state) + if ml_stepper_name == "PureMLStepper": + stepper = PureMLStepper(mock_model, timestep) + elif ml_stepper_name == "MLStateStepper": + stepper = MLStateStepper(mock_model, timestep) + return stepper + + +def test_MLStepper_schema_unchanged(state, ml_stepper, regtest): + (tendencies, diagnostics, states) = ml_stepper(None, state) xr.Dataset(diagnostics).info(regtest) xr.Dataset(tendencies).info(regtest) + xr.Dataset(states).info(regtest) def test_state_regression(state, regtest): @@ -39,14 +64,13 @@ def test_state_regression(state, regtest): print(checksum, file=regtest) -def test_PureMLStepper_regression_checksum(state, regtest): - model = get_mock_sklearn_model() - timestep = 900 - (tendencies, diagnostics, _,) = PureMLStepper(model, timestep)(None, state) +def test_MLStepper_regression_checksum(state, ml_stepper, regtest): + (tendencies, diagnostics, states) = ml_stepper(None, state) checksums = yaml.safe_dump( [ ("tendencies", checksum_xarray_dict(tendencies)), ("diagnostics", checksum_xarray_dict(diagnostics)), + ("states", checksum_xarray_dict(states)), ] ) From 427f4c556d9b605a658f997f0dd56cafd5ea8d6d Mon Sep 17 00:00:00 2001 From: brianhenn Date: Thu, 11 Mar 2021 05:51:17 +0000 Subject: [PATCH 07/17] cleanup --- workflows/prognostic_c48_run/examples/prognostic_config.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/workflows/prognostic_c48_run/examples/prognostic_config.yml b/workflows/prognostic_c48_run/examples/prognostic_config.yml index 76c706e70c..38320c0e4b 100644 --- a/workflows/prognostic_c48_run/examples/prognostic_config.yml +++ b/workflows/prognostic_c48_run/examples/prognostic_config.yml @@ -14,5 +14,3 @@ namelist: fhzero: 0.25 # hours - frequency at which precip is set back to zero fv_core_nml: n_split: 6 # num dynamics steps per physics step - - \ No newline at end of file From cd9ec8f3483bcca8b6cb5a29ffb227c59d5ad157 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Fri, 12 Mar 2021 19:22:25 +0000 Subject: [PATCH 08/17] updated docs --- .../prognostic_c48_run/docs/config-usage.rst | 16 ++++++++++++++++ .../docs/configuration-api.rst | 3 +++ 2 files changed, 19 insertions(+) diff --git a/workflows/prognostic_c48_run/docs/config-usage.rst b/workflows/prognostic_c48_run/docs/config-usage.rst index 4d3c806460..5c4a6903fd 100644 --- a/workflows/prognostic_c48_run/docs/config-usage.rst +++ b/workflows/prognostic_c48_run/docs/config-usage.rst @@ -109,6 +109,22 @@ It can be used multiple times to specify multiple models. For example:: --model_url path/to_another/model > fv3config.yaml +Prephysics +~~~~~~~~~~ + +If prephysics computations (currently, only setting radiative fluxes) are needed, +they can be configured by setting the :py:attr:`UserConfig.prephysics` section, +following what is required by :py:class:`runtime.steppers.prephysics.PrephysicsConfig`. +Its `config` attribute may either specify setting values via an ML model, following +:py:class:`runtime.steppers.machine_learning.MachineLearningConfig` specs, or +setting values directly via an external source (not yet implemented). See example +for setting values via an ML model:: + + prephysics: + config: + model: ["path/to/model"] + + Diagnostics ~~~~~~~~~~~ diff --git a/workflows/prognostic_c48_run/docs/configuration-api.rst b/workflows/prognostic_c48_run/docs/configuration-api.rst index c64e0688f3..03363b439f 100644 --- a/workflows/prognostic_c48_run/docs/configuration-api.rst +++ b/workflows/prognostic_c48_run/docs/configuration-api.rst @@ -40,6 +40,9 @@ Python "Physics" .. py:module:: runtime.nudging .. autoclass:: NudgingConfig +.. py:module:: runtime.steppers.prephysics +.. autoclass:: PrephysicsConfig + Diagnostics ~~~~~~~~~~~ From bf163ad492929523fc61564f3a3c7d9558d41d7e Mon Sep 17 00:00:00 2001 From: brianhenn Date: Fri, 12 Mar 2021 19:25:47 +0000 Subject: [PATCH 09/17] updated fortran model external --- external/fv3gfs-fortran | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/fv3gfs-fortran b/external/fv3gfs-fortran index b8340fb5b9..75605127be 160000 --- a/external/fv3gfs-fortran +++ b/external/fv3gfs-fortran @@ -1 +1 @@ -Subproject commit b8340fb5b990e70c27e57b1d4ef2b86a772ed85c +Subproject commit 75605127be63132c1d16fb6da6e70f60f88e2d40 From b6f6b45609dc81e4f46198fdddadee2e63b2f3cb Mon Sep 17 00:00:00 2001 From: brianhenn Date: Sun, 14 Mar 2021 06:05:31 +0000 Subject: [PATCH 10/17] revert Prescriber --- .../prognostic_c48_run/docs/config-usage.rst | 16 ----- .../docs/configuration-api.rst | 5 -- .../prognostic_c48_run/prepare_config.py | 7 +- .../prognostic_c48_run/runtime/config.py | 6 +- workflows/prognostic_c48_run/runtime/loop.py | 16 +---- .../runtime/steppers/prephysics.py | 65 ------------------- 6 files changed, 8 insertions(+), 107 deletions(-) delete mode 100644 workflows/prognostic_c48_run/runtime/steppers/prephysics.py diff --git a/workflows/prognostic_c48_run/docs/config-usage.rst b/workflows/prognostic_c48_run/docs/config-usage.rst index 5c4a6903fd..c162e50cff 100644 --- a/workflows/prognostic_c48_run/docs/config-usage.rst +++ b/workflows/prognostic_c48_run/docs/config-usage.rst @@ -108,22 +108,6 @@ It can be used multiple times to specify multiple models. For example:: --model_url path/to/model --model_url path/to_another/model > fv3config.yaml - -Prephysics -~~~~~~~~~~ - -If prephysics computations (currently, only setting radiative fluxes) are needed, -they can be configured by setting the :py:attr:`UserConfig.prephysics` section, -following what is required by :py:class:`runtime.steppers.prephysics.PrephysicsConfig`. -Its `config` attribute may either specify setting values via an ML model, following -:py:class:`runtime.steppers.machine_learning.MachineLearningConfig` specs, or -setting values directly via an external source (not yet implemented). See example -for setting values via an ML model:: - - prephysics: - config: - model: ["path/to/model"] - Diagnostics ~~~~~~~~~~~ diff --git a/workflows/prognostic_c48_run/docs/configuration-api.rst b/workflows/prognostic_c48_run/docs/configuration-api.rst index 03363b439f..3a5b1584ac 100644 --- a/workflows/prognostic_c48_run/docs/configuration-api.rst +++ b/workflows/prognostic_c48_run/docs/configuration-api.rst @@ -33,17 +33,12 @@ Top-level Python "Physics" ~~~~~~~~~~~~~~~~ - .. py:module:: runtime.steppers.machine_learning .. autoclass:: MachineLearningConfig .. py:module:: runtime.nudging .. autoclass:: NudgingConfig -.. py:module:: runtime.steppers.prephysics -.. autoclass:: PrephysicsConfig - - Diagnostics ~~~~~~~~~~~ diff --git a/workflows/prognostic_c48_run/prepare_config.py b/workflows/prognostic_c48_run/prepare_config.py index fea2fa9242..ca7e38342e 100644 --- a/workflows/prognostic_c48_run/prepare_config.py +++ b/workflows/prognostic_c48_run/prepare_config.py @@ -17,7 +17,6 @@ from runtime.steppers.nudging import NudgingConfig from runtime.config import UserConfig from runtime.steppers.machine_learning import MachineLearningConfig -from runtime.steppers.prephysics import PrephysicsConfig logger = logging.getLogger(__name__) @@ -97,11 +96,9 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig: config_dict.get("namelist", {}).get("fv_core_nml", {}).get("nudge", False) ) - prephysics: Optional[PrephysicsConfig] + prephysics: Optional[MachineLearningConfig] if "prephysics" in config_dict: - prephysics = dacite.from_dict( - PrephysicsConfig, {"config": config_dict["prephysics"]} - ) + prephysics = dacite.from_dict(MachineLearningConfig, config_dict["prephysics"]) else: prephysics = None diff --git a/workflows/prognostic_c48_run/runtime/config.py b/workflows/prognostic_c48_run/runtime/config.py index 38ebdb5cb7..38de10cf90 100644 --- a/workflows/prognostic_c48_run/runtime/config.py +++ b/workflows/prognostic_c48_run/runtime/config.py @@ -12,7 +12,6 @@ ) from runtime.steppers.nudging import NudgingConfig from runtime.steppers.machine_learning import MachineLearningConfig -from runtime.steppers.prephysics import PrephysicsConfig FV3CONFIG_FILENAME = "fv3config.yml" @@ -22,9 +21,12 @@ class UserConfig: """The top-level object for python runtime configurations Attributes: + diagnostics: list of diagnostic file configurations fortran_diagnostics: list of Fortran diagnostic outputs. Currently only used by post-processing and so only name and chunks items need to be specified. + prephysics: optional configuration of computations prior to physics, + specified by a machine learning configuation scikit_learn: a machine learning configuration nudging: nudge2fine configuration. Cannot be used if any scikit_learn model urls are specified. @@ -37,7 +39,7 @@ class UserConfig: diagnostics: List[DiagnosticFileConfig] fortran_diagnostics: List[FortranFileConfig] - prephysics: Optional[PrephysicsConfig] = None + prephysics: Optional[MachineLearningConfig] = None scikit_learn: MachineLearningConfig = MachineLearningConfig() nudging: Optional[NudgingConfig] = None step_tendency_variables: List[str] = dataclasses.field( diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index b4d3f717e2..2766a32666 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -36,7 +36,6 @@ MLStateStepper, ) from runtime.steppers.nudging import PureNudger -from runtime.steppers.prephysics import Prescriber, PrescriberConfig from runtime.types import Diagnostics, State, Tendencies from runtime.names import TENDENCY_TO_STATE_NAME from toolz import dissoc @@ -223,22 +222,11 @@ def _get_steppers(self, config: UserConfig) -> Mapping[str, Optional[Stepper]]: steppers: MutableMapping[str, Optional[Stepper]] = {} if config.prephysics is not None and isinstance( - config.prephysics.config, MachineLearningConfig + config.prephysics, MachineLearningConfig ): self._log_info("Using MLStateStepper for prephysics") - model = self._open_model(config.prephysics.config, "_compute_prephysics") + model = self._open_model(config.prephysics, "_compute_prephysics") steppers["_compute_prephysics"] = MLStateStepper(model, self._timestep) - elif config.prephysics is not None and isinstance( - config.prephysics.config, PrescriberConfig - ): - self._log_info("Using Prescriber for prephysics") - partitioner = fv3gfs.util.CubedSpherePartitioner.from_namelist( - get_namelist() - ) - communicator = fv3gfs.util.CubedSphereCommunicator(self.comm, partitioner) - steppers["_compute_prephysics"] = Prescriber( - config.prephysics.config, communicator - ) else: self._log_info("No prephysics computations") steppers["_compute_prephysics"] = None diff --git a/workflows/prognostic_c48_run/runtime/steppers/prephysics.py b/workflows/prognostic_c48_run/runtime/steppers/prephysics.py deleted file mode 100644 index d0fae1a689..0000000000 --- a/workflows/prognostic_c48_run/runtime/steppers/prephysics.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Union, Sequence -import dataclasses -from runtime.steppers.machine_learning import MachineLearningConfig -import fv3gfs.util - - -@dataclasses.dataclass -class PrescriberConfig: - """Configuration for prescribing states in the model from an external source - - Attributes: - variables: list variable names to prescribe - data_source: path to the source of the data to prescribe - - Example:: - - PrescriberConfig( - variables=[''] - data_source="" - ) - - """ - - variables: Sequence[str] - data_source: str - - -class Prescriber: - """A pre-physics stepper which obtains prescribed values from an external source - - TODO: Implement methods - """ - - net_moistening = "net_moistening" - - def __init__( - self, - config: PrescriberConfig, - communicator: fv3gfs.util.CubedSphereCommunicator, - ): - - self._prescribed_variables: Sequence[str] = list(config.variables) - self._data_source: str = config.data_source - - def __call__(self, time, state): - return {}, {}, {} - - def get_diagnostics(self, state, tendency): - return {} - - def get_momentum_diagnostics(self, state, tendency): - return {} - - -@dataclasses.dataclass -class PrephysicsConfig: - """Configuration of pre-physics computations - - Attributes: - config: can be either a MachineLearningConfig or a - PrescriberConfig, as these are the allowed pre-physics computations - - """ - - config: Union[PrescriberConfig, MachineLearningConfig] From bb19f9e9553023611aa549805cf1b39fb30fe65e Mon Sep 17 00:00:00 2001 From: brianhenn Date: Sun, 14 Mar 2021 06:40:02 +0000 Subject: [PATCH 11/17] prephysics and postphysics steppers --- workflows/prognostic_c48_run/runtime/loop.py | 117 +++++++++---------- 1 file changed, 57 insertions(+), 60 deletions(-) diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 2766a32666..576a9e4399 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -8,7 +8,6 @@ Iterable, List, Mapping, - MutableMapping, Optional, Sequence, Tuple, @@ -129,14 +128,12 @@ def add_tendency(state: Any, tendency: State, dt: float) -> State: return updated # type: ignore -def override_state(state: Any, overriding_state: State) -> State: - """Given state and an overriding state, return updated state. Needed - to maintain attributes of the target state +def assign_attrs_from(src: Any, dst: State) -> State: + """Given src state and a dst state, return dst state with src attrs """ - with xr.set_options(keep_attrs=True): - updated = {} - for name in overriding_state: - updated[name] = 0.0 * state[name] + overriding_state[name] + updated = {} + for name in dst: + updated[name] = dst[name].assign_attrs(src[name].attrs) return updated # type: ignore @@ -165,16 +162,18 @@ class TimeLoop(Iterable[Tuple[cftime.DatetimeJulian, Diagnostics]], LoggingMixin Each time step of the model evolutions proceeds like this:: step_dynamics, + compute_prephysics, + apply_prephysics, compute_physics, - apply_python_to_physics_state - apply_physics - compute_python_updates - apply_python_to_dycore_state + apply_postphysics_to_physics_state, + apply_physics, + compute_postphysics, + apply_postphysics_to_dycore_state, The time loop relies on objects implementing the :py:class:`Stepper` interface to enable ML and other updates. The steppers compute their - updates in ``_compute_python_updates``. The ``TimeLoop`` controls when - and how to apply these updates to the FV3 state. + updates in ``_compute_prephysics`` and ``_compute_postphysics``. The + ``TimeLoop`` controls when and how to apply these updates to the FV3 state. """ def __init__( @@ -203,7 +202,8 @@ def __init__( self._states_to_output: Sequence[str] = self._get_states_to_output(config) self._log_debug(f"States to output: {self._states_to_output}") - self.steppers = self._get_steppers(config) + self._prephysics_stepper = self._get_prephysics_stepper(config) + self._postphysics_stepper = self._get_postphysics_stepper(config) self._log_info(self._fv3gfs.get_tracer_metadata()) MPI.COMM_WORLD.barrier() # wait for initialization to finish @@ -217,38 +217,34 @@ def _get_states_to_output(self, config: UserConfig) -> Sequence[str]: states_to_output = diagnostic.variables # type: ignore return states_to_output - def _get_steppers(self, config: UserConfig) -> Mapping[str, Optional[Stepper]]: - - steppers: MutableMapping[str, Optional[Stepper]] = {} - + def _get_prephysics_stepper(self, config: UserConfig) -> Optional[Stepper]: if config.prephysics is not None and isinstance( config.prephysics, MachineLearningConfig ): self._log_info("Using MLStateStepper for prephysics") model = self._open_model(config.prephysics, "_compute_prephysics") - steppers["_compute_prephysics"] = MLStateStepper(model, self._timestep) + stepper: Optional[Stepper] = MLStateStepper(model, self._timestep) else: self._log_info("No prephysics computations") - steppers["_compute_prephysics"] = None + stepper = None + return stepper + def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]: if config.scikit_learn.model: - self._log_info("Using MLStepper for python updates") - model = self._open_model(config.scikit_learn, "_compute_python_updates") - steppers["_compute_python_updates"] = PureMLStepper(model, self._timestep) + self._log_info("Using MLStepper for postphysics updates") + model = self._open_model(config.scikit_learn, "_compute_postphysics") + stepper: Optional[Stepper] = PureMLStepper(model, self._timestep) elif config.nudging: - self._log_info("Using NudgingStepper for python updates") + self._log_info("Using NudgingStepper for postphysics updates") partitioner = fv3gfs.util.CubedSpherePartitioner.from_namelist( get_namelist() ) communicator = fv3gfs.util.CubedSphereCommunicator(self.comm, partitioner) - steppers["_compute_python_updates"] = PureNudger( - config.nudging, communicator - ) + stepper = PureNudger(config.nudging, communicator) else: self._log_info("Performing baseline simulation") - steppers["_compute_python_updates"] = None - - return steppers + stepper = None + return stepper def _open_model(self, ml_config: MachineLearningConfig, step: str): self._log_info("Downloading ML Model") @@ -340,19 +336,20 @@ def _substeps(self) -> Sequence[Callable[..., Diagnostics]]: self._compute_prephysics, self._apply_prephysics, self._compute_physics, - self._apply_python_to_physics_state, + self._apply_postphysics_to_physics_state, self._apply_physics, - self._compute_python_updates, - self._apply_python_to_dycore_state, + self._compute_postphysics, + self._apply_postphysics_to_dycore_state, ] def _compute_prephysics(self) -> Diagnostics: - stepper = self.steppers["_compute_prephysics"] - if stepper is None: + if self._prephysics_stepper is None: diagnostics: Diagnostics = {} else: self._log_debug("Computing prephysics updates") - _, diagnostics, state_updates = stepper(self._state.time, self._state) + _, diagnostics, state_updates = self._prephysics_stepper( + self._state.time, self._state + ) self._state_updates.update(state_updates) return diagnostics @@ -369,24 +366,24 @@ def _apply_prephysics(self): self._log_debug( f"Applying prephysics state updates for: {list(state_updates.keys())}" ) - updated_state = override_state(self._state, state_updates) + updated_state = assign_attrs_from(self._state, state_updates) self._state.update_mass_conserving(updated_state) return {} - def _apply_python_to_physics_state(self) -> Diagnostics: + def _apply_postphysics_to_physics_state(self) -> Diagnostics: """Apply computed tendencies and state updates to the physics state Mostly used for updating the eastward and northward winds. """ - self._log_debug(f"Apply python tendencies to physics state") + self._log_debug(f"Apply postphysics tendencies to physics state") tendency = {k: v for k, v in self._tendencies.items() if k in ["dQu", "dQv"]} diagnostics: Diagnostics = {} - stepper = self.steppers["_compute_python_updates"] - - if stepper is not None: - diagnostics = stepper.get_momentum_diagnostics(self._state, tendency) + if self._postphysics_stepper is not None: + diagnostics = self._postphysics_stepper.get_momentum_diagnostics( + self._state, tendency + ) if self._do_only_diagnostic_ml: rename_diagnostics(diagnostics) else: @@ -395,17 +392,17 @@ def _apply_python_to_physics_state(self) -> Diagnostics: return diagnostics - def _compute_python_updates(self) -> Diagnostics: - self._log_info("Computing Python Updates") + def _compute_postphysics(self) -> Diagnostics: + self._log_info("Computing Postphysics Updates") - stepper = self.steppers["_compute_python_updates"] - - if stepper is None: + if self._postphysics_stepper is None: return {} else: - (self._tendencies, diagnostics, self._state_updates,) = stepper( - self._state.time, self._state - ) + ( + self._tendencies, + diagnostics, + self._state_updates, + ) = self._postphysics_stepper(self._state.time, self._state) try: rank_updated_points = diagnostics["rank_updated_points"] except KeyError: @@ -422,23 +419,23 @@ def _compute_python_updates(self) -> Diagnostics: ) return diagnostics - def _apply_python_to_dycore_state(self) -> Diagnostics: + def _apply_postphysics_to_dycore_state(self) -> Diagnostics: tendency = dissoc(self._tendencies, "dQu", "dQv") - stepper = self.steppers["_compute_python_updates"] - - if stepper is None: + if self._postphysics_stepper is None: diagnostics = compute_baseline_diagnostics(self._state) else: - diagnostics = stepper.get_diagnostics(self._state, tendency) + diagnostics = self._postphysics_stepper.get_diagnostics( + self._state, tendency + ) if self._do_only_diagnostic_ml: rename_diagnostics(diagnostics) else: updated_state = add_tendency(self._state, tendency, dt=self._timestep) updated_state[TOTAL_PRECIP] = precipitation_sum( self._state[TOTAL_PRECIP], - diagnostics[stepper.net_moistening], + diagnostics[self._postphysics_stepper.net_moistening], self._timestep, ) diagnostics[TOTAL_PRECIP] = updated_state[TOTAL_PRECIP] @@ -548,8 +545,8 @@ def __init__( self._storage_variables = list(storage_variables) _apply_physics = monitor("fv3_physics", TimeLoop._apply_physics) - _apply_python_to_dycore_state = monitor( - "python", TimeLoop._apply_python_to_dycore_state + _apply_postphysics_to_dycore_state = monitor( + "python", TimeLoop._apply_postphysics_to_dycore_state ) From 436e036a6eec6a3430de3aaa91eb863a03f14b3a Mon Sep 17 00:00:00 2001 From: brianhenn Date: Mon, 15 Mar 2021 21:25:11 +0000 Subject: [PATCH 12/17] prephysics diagnostic ml --- .../runtime/diagnostics/machine_learning.py | 5 ++++- workflows/prognostic_c48_run/runtime/loop.py | 20 +++++++++++++------ .../runtime/steppers/machine_learning.py | 4 ++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py b/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py index a93da8bc18..c548baeb35 100644 --- a/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py @@ -111,12 +111,15 @@ def rename_diagnostics(diags: Diagnostics): "net_heating", "column_integrated_dQu", "column_integrated_dQv", + "total_sky_downward_shortwave_flux_at_surface_override", + "total_sky_net_shortwave_flux_at_surface_override", + "total_sky_downward_longwave_flux_at_surface_override", } ml_tendencies_in_diags = ml_tendencies & set(diags) for variable in ml_tendencies_in_diags: attrs = diags[variable].attrs diags[f"{variable}_diagnostic"] = diags[variable].assign_attrs( - description=attrs["description"] + " (diagnostic only)" + description=attrs.get("description", "") + " (diagnostic only)" ) diags[variable] = xr.zeros_like(diags[variable]).assign_attrs(attrs) diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 576a9e4399..ee2dffd821 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -29,7 +29,7 @@ ) from runtime.steppers.machine_learning import ( PureMLStepper, - load_adapted_model, + open_model, download_model, MachineLearningConfig, MLStateStepper, @@ -196,7 +196,12 @@ def __init__( self._timestep = timestep self._log_info(f"Timestep: {timestep}") - self._do_only_diagnostic_ml = config.scikit_learn.diagnostic_ml + self._prephysics_only_diagnostic_ml: bool = getattr( + getattr(config, "prephysics"), "diagnostic_ml", False + ) + print(f"_prephysics_only_diagnostic_ml: {self._prephysics_only_diagnostic_ml}") + + self._postphysics_only_diagnostic_ml: bool = config.scikit_learn.diagnostic_ml self._tendencies: Tendencies = {} self._state_updates: State = {} @@ -257,7 +262,7 @@ def _open_model(self, ml_config: MachineLearningConfig, step: str): local_model_paths = self.comm.bcast(local_model_paths, root=0) setattr(ml_config, "model", local_model_paths) self._log_info("Model Downloaded From Remote") - model = load_adapted_model(ml_config) + model = open_model(ml_config) self._log_info("Model Loaded") return model @@ -350,7 +355,10 @@ def _compute_prephysics(self) -> Diagnostics: _, diagnostics, state_updates = self._prephysics_stepper( self._state.time, self._state ) - self._state_updates.update(state_updates) + if self._prephysics_only_diagnostic_ml: + rename_diagnostics(diagnostics) + else: + self._state_updates.update(state_updates) return diagnostics def _apply_prephysics(self): @@ -384,7 +392,7 @@ def _apply_postphysics_to_physics_state(self) -> Diagnostics: diagnostics = self._postphysics_stepper.get_momentum_diagnostics( self._state, tendency ) - if self._do_only_diagnostic_ml: + if self._postphysics_only_diagnostic_ml: rename_diagnostics(diagnostics) else: updated_state = add_tendency(self._state, tendency, dt=self._timestep) @@ -429,7 +437,7 @@ def _apply_postphysics_to_dycore_state(self) -> Diagnostics: diagnostics = self._postphysics_stepper.get_diagnostics( self._state, tendency ) - if self._do_only_diagnostic_ml: + if self._postphysics_only_diagnostic_ml: rename_diagnostics(diagnostics) else: updated_state = add_tendency(self._state, tendency, dt=self._timestep) diff --git a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py index 8ad67b60d7..ae6cca56f7 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py @@ -13,7 +13,7 @@ from vcm import thermo import vcm -__all__ = ["MachineLearningConfig", "PureMLStepper", "load_adapted_model"] +__all__ = ["MachineLearningConfig", "PureMLStepper", "open_model"] logger = logging.getLogger(__name__) @@ -130,7 +130,7 @@ def predict_columnwise(self, arg: xr.Dataset, **kwargs) -> xr.Dataset: return xr.merge(predictions) -def load_adapted_model(config: MachineLearningConfig) -> MultiModelAdapter: +def open_model(config: MachineLearningConfig) -> MultiModelAdapter: model_paths = config.model models = [] for path in model_paths: From e7be2636b85bc47340c518eace9bf06a50e468e1 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Mon, 15 Mar 2021 22:04:39 +0000 Subject: [PATCH 13/17] modifying tests per PR review --- workflows/prognostic_c48_run/runtime/loop.py | 2 - ...s_regression_checksum[MLStateStepper].out} | 0 ...rs_regression_checksum[PureMLStepper].out} | 0 ...pers_schema_unchanged[MLStateStepper].out} | 0 ...ppers_schema_unchanged[PureMLStepper].out} | 0 .../tests/machine_learning_mocks.py | 119 +++++++----------- .../tests/test_machine_learning.py | 27 ++-- 7 files changed, 57 insertions(+), 91 deletions(-) rename workflows/prognostic_c48_run/tests/_regtest_outputs/{test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out => test_machine_learning.test_ml_steppers_regression_checksum[MLStateStepper].out} (100%) rename workflows/prognostic_c48_run/tests/_regtest_outputs/{test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out => test_machine_learning.test_ml_steppers_regression_checksum[PureMLStepper].out} (100%) rename workflows/prognostic_c48_run/tests/_regtest_outputs/{test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out => test_machine_learning.test_ml_steppers_schema_unchanged[MLStateStepper].out} (100%) rename workflows/prognostic_c48_run/tests/_regtest_outputs/{test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out => test_machine_learning.test_ml_steppers_schema_unchanged[PureMLStepper].out} (100%) diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index ee2dffd821..0f13b4d901 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -199,8 +199,6 @@ def __init__( self._prephysics_only_diagnostic_ml: bool = getattr( getattr(config, "prephysics"), "diagnostic_ml", False ) - print(f"_prephysics_only_diagnostic_ml: {self._prephysics_only_diagnostic_ml}") - self._postphysics_only_diagnostic_ml: bool = config.scikit_learn.diagnostic_ml self._tendencies: Tendencies = {} self._state_updates: State = {} diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[MLStateStepper].out similarity index 100% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[MLStateStepper].out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[MLStateStepper].out diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[PureMLStepper].out similarity index 100% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_regression_checksum[PureMLStepper].out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[PureMLStepper].out diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[MLStateStepper].out similarity index 100% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[MLStateStepper].out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[MLStateStepper].out diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[PureMLStepper].out similarity index 100% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_MLStepper_schema_unchanged[PureMLStepper].out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[PureMLStepper].out diff --git a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py index d9ad09422b..0fbd4dd487 100644 --- a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py +++ b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py @@ -11,12 +11,17 @@ def _model_dataset() -> xr.Dataset: nz = 63 arr = np.zeros((1, nz)) + arr_1d = np.zeros((1,)) dims = ["sample", "z"] + dims_1d = ["sample"] data = xr.Dataset( { "specific_humidity": (dims, arr), "air_temperature": (dims, arr), + "downward_shortwave": (dims_1d, arr_1d), + "net_shortwave": (dims_1d, arr_1d), + "downward_longwave": (dims_1d, arr_1d), "dQ1": (dims, arr), "dQ2": (dims, arr), "dQu": (dims, arr), @@ -27,80 +32,52 @@ def _model_dataset() -> xr.Dataset: return data -def _rad_model_dataset() -> xr.Dataset: - - arr = np.zeros((1,)) - dims = [ - "sample", - ] - - data = xr.Dataset( - { - "specific_humidity": (dims, arr), - "air_temperature": (dims, arr), - "downward_shortwave": (dims, arr), - "net_shortwave": (dims, arr), - "downward_longwave": (dims, arr), - } - ) - - return data - - -def get_mock_sklearn_model() -> fv3fit.Predictor: +def get_mock_sklearn_model(model_predictands: str) -> fv3fit.Predictor: data = _model_dataset() - nz = data.sizes["z"] - heating_constant_K_per_s = np.zeros(nz) - # include nonzero moistening to test for mass conservation - moistening_constant_per_s = -np.full(nz, 1e-4 / 86400) - wind_tendency_constant_m_per_s_per_s = np.full(nz, 1 / 86400) - constant = np.concatenate( - [ - heating_constant_K_per_s, - moistening_constant_per_s, - wind_tendency_constant_m_per_s_per_s, - wind_tendency_constant_m_per_s_per_s, - ] - ) - estimator = RegressorEnsemble( - DummyRegressor(strategy="constant", constant=constant) - ) - - model = SklearnWrapper( - "sample", - ["specific_humidity", "air_temperature"], - ["dQ1", "dQ2", "dQu", "dQv"], - estimator, - ) - - # needed to avoid sklearn.exceptions.NotFittedError - model.fit([data]) - return model - - -def get_mock_rad_flux_model() -> fv3fit.Predictor: - - data = _rad_model_dataset() - - n_sample = data.sizes["sample"] - downward_shortwave = np.full(n_sample, 300.0) - net_shortwave = np.full(n_sample, 250.0) - downward_longwave = np.full(n_sample, 400.0) - constant = np.concatenate( - [[downward_shortwave], [net_shortwave], [downward_longwave]] - ) - estimator = RegressorEnsemble( - DummyRegressor(strategy="constant", constant=constant) - ) - - model = SklearnWrapper( - "sample", - ["air_temperature", "specific_humidity"], - ["downward_shortwave", "net_shortwave", "downward_longwave"], - estimator, - ) + if model_predictands == "tendencies": + nz = data.sizes["z"] + heating_constant_K_per_s = np.zeros(nz) + # include nonzero moistening to test for mass conservation + moistening_constant_per_s = -np.full(nz, 1e-4 / 86400) + wind_tendency_constant_m_per_s_per_s = np.full(nz, 1 / 86400) + constant = np.concatenate( + [ + heating_constant_K_per_s, + moistening_constant_per_s, + wind_tendency_constant_m_per_s_per_s, + wind_tendency_constant_m_per_s_per_s, + ] + ) + estimator = RegressorEnsemble( + DummyRegressor(strategy="constant", constant=constant) + ) + model = SklearnWrapper( + "sample", + ["specific_humidity", "air_temperature"], + ["dQ1", "dQ2", "dQu", "dQv"], + estimator, + ) + elif model_predictands == "rad_fluxes": + n_sample = data.sizes["sample"] + downward_shortwave = np.full(n_sample, 300.0) + net_shortwave = np.full(n_sample, 250.0) + downward_longwave = np.full(n_sample, 400.0) + constant = np.concatenate( + [[downward_shortwave], [net_shortwave], [downward_longwave]] + ) + estimator = RegressorEnsemble( + DummyRegressor(strategy="constant", constant=constant) + ) + model = SklearnWrapper( + "sample", + ["air_temperature", "specific_humidity"], + ["downward_shortwave", "net_shortwave", "downward_longwave"], + estimator, + ) + else: + raise ValueError(f"Undefined mock model type: {model_predictands}") # needed to avoid sklearn.exceptions.NotFittedError model.fit([data]) diff --git a/workflows/prognostic_c48_run/tests/test_machine_learning.py b/workflows/prognostic_c48_run/tests/test_machine_learning.py index ca8a3c9f5f..8a5160723c 100644 --- a/workflows/prognostic_c48_run/tests/test_machine_learning.py +++ b/workflows/prognostic_c48_run/tests/test_machine_learning.py @@ -1,5 +1,5 @@ from runtime.steppers.machine_learning import PureMLStepper, MLStateStepper -from machine_learning_mocks import get_mock_sklearn_model, get_mock_rad_flux_model +from machine_learning_mocks import get_mock_sklearn_model import requests import xarray as xr import joblib @@ -32,27 +32,18 @@ def ml_stepper_name(request): @pytest.fixture -def mock_model(ml_stepper_name): - if ml_stepper_name == "PureMLStepper": - model = get_mock_sklearn_model() - elif ml_stepper_name == "MLStateStepper": - model = get_mock_rad_flux_model() - else: - raise ValueError("ML Stepper name not defined.") - return model - - -@pytest.fixture -def ml_stepper(ml_stepper_name, mock_model): +def ml_stepper(ml_stepper_name): timestep = 900 if ml_stepper_name == "PureMLStepper": - stepper = PureMLStepper(mock_model, timestep) + mock_model = get_mock_sklearn_model("tendencies") + ml_stepper = PureMLStepper(mock_model, timestep) elif ml_stepper_name == "MLStateStepper": - stepper = MLStateStepper(mock_model, timestep) - return stepper + mock_model = get_mock_sklearn_model("rad_fluxes") + ml_stepper = MLStateStepper(mock_model, timestep) + return ml_stepper -def test_MLStepper_schema_unchanged(state, ml_stepper, regtest): +def test_ml_steppers_schema_unchanged(state, ml_stepper, regtest): (tendencies, diagnostics, states) = ml_stepper(None, state) xr.Dataset(diagnostics).info(regtest) xr.Dataset(tendencies).info(regtest) @@ -64,7 +55,7 @@ def test_state_regression(state, regtest): print(checksum, file=regtest) -def test_MLStepper_regression_checksum(state, ml_stepper, regtest): +def test_ml_steppers_regression_checksum(state, ml_stepper, regtest): (tendencies, diagnostics, states) = ml_stepper(None, state) checksums = yaml.safe_dump( [ From dcd1e050f20cd129cf0a086ec1d0fd9b4772a544 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Mon, 15 Mar 2021 22:14:12 +0000 Subject: [PATCH 14/17] cleanup --- workflows/prognostic_c48_run/docs/config-usage.rst | 2 +- workflows/prognostic_c48_run/docs/configuration-api.rst | 2 ++ workflows/prognostic_c48_run/runtime/config.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/workflows/prognostic_c48_run/docs/config-usage.rst b/workflows/prognostic_c48_run/docs/config-usage.rst index c162e50cff..4d3c806460 100644 --- a/workflows/prognostic_c48_run/docs/config-usage.rst +++ b/workflows/prognostic_c48_run/docs/config-usage.rst @@ -108,7 +108,7 @@ It can be used multiple times to specify multiple models. For example:: --model_url path/to/model --model_url path/to_another/model > fv3config.yaml - + Diagnostics ~~~~~~~~~~~ diff --git a/workflows/prognostic_c48_run/docs/configuration-api.rst b/workflows/prognostic_c48_run/docs/configuration-api.rst index 3a5b1584ac..c64e0688f3 100644 --- a/workflows/prognostic_c48_run/docs/configuration-api.rst +++ b/workflows/prognostic_c48_run/docs/configuration-api.rst @@ -33,12 +33,14 @@ Top-level Python "Physics" ~~~~~~~~~~~~~~~~ + .. py:module:: runtime.steppers.machine_learning .. autoclass:: MachineLearningConfig .. py:module:: runtime.nudging .. autoclass:: NudgingConfig + Diagnostics ~~~~~~~~~~~ diff --git a/workflows/prognostic_c48_run/runtime/config.py b/workflows/prognostic_c48_run/runtime/config.py index 38de10cf90..01eb4f4a76 100644 --- a/workflows/prognostic_c48_run/runtime/config.py +++ b/workflows/prognostic_c48_run/runtime/config.py @@ -21,7 +21,6 @@ class UserConfig: """The top-level object for python runtime configurations Attributes: - diagnostics: list of diagnostic file configurations fortran_diagnostics: list of Fortran diagnostic outputs. Currently only used by post-processing and so only name and chunks items need to be specified. From e0e13a70e8479882cdae91d4086578998ab9ad5f Mon Sep 17 00:00:00 2001 From: brianhenn Date: Tue, 16 Mar 2021 04:14:42 +0000 Subject: [PATCH 15/17] address addl PR comments --- workflows/prognostic_c48_run/runtime/loop.py | 43 +++++++++---------- .../tests/machine_learning_mocks.py | 2 +- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 0f13b4d901..6d5fd0d69a 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -1,6 +1,7 @@ import datetime import json import os +import tempfile import logging from typing import ( Any, @@ -162,8 +163,7 @@ class TimeLoop(Iterable[Tuple[cftime.DatetimeJulian, Diagnostics]], LoggingMixin Each time step of the model evolutions proceeds like this:: step_dynamics, - compute_prephysics, - apply_prephysics, + step_prephysics, compute_physics, apply_postphysics_to_physics_state, apply_physics, @@ -172,7 +172,7 @@ class TimeLoop(Iterable[Tuple[cftime.DatetimeJulian, Diagnostics]], LoggingMixin The time loop relies on objects implementing the :py:class:`Stepper` interface to enable ML and other updates. The steppers compute their - updates in ``_compute_prephysics`` and ``_compute_postphysics``. The + updates in ``_step_prephysics`` and ``_compute_postphysics``. The ``TimeLoop`` controls when and how to apply these updates to the FV3 state. """ @@ -225,7 +225,7 @@ def _get_prephysics_stepper(self, config: UserConfig) -> Optional[Stepper]: config.prephysics, MachineLearningConfig ): self._log_info("Using MLStateStepper for prephysics") - model = self._open_model(config.prephysics, "_compute_prephysics") + model = self._open_model(config.prephysics, "_prephysics") stepper: Optional[Stepper] = MLStateStepper(model, self._timestep) else: self._log_info("No prephysics computations") @@ -235,7 +235,7 @@ def _get_prephysics_stepper(self, config: UserConfig) -> Optional[Stepper]: def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]: if config.scikit_learn.model: self._log_info("Using MLStepper for postphysics updates") - model = self._open_model(config.scikit_learn, "_compute_postphysics") + model = self._open_model(config.scikit_learn, "_postphysics") stepper: Optional[Stepper] = PureMLStepper(model, self._timestep) elif config.nudging: self._log_info("Using NudgingStepper for postphysics updates") @@ -251,16 +251,17 @@ def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]: def _open_model(self, ml_config: MachineLearningConfig, step: str): self._log_info("Downloading ML Model") - if self.rank == 0: - local_model_paths = download_model( - ml_config, os.path.join(step, "ml_model") - ) - else: - local_model_paths = None # type: ignore - local_model_paths = self.comm.bcast(local_model_paths, root=0) - setattr(ml_config, "model", local_model_paths) - self._log_info("Model Downloaded From Remote") - model = open_model(ml_config) + with tempfile.TemporaryDirectory() as tmpdir: + if self.rank == 0: + local_model_paths = download_model( + ml_config, os.path.join(tmpdir, step) + ) + else: + local_model_paths = None # type: ignore + local_model_paths = self.comm.bcast(local_model_paths, root=0) + setattr(ml_config, "model", local_model_paths) + self._log_info("Model Downloaded From Remote") + model = open_model(ml_config) self._log_info("Model Loaded") return model @@ -336,8 +337,7 @@ def _print_global_timings(self, root=0): def _substeps(self) -> Sequence[Callable[..., Diagnostics]]: return [ self._step_dynamics, - self._compute_prephysics, - self._apply_prephysics, + self._step_prephysics, self._compute_physics, self._apply_postphysics_to_physics_state, self._apply_physics, @@ -345,7 +345,8 @@ def _substeps(self) -> Sequence[Callable[..., Diagnostics]]: self._apply_postphysics_to_dycore_state, ] - def _compute_prephysics(self) -> Diagnostics: + def _step_prephysics(self) -> Diagnostics: + if self._prephysics_stepper is None: diagnostics: Diagnostics = {} else: @@ -357,9 +358,6 @@ def _compute_prephysics(self) -> Diagnostics: rename_diagnostics(diagnostics) else: self._state_updates.update(state_updates) - return diagnostics - - def _apply_prephysics(self): prephysics_overrides = [ "total_sky_downward_shortwave_flux_at_surface_override", "total_sky_net_shortwave_flux_at_surface_override", @@ -374,7 +372,8 @@ def _apply_prephysics(self): ) updated_state = assign_attrs_from(self._state, state_updates) self._state.update_mass_conserving(updated_state) - return {} + + return diagnostics def _apply_postphysics_to_physics_state(self) -> Diagnostics: """Apply computed tendencies and state updates to the physics state diff --git a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py index 0fbd4dd487..266e20c6d6 100644 --- a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py +++ b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py @@ -32,7 +32,7 @@ def _model_dataset() -> xr.Dataset: return data -def get_mock_sklearn_model(model_predictands: str) -> fv3fit.Predictor: +def get_mock_sklearn_model(model_predictands: str = "tendencies") -> fv3fit.Predictor: data = _model_dataset() From 73a80d0618ed408a029f434b84a3f1060bbfd13b Mon Sep 17 00:00:00 2001 From: brianhenn Date: Tue, 23 Mar 2021 20:25:40 +0000 Subject: [PATCH 16/17] update to use master wrapper --- external/fv3gfs-wrapper | 2 +- .../runtime/diagnostics/machine_learning.py | 6 ++-- workflows/prognostic_c48_run/runtime/loop.py | 36 +++++++++++-------- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/external/fv3gfs-wrapper b/external/fv3gfs-wrapper index 6c4d01c937..1d049991c7 160000 --- a/external/fv3gfs-wrapper +++ b/external/fv3gfs-wrapper @@ -1 +1 @@ -Subproject commit 6c4d01c9379f410a5d629dbb4d9f91e3552b10f7 +Subproject commit 1d049991c7c7bbddebff3b52b5ac8820d0f5e816 diff --git a/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py b/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py index c548baeb35..08b8846a65 100644 --- a/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py @@ -111,9 +111,9 @@ def rename_diagnostics(diags: Diagnostics): "net_heating", "column_integrated_dQu", "column_integrated_dQv", - "total_sky_downward_shortwave_flux_at_surface_override", - "total_sky_net_shortwave_flux_at_surface_override", - "total_sky_downward_longwave_flux_at_surface_override", + "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface", } ml_tendencies_in_diags = ml_tendencies & set(diags) for variable in ml_tendencies_in_diags: diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 6d5fd0d69a..108638607f 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -1,7 +1,8 @@ import datetime import json import os -import tempfile + +# import tempfile import logging from typing import ( Any, @@ -251,17 +252,22 @@ def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]: def _open_model(self, ml_config: MachineLearningConfig, step: str): self._log_info("Downloading ML Model") - with tempfile.TemporaryDirectory() as tmpdir: - if self.rank == 0: - local_model_paths = download_model( - ml_config, os.path.join(tmpdir, step) - ) - else: - local_model_paths = None # type: ignore - local_model_paths = self.comm.bcast(local_model_paths, root=0) - setattr(ml_config, "model", local_model_paths) - self._log_info("Model Downloaded From Remote") - model = open_model(ml_config) + # with tempfile.TemporaryDirectory() as tmpdir: + # self._log_info(f"Model Downloading to {tmpdir}") + self._log_info(f"Model Downloading to {step}") + self._log_info(f"current working directory {os.getcwd()}") + if self.rank == 0: + local_model_paths = download_model( + # ml_config, os.path.join(tmpdir, step) + ml_config, + step, + ) + else: + local_model_paths = None # type: ignore + local_model_paths = self.comm.bcast(local_model_paths, root=0) + setattr(ml_config, "model", local_model_paths) + self._log_info("Model Downloaded From Remote") + model = open_model(ml_config) self._log_info("Model Loaded") return model @@ -359,9 +365,9 @@ def _step_prephysics(self) -> Diagnostics: else: self._state_updates.update(state_updates) prephysics_overrides = [ - "total_sky_downward_shortwave_flux_at_surface_override", - "total_sky_net_shortwave_flux_at_surface_override", - "total_sky_downward_longwave_flux_at_surface_override", + "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface", ] state_updates = { k: v for k, v in self._state_updates.items() if k in prephysics_overrides From f8fa385ac5e0a6ecdb9ad771e8ff2dfd5b1246d4 Mon Sep 17 00:00:00 2001 From: brianhenn Date: Tue, 23 Mar 2021 22:25:49 +0000 Subject: [PATCH 17/17] fix regression test model caching problem --- workflows/prognostic_c48_run/runtime/loop.py | 31 ++++++++----------- .../tests/test_regression.py | 5 +-- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 108638607f..d9a4b56950 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -1,8 +1,7 @@ import datetime import json import os - -# import tempfile +import tempfile import logging from typing import ( Any, @@ -252,22 +251,18 @@ def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]: def _open_model(self, ml_config: MachineLearningConfig, step: str): self._log_info("Downloading ML Model") - # with tempfile.TemporaryDirectory() as tmpdir: - # self._log_info(f"Model Downloading to {tmpdir}") - self._log_info(f"Model Downloading to {step}") - self._log_info(f"current working directory {os.getcwd()}") - if self.rank == 0: - local_model_paths = download_model( - # ml_config, os.path.join(tmpdir, step) - ml_config, - step, - ) - else: - local_model_paths = None # type: ignore - local_model_paths = self.comm.bcast(local_model_paths, root=0) - setattr(ml_config, "model", local_model_paths) - self._log_info("Model Downloaded From Remote") - model = open_model(ml_config) + with tempfile.TemporaryDirectory() as tmpdir: + if self.rank == 0: + local_model_paths = download_model( + ml_config, os.path.join(tmpdir, step) + ) + else: + local_model_paths = None # type: ignore + local_model_paths = self.comm.bcast(local_model_paths, root=0) + setattr(ml_config, "model", local_model_paths) + self._log_info("Model Downloaded From Remote") + model = open_model(ml_config) + MPI.COMM_WORLD.barrier() self._log_info("Model Loaded") return model diff --git a/workflows/prognostic_c48_run/tests/test_regression.py b/workflows/prognostic_c48_run/tests/test_regression.py index 56411cd096..dcee04420b 100644 --- a/workflows/prognostic_c48_run/tests/test_regression.py +++ b/workflows/prognostic_c48_run/tests/test_regression.py @@ -1,6 +1,7 @@ from pathlib import Path import json import fv3config +import fv3fit import runtime.metrics import tempfile import numpy as np @@ -443,11 +444,11 @@ def completed_rundir(configuration, tmpdir_factory): if configuration == ConfigEnum.sklearn: model = get_mock_sklearn_model() - model.dump(str(model_path)) + fv3fit.dump(model, str(model_path)) config = get_ml_config(model_path) elif configuration == ConfigEnum.keras: model = get_mock_keras_model() - model.dump(str(model_path)) + fv3fit.dump(model, str(model_path)) config = get_ml_config(model_path) elif configuration == ConfigEnum.nudging: config = get_nudging_config()