diff --git a/external/fv3gfs-fortran b/external/fv3gfs-fortran index 8e9f185329..75605127be 160000 --- a/external/fv3gfs-fortran +++ b/external/fv3gfs-fortran @@ -1 +1 @@ -Subproject commit 8e9f18532998dcf4e144f93b7689c25686cec169 +Subproject commit 75605127be63132c1d16fb6da6e70f60f88e2d40 diff --git a/external/fv3gfs-wrapper b/external/fv3gfs-wrapper index 32878b2c77..1d049991c7 160000 --- a/external/fv3gfs-wrapper +++ b/external/fv3gfs-wrapper @@ -1 +1 @@ -Subproject commit 32878b2c777d25b21c86274435a0757762dfa65d +Subproject commit 1d049991c7c7bbddebff3b52b5ac8820d0f5e816 diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out index a3f9cb46fb..0c64cc3f15 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression.out @@ -1530,6 +1530,7 @@ namelist: ldebug: false nudging: null orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out index 5fa8807124..5a2e0a46e0 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudge_to_obs_config_regression.out @@ -1719,6 +1719,7 @@ namelist: ldebug: false nudging: null orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out index e7f3130315..d202cf74cc 100644 --- a/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out +++ b/workflows/prognostic_c48_run/_regtest_outputs/test_prepare_config.test_prepare_nudging_config_regression.out @@ -1615,6 +1615,7 @@ nudging: x_wind: 12 y_wind: 12 orographic_forcing: gs://vcm-fv3config/data/orographic_data/v1.0 +prephysics: null scikit_learn: diagnostic_ml: false input_standard_names: {} diff --git a/workflows/prognostic_c48_run/prepare_config.py b/workflows/prognostic_c48_run/prepare_config.py index e6e27064a7..ed2f418ad8 100644 --- a/workflows/prognostic_c48_run/prepare_config.py +++ b/workflows/prognostic_c48_run/prepare_config.py @@ -99,6 +99,13 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig: nudge_to_observations = ( config_dict.get("namelist", {}).get("fv_core_nml", {}).get("nudge", False) ) + + prephysics: Optional[MachineLearningConfig] + if "prephysics" in config_dict: + prephysics = dacite.from_dict(MachineLearningConfig, config_dict["prephysics"]) + else: + prephysics = None + nudging: Optional[NudgingConfig] if "nudging" in config_dict: config_dict["nudging"]["restarts_path"] = config_dict["nudging"].get( @@ -138,6 +145,7 @@ def user_config_from_dict_and_args(config_dict: dict, args) -> UserConfig: ) return UserConfig( + prephysics=prephysics, nudging=nudging, diagnostics=diagnostics, fortran_diagnostics=fortran_diagnostics, diff --git a/workflows/prognostic_c48_run/runtime/config.py b/workflows/prognostic_c48_run/runtime/config.py index a6e50c4893..01eb4f4a76 100644 --- a/workflows/prognostic_c48_run/runtime/config.py +++ b/workflows/prognostic_c48_run/runtime/config.py @@ -24,6 +24,8 @@ class UserConfig: diagnostics: list of diagnostic file configurations fortran_diagnostics: list of Fortran diagnostic outputs. Currently only used by post-processing and so only name and chunks items need to be specified. + prephysics: optional configuration of computations prior to physics, + specified by a machine learning configuation scikit_learn: a machine learning configuration nudging: nudge2fine configuration. Cannot be used if any scikit_learn model urls are specified. @@ -36,6 +38,7 @@ class UserConfig: diagnostics: List[DiagnosticFileConfig] fortran_diagnostics: List[FortranFileConfig] + prephysics: Optional[MachineLearningConfig] = None scikit_learn: MachineLearningConfig = MachineLearningConfig() nudging: Optional[NudgingConfig] = None step_tendency_variables: List[str] = dataclasses.field( diff --git a/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py b/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py index 625f0fbadd..c7fdd324e9 100644 --- a/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/diagnostics/machine_learning.py @@ -123,12 +123,15 @@ def rename_diagnostics(diags: Diagnostics): "net_heating", "column_integrated_dQu", "column_integrated_dQv", + "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface", } ml_tendencies_in_diags = ml_tendencies & set(diags) for variable in ml_tendencies_in_diags: attrs = diags[variable].attrs diags[f"{variable}_diagnostic"] = diags[variable].assign_attrs( - description=attrs["description"] + " (diagnostic only)" + description=attrs.get("description", "") + " (diagnostic only)" ) diags[variable] = xr.zeros_like(diags[variable]).assign_attrs(attrs) diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 1336241594..d9a4b56950 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -1,7 +1,18 @@ import datetime import json +import os +import tempfile import logging -from typing import Any, Callable, Iterable, List, Mapping, Optional, Sequence, Tuple +from typing import ( + Any, + Callable, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, +) import cftime import fv3gfs.util @@ -14,12 +25,16 @@ from runtime.diagnostics.machine_learning import ( compute_baseline_diagnostics, rename_diagnostics, -) -from runtime.diagnostics.machine_learning import ( precipitation_rate, precipitation_sum, ) -from runtime.steppers.machine_learning import PureMLStepper, open_model, download_model +from runtime.steppers.machine_learning import ( + PureMLStepper, + open_model, + download_model, + MachineLearningConfig, + MLStateStepper, +) from runtime.steppers.nudging import PureNudger from runtime.types import Diagnostics, State, Tendencies from runtime.names import TENDENCY_TO_STATE_NAME @@ -114,6 +129,15 @@ def add_tendency(state: Any, tendency: State, dt: float) -> State: return updated # type: ignore +def assign_attrs_from(src: Any, dst: State) -> State: + """Given src state and a dst state, return dst state with src attrs + """ + updated = {} + for name in dst: + updated[name] = dst[name].assign_attrs(src[name].attrs) + return updated # type: ignore + + class LoggingMixin: rank: int @@ -139,16 +163,17 @@ class TimeLoop(Iterable[Tuple[cftime.DatetimeJulian, Diagnostics]], LoggingMixin Each time step of the model evolutions proceeds like this:: step_dynamics, + step_prephysics, compute_physics, - apply_python_to_physics_state - apply_physics - compute_python_updates - apply_python_to_dycore_state + apply_postphysics_to_physics_state, + apply_physics, + compute_postphysics, + apply_postphysics_to_dycore_state, The time loop relies on objects implementing the :py:class:`Stepper` interface to enable ML and other updates. The steppers compute their - updates in ``_compute_python_updates``. The ``TimeLoop`` controls when - and how to apply these updates to the FV3 state. + updates in ``_step_prephysics`` and ``_compute_postphysics``. The + ``TimeLoop`` controls when and how to apply these updates to the FV3 state. """ def __init__( @@ -171,13 +196,17 @@ def __init__( self._timestep = timestep self._log_info(f"Timestep: {timestep}") - self._do_only_diagnostic_ml = config.scikit_learn.diagnostic_ml + self._prephysics_only_diagnostic_ml: bool = getattr( + getattr(config, "prephysics"), "diagnostic_ml", False + ) + self._postphysics_only_diagnostic_ml: bool = config.scikit_learn.diagnostic_ml self._tendencies: Tendencies = {} self._state_updates: State = {} self._states_to_output: Sequence[str] = self._get_states_to_output(config) self._log_debug(f"States to output: {self._states_to_output}") - self.stepper = self._get_stepper(config) + self._prephysics_stepper = self._get_prephysics_stepper(config) + self._postphysics_stepper = self._get_postphysics_stepper(config) self._log_info(self._fv3gfs.get_tracer_metadata()) MPI.COMM_WORLD.barrier() # wait for initialization to finish @@ -191,30 +220,51 @@ def _get_states_to_output(self, config: UserConfig) -> Sequence[str]: states_to_output = diagnostic.variables # type: ignore return states_to_output - def _get_stepper(self, config: UserConfig) -> Optional[Stepper]: + def _get_prephysics_stepper(self, config: UserConfig) -> Optional[Stepper]: + if config.prephysics is not None and isinstance( + config.prephysics, MachineLearningConfig + ): + self._log_info("Using MLStateStepper for prephysics") + model = self._open_model(config.prephysics, "_prephysics") + stepper: Optional[Stepper] = MLStateStepper(model, self._timestep) + else: + self._log_info("No prephysics computations") + stepper = None + return stepper + + def _get_postphysics_stepper(self, config: UserConfig) -> Optional[Stepper]: if config.scikit_learn.model: - self._log_info("Using MLStepper") - self._log_info("Downloading ML Model") - if self.rank == 0: - local_model_paths = download_model(config.scikit_learn, "ml_model") - else: - local_model_paths = None # type: ignore - local_model_paths = self.comm.bcast(local_model_paths, root=0) - setattr(config.scikit_learn, "model", local_model_paths) - self._log_info("Model Downloaded From Remote") - model = open_model(config.scikit_learn) - self._log_info("Model Loaded") - return PureMLStepper(model, self._timestep) + self._log_info("Using MLStepper for postphysics updates") + model = self._open_model(config.scikit_learn, "_postphysics") + stepper: Optional[Stepper] = PureMLStepper(model, self._timestep) elif config.nudging: - self._log_info("Using NudgingStepper") + self._log_info("Using NudgingStepper for postphysics updates") partitioner = fv3gfs.util.CubedSpherePartitioner.from_namelist( get_namelist() ) communicator = fv3gfs.util.CubedSphereCommunicator(self.comm, partitioner) - return PureNudger(config.nudging, communicator) + stepper = PureNudger(config.nudging, communicator) else: self._log_info("Performing baseline simulation") - return None + stepper = None + return stepper + + def _open_model(self, ml_config: MachineLearningConfig, step: str): + self._log_info("Downloading ML Model") + with tempfile.TemporaryDirectory() as tmpdir: + if self.rank == 0: + local_model_paths = download_model( + ml_config, os.path.join(tmpdir, step) + ) + else: + local_model_paths = None # type: ignore + local_model_paths = self.comm.bcast(local_model_paths, root=0) + setattr(ml_config, "model", local_model_paths) + self._log_info("Model Downloaded From Remote") + model = open_model(ml_config) + MPI.COMM_WORLD.barrier() + self._log_info("Model Loaded") + return model @property def time(self) -> cftime.DatetimeJulian: @@ -288,26 +338,59 @@ def _print_global_timings(self, root=0): def _substeps(self) -> Sequence[Callable[..., Diagnostics]]: return [ self._step_dynamics, + self._step_prephysics, self._compute_physics, - self._apply_python_to_physics_state, + self._apply_postphysics_to_physics_state, self._apply_physics, - self._compute_python_updates, - self._apply_python_to_dycore_state, + self._compute_postphysics, + self._apply_postphysics_to_dycore_state, ] - def _apply_python_to_physics_state(self) -> Diagnostics: + def _step_prephysics(self) -> Diagnostics: + + if self._prephysics_stepper is None: + diagnostics: Diagnostics = {} + else: + self._log_debug("Computing prephysics updates") + _, diagnostics, state_updates = self._prephysics_stepper( + self._state.time, self._state + ) + if self._prephysics_only_diagnostic_ml: + rename_diagnostics(diagnostics) + else: + self._state_updates.update(state_updates) + prephysics_overrides = [ + "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_net_shortwave_flux_at_surface", + "override_for_time_adjusted_total_sky_downward_longwave_flux_at_surface", + ] + state_updates = { + k: v for k, v in self._state_updates.items() if k in prephysics_overrides + } + self._state_updates = dissoc(self._state_updates, *prephysics_overrides) + self._log_debug( + f"Applying prephysics state updates for: {list(state_updates.keys())}" + ) + updated_state = assign_attrs_from(self._state, state_updates) + self._state.update_mass_conserving(updated_state) + + return diagnostics + + def _apply_postphysics_to_physics_state(self) -> Diagnostics: """Apply computed tendencies and state updates to the physics state Mostly used for updating the eastward and northward winds. """ - self._log_debug(f"Apply python tendencies to physics state") + self._log_debug(f"Apply postphysics tendencies to physics state") tendency = {k: v for k, v in self._tendencies.items() if k in ["dQu", "dQv"]} diagnostics: Diagnostics = {} - if self.stepper is not None: - diagnostics = self.stepper.get_momentum_diagnostics(self._state, tendency) - if self._do_only_diagnostic_ml: + if self._postphysics_stepper is not None: + diagnostics = self._postphysics_stepper.get_momentum_diagnostics( + self._state, tendency + ) + if self._postphysics_only_diagnostic_ml: rename_diagnostics(diagnostics) else: updated_state = add_tendency(self._state, tendency, dt=self._timestep) @@ -315,15 +398,17 @@ def _apply_python_to_physics_state(self) -> Diagnostics: return diagnostics - def _compute_python_updates(self) -> Diagnostics: - self._log_info("Computing Python Updates") + def _compute_postphysics(self) -> Diagnostics: + self._log_info("Computing Postphysics Updates") - if self.stepper is None: + if self._postphysics_stepper is None: return {} else: - (self._tendencies, diagnostics, self._state_updates,) = self.stepper( - self._state.time, self._state - ) + ( + self._tendencies, + diagnostics, + self._state_updates, + ) = self._postphysics_stepper(self._state.time, self._state) try: rank_updated_points = diagnostics["rank_updated_points"] except KeyError: @@ -340,21 +425,23 @@ def _compute_python_updates(self) -> Diagnostics: ) return diagnostics - def _apply_python_to_dycore_state(self) -> Diagnostics: + def _apply_postphysics_to_dycore_state(self) -> Diagnostics: tendency = dissoc(self._tendencies, "dQu", "dQv") - if self.stepper is None: + if self._postphysics_stepper is None: diagnostics = compute_baseline_diagnostics(self._state) else: - diagnostics = self.stepper.get_diagnostics(self._state, tendency) - if self._do_only_diagnostic_ml: + diagnostics = self._postphysics_stepper.get_diagnostics( + self._state, tendency + ) + if self._postphysics_only_diagnostic_ml: rename_diagnostics(diagnostics) else: updated_state = add_tendency(self._state, tendency, dt=self._timestep) updated_state[TOTAL_PRECIP] = precipitation_sum( self._state[TOTAL_PRECIP], - diagnostics[self.stepper.net_moistening], + diagnostics[self._postphysics_stepper.net_moistening], self._timestep, ) diagnostics[TOTAL_PRECIP] = updated_state[TOTAL_PRECIP] @@ -464,8 +551,8 @@ def __init__( self._storage_variables = list(storage_variables) _apply_physics = monitor("fv3_physics", TimeLoop._apply_physics) - _apply_python_to_dycore_state = monitor( - "python", TimeLoop._apply_python_to_dycore_state + _apply_postphysics_to_dycore_state = monitor( + "python", TimeLoop._apply_postphysics_to_dycore_state ) diff --git a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py index bbdaea1755..bef78ac7c5 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py +++ b/workflows/prognostic_c48_run/runtime/steppers/machine_learning.py @@ -155,7 +155,7 @@ def download_model(config: MachineLearningConfig, path: str) -> Sequence[str]: def predict(model: MultiModelAdapter, state: State) -> State: - """Given ML model and state, return tendency prediction.""" + """Given ML model and state, return prediction""" state_loaded = {key: state[key] for key in model.input_variables} ds = xr.Dataset(state_loaded) # type: ignore output = model.predict_columnwise(ds, feature_dim="z") @@ -212,3 +212,20 @@ def get_diagnostics(self, state, tendency): def get_momentum_diagnostics(self, state, tendency): return runtime.compute_ml_momentum_diagnostics(state, tendency) + + +class MLStateStepper(PureMLStepper): + def __call__(self, time, state): + + diagnostics: Diagnostics = {} + state_updates: State = predict(self.model, state) + + for name in state_updates.keys(): + diagnostics[name] = state_updates[name] + + tendency = {} + return ( + tendency, + diagnostics, + state_updates, + ) diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[MLStateStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[MLStateStepper].out new file mode 100644 index 0000000000..3e2b87b688 --- /dev/null +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[MLStateStepper].out @@ -0,0 +1,17 @@ +- - tendencies + - [] +- - diagnostics + - - - downward_longwave + - 8b886991eea4c48475709bca29505185 + - - downward_shortwave + - 47735602b45938d453a31013ef9410ba + - - net_shortwave + - 01b249036fc7c5eec78879aa849ec524 +- - states + - - - downward_longwave + - 8b886991eea4c48475709bca29505185 + - - downward_shortwave + - 47735602b45938d453a31013ef9410ba + - - net_shortwave + - 01b249036fc7c5eec78879aa849ec524 + diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_regression_checksum.out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[PureMLStepper].out similarity index 96% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_regression_checksum.out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[PureMLStepper].out index 7967633322..bd5e48845c 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_regression_checksum.out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_regression_checksum[PureMLStepper].out @@ -14,4 +14,6 @@ - fcc46bebe36ea131688f8e15700e18d4 - - rank_updated_points - 261cfa795cd717bb3a4b2bde266ba20a +- - states + - [] diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[MLStateStepper].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[MLStateStepper].out new file mode 100644 index 0000000000..adb9b74035 --- /dev/null +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[MLStateStepper].out @@ -0,0 +1,29 @@ +xarray.Dataset { +dimensions: + x = 4 ; + y = 4 ; + +variables: + float64 downward_longwave(y, x) ; + float64 downward_shortwave(y, x) ; + float64 net_shortwave(y, x) ; + +// global attributes: +}xarray.Dataset { +dimensions: + +variables: + +// global attributes: +}xarray.Dataset { +dimensions: + x = 4 ; + y = 4 ; + +variables: + float64 downward_longwave(y, x) ; + float64 downward_shortwave(y, x) ; + float64 net_shortwave(y, x) ; + +// global attributes: +} \ No newline at end of file diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_schema_unchanged.out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[PureMLStepper].out similarity index 91% rename from workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_schema_unchanged.out rename to workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[PureMLStepper].out index 2fa482b289..9be83ff889 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_PureMLStepper_schema_unchanged.out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_machine_learning.test_ml_steppers_schema_unchanged[PureMLStepper].out @@ -25,5 +25,11 @@ variables: float64 dQu(z, y, x) ; float64 dQv(z, y, x) ; +// global attributes: +}xarray.Dataset { +dimensions: + +variables: + // global attributes: } \ No newline at end of file diff --git a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py index 1a94e8eb4a..266e20c6d6 100644 --- a/workflows/prognostic_c48_run/tests/machine_learning_mocks.py +++ b/workflows/prognostic_c48_run/tests/machine_learning_mocks.py @@ -11,12 +11,17 @@ def _model_dataset() -> xr.Dataset: nz = 63 arr = np.zeros((1, nz)) + arr_1d = np.zeros((1,)) dims = ["sample", "z"] + dims_1d = ["sample"] data = xr.Dataset( { "specific_humidity": (dims, arr), "air_temperature": (dims, arr), + "downward_shortwave": (dims_1d, arr_1d), + "net_shortwave": (dims_1d, arr_1d), + "downward_longwave": (dims_1d, arr_1d), "dQ1": (dims, arr), "dQ2": (dims, arr), "dQu": (dims, arr), @@ -27,33 +32,52 @@ def _model_dataset() -> xr.Dataset: return data -def get_mock_sklearn_model() -> fv3fit.Predictor: +def get_mock_sklearn_model(model_predictands: str = "tendencies") -> fv3fit.Predictor: data = _model_dataset() - nz = data.sizes["z"] - heating_constant_K_per_s = np.zeros(nz) - # include nonzero moistening to test for mass conservation - moistening_constant_per_s = -np.full(nz, 1e-4 / 86400) - wind_tendency_constant_m_per_s_per_s = np.full(nz, 1 / 86400) - constant = np.concatenate( - [ - heating_constant_K_per_s, - moistening_constant_per_s, - wind_tendency_constant_m_per_s_per_s, - wind_tendency_constant_m_per_s_per_s, - ] - ) - estimator = RegressorEnsemble( - DummyRegressor(strategy="constant", constant=constant) - ) - - model = SklearnWrapper( - "sample", - ["specific_humidity", "air_temperature"], - ["dQ1", "dQ2", "dQu", "dQv"], - estimator, - ) + if model_predictands == "tendencies": + nz = data.sizes["z"] + heating_constant_K_per_s = np.zeros(nz) + # include nonzero moistening to test for mass conservation + moistening_constant_per_s = -np.full(nz, 1e-4 / 86400) + wind_tendency_constant_m_per_s_per_s = np.full(nz, 1 / 86400) + constant = np.concatenate( + [ + heating_constant_K_per_s, + moistening_constant_per_s, + wind_tendency_constant_m_per_s_per_s, + wind_tendency_constant_m_per_s_per_s, + ] + ) + estimator = RegressorEnsemble( + DummyRegressor(strategy="constant", constant=constant) + ) + model = SklearnWrapper( + "sample", + ["specific_humidity", "air_temperature"], + ["dQ1", "dQ2", "dQu", "dQv"], + estimator, + ) + elif model_predictands == "rad_fluxes": + n_sample = data.sizes["sample"] + downward_shortwave = np.full(n_sample, 300.0) + net_shortwave = np.full(n_sample, 250.0) + downward_longwave = np.full(n_sample, 400.0) + constant = np.concatenate( + [[downward_shortwave], [net_shortwave], [downward_longwave]] + ) + estimator = RegressorEnsemble( + DummyRegressor(strategy="constant", constant=constant) + ) + model = SklearnWrapper( + "sample", + ["air_temperature", "specific_humidity"], + ["downward_shortwave", "net_shortwave", "downward_longwave"], + estimator, + ) + else: + raise ValueError(f"Undefined mock model type: {model_predictands}") # needed to avoid sklearn.exceptions.NotFittedError model.fit([data]) diff --git a/workflows/prognostic_c48_run/tests/test_machine_learning.py b/workflows/prognostic_c48_run/tests/test_machine_learning.py index 565684c39a..978edc05a4 100644 --- a/workflows/prognostic_c48_run/tests/test_machine_learning.py +++ b/workflows/prognostic_c48_run/tests/test_machine_learning.py @@ -1,4 +1,4 @@ -from runtime.steppers.machine_learning import PureMLStepper +from runtime.steppers.machine_learning import PureMLStepper, MLStateStepper from machine_learning_mocks import get_mock_sklearn_model import requests import xarray as xr @@ -16,12 +16,28 @@ def state(tmp_path_factory): return xr.open_dataset(str(lpath)) -def test_PureMLStepper_schema_unchanged(state, regtest): - model = get_mock_sklearn_model() +@pytest.fixture(params=["PureMLStepper", "MLStateStepper"]) +def ml_stepper_name(request): + return request.param + + +@pytest.fixture +def ml_stepper(ml_stepper_name): timestep = 900 - (tendencies, diagnostics, _,) = PureMLStepper(model, timestep)(None, state) + if ml_stepper_name == "PureMLStepper": + mock_model = get_mock_sklearn_model("tendencies") + ml_stepper = PureMLStepper(mock_model, timestep) + elif ml_stepper_name == "MLStateStepper": + mock_model = get_mock_sklearn_model("rad_fluxes") + ml_stepper = MLStateStepper(mock_model, timestep) + return ml_stepper + + +def test_ml_steppers_schema_unchanged(state, ml_stepper, regtest): + (tendencies, diagnostics, states) = ml_stepper(None, state) xr.Dataset(diagnostics).info(regtest) xr.Dataset(tendencies).info(regtest) + xr.Dataset(states).info(regtest) def test_state_regression(state, regtest): @@ -29,14 +45,13 @@ def test_state_regression(state, regtest): print(checksum, file=regtest) -def test_PureMLStepper_regression_checksum(state, regtest): - model = get_mock_sklearn_model() - timestep = 900 - (tendencies, diagnostics, _,) = PureMLStepper(model, timestep)(None, state) +def test_ml_steppers_regression_checksum(state, ml_stepper, regtest): + (tendencies, diagnostics, states) = ml_stepper(None, state) checksums = yaml.safe_dump( [ ("tendencies", vcm.testing.checksum_dataarray_mapping(tendencies)), ("diagnostics", vcm.testing.checksum_dataarray_mapping(diagnostics)), + ("states", vcm.testing.checksum_dataarray_mapping(states)), ] ) diff --git a/workflows/prognostic_c48_run/tests/test_regression.py b/workflows/prognostic_c48_run/tests/test_regression.py index 56411cd096..dcee04420b 100644 --- a/workflows/prognostic_c48_run/tests/test_regression.py +++ b/workflows/prognostic_c48_run/tests/test_regression.py @@ -1,6 +1,7 @@ from pathlib import Path import json import fv3config +import fv3fit import runtime.metrics import tempfile import numpy as np @@ -443,11 +444,11 @@ def completed_rundir(configuration, tmpdir_factory): if configuration == ConfigEnum.sklearn: model = get_mock_sklearn_model() - model.dump(str(model_path)) + fv3fit.dump(model, str(model_path)) config = get_ml_config(model_path) elif configuration == ConfigEnum.keras: model = get_mock_keras_model() - model.dump(str(model_path)) + fv3fit.dump(model, str(model_path)) config = get_ml_config(model_path) elif configuration == ConfigEnum.nudging: config = get_nudging_config()