diff --git a/imod/mf6/boundary_condition.py b/imod/mf6/boundary_condition.py index afe63ffa7..29dffc42d 100644 --- a/imod/mf6/boundary_condition.py +++ b/imod/mf6/boundary_condition.py @@ -202,12 +202,23 @@ def _get_unfiltered_pkg_options( options = copy(predefined_options) if not_options is None: - not_options = self._get_period_varnames() + not_options = [] + if hasattr(self, "_period_data"): + not_options.extend(self._period_data) + if hasattr(self, "_auxiliary_data"): + not_options.extend(get_variable_names(self)) + not_options.extend(self._auxiliary_data.keys()) for varname in self.dataset.data_vars.keys(): # pylint:disable=no-member if varname in not_options: continue - v = self.dataset[varname].values[()] + # TODO: can we easily avoid this try-except? + # On which keys does it fail? + try: + v = self.dataset[varname].item() + except ValueError: + # Apparently not a scalar, therefore not an option entry. + pass options[varname] = v return options diff --git a/imod/mf6/hfb.py b/imod/mf6/hfb.py index 32d594d00..12e8e327e 100644 --- a/imod/mf6/hfb.py +++ b/imod/mf6/hfb.py @@ -33,6 +33,7 @@ from imod.mf6.disv import VerticesDiscretization from imod.mf6.mf6_hfb_adapter import Mf6HorizontalFlowBarrier from imod.mf6.package import Package +from imod.mf6.utilities.zarr_helper import to_zarr from imod.mf6.validation_settings import ValidationSettings from imod.prepare.cleanup import cleanup_hfb from imod.schemata import ( @@ -562,6 +563,11 @@ def to_netcdf( new.dataset["geometry"] = new.line_data.to_json() new.dataset.to_netcdf(*args, **kwargs) + def to_zarr(self, path, engine: str, **kwargs): + new = deepcopy(self) + new.dataset["geometry"] = new.line_data.to_json() + to_zarr(new.dataset, path, engine, **kwargs) + def _netcdf_encoding(self): return {"geometry": {"dtype": "str"}} diff --git a/imod/mf6/model.py b/imod/mf6/model.py index 8f3888abb..2cd0d38e2 100644 --- a/imod/mf6/model.py +++ b/imod/mf6/model.py @@ -592,6 +592,7 @@ def dump( validate: bool = True, mdal_compliant: bool = False, crs: Optional[Any] = None, + engine="netCDF4", ): """ Dump simulation to files. Writes a model definition as .TOML file, which @@ -615,6 +616,8 @@ def dump( crs: Any, optional Anything accepted by rasterio.crs.CRS.from_user_input Requires ``rioxarray`` installed. + engine: str, optional + "netCDF4" or "zarr" or "zarr.zip". Defaults to "netCDF4". """ modeldirectory = pathlib.Path(directory) / modelname modeldirectory.mkdir(exist_ok=True, parents=True) @@ -624,13 +627,26 @@ def dump( if statusinfo.has_errors(): raise ValidationError(statusinfo.to_string()) + match engine: + case "netCDF4": + ext = "nc" + case "zarr": + ext = "zarr" + case "zarr.zip": + ext = "zarr.zip" + case _: + raise ValueError(f"Unknown engine: {engine}") + toml_content: dict = collections.defaultdict(dict) for pkgname, pkg in self.items(): - pkg_path = f"{pkgname}.nc" + pkg_path = f"{pkgname}.{ext}" toml_content[type(pkg).__name__][pkgname] = pkg_path - pkg.to_netcdf( - modeldirectory / pkg_path, crs=crs, mdal_compliant=mdal_compliant - ) + if engine == "netCDF4": + pkg.to_netcdf( + modeldirectory / pkg_path, crs=crs, mdal_compliant=mdal_compliant + ) + else: + pkg.to_zarr(modeldirectory / pkg_path, engine=engine) toml_path = modeldirectory / f"{modelname}.toml" with open(toml_path, "wb") as f: diff --git a/imod/mf6/multimodel/exchange_creator.py b/imod/mf6/multimodel/exchange_creator.py index 31a55e6a8..8a9a16b00 100644 --- a/imod/mf6/multimodel/exchange_creator.py +++ b/imod/mf6/multimodel/exchange_creator.py @@ -1,5 +1,5 @@ import abc -from typing import Dict +from typing import Dict, NamedTuple import numpy as np import pandas as pd @@ -8,10 +8,14 @@ from imod.common.utilities.grid import get_active_domain_slice, to_cell_idx from imod.mf6.gwfgwf import GWFGWF from imod.mf6.gwtgwt import GWTGWT -from imod.mf6.multimodel.modelsplitter import PartitionInfo from imod.typing import GridDataArray +class PartitionInfo(NamedTuple): + active_domain: GridDataArray + partition_id: int + + def _adjust_gridblock_indexing(connected_cells: xr.Dataset) -> xr.Dataset: """ adjusts the gridblock numbering from 0-based to 1-based. @@ -25,8 +29,7 @@ class ExchangeCreator(abc.ABC): """ Creates the GroundWaterFlow to GroundWaterFlow exchange package (gwfgwf) as a function of a submodel label array and a PartitionInfo object. This file contains the cell indices of coupled cells. With coupled cells we mean - cells that are adjacent but that are located in different subdomains. At the moment only structured grids are - supported, for unstructured grids the geometric information is still set to default values. + cells that are adjacent but that are located in different subdomains. The submodel_labels array should have the same topology as the domain being partitioned. The array will be used to determine the connectivity of the submodels after the split operation has been performed. @@ -248,7 +251,7 @@ def _create_global_cellidx_to_local_cellid_mapping( mapping = {} for submodel_partition_info in partition_info: - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id mapping[model_id] = pd.merge( global_to_local_idx[model_id], local_cell_idx_to_id[model_id] ) @@ -268,7 +271,7 @@ def _get_local_cell_indices( def _local_cell_idx_to_id(cls, partition_info) -> Dict[int, pd.DataFrame]: local_cell_idx_to_id = {} for submodel_partition_info in partition_info: - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id local_cell_indices = cls._get_local_cell_indices(submodel_partition_info) local_cell_id = list(np.ndindex(local_cell_indices.shape)) diff --git a/imod/mf6/multimodel/exchange_creator_structured.py b/imod/mf6/multimodel/exchange_creator_structured.py index 79b3fee91..7d1572efd 100644 --- a/imod/mf6/multimodel/exchange_creator_structured.py +++ b/imod/mf6/multimodel/exchange_creator_structured.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, NamedTuple import numpy as np import pandas as pd @@ -6,12 +6,16 @@ from imod.common.utilities.grid import create_geometric_grid_info from imod.mf6.multimodel.exchange_creator import ExchangeCreator -from imod.mf6.multimodel.modelsplitter import PartitionInfo from imod.typing import GridDataArray NOT_CONNECTED_VALUE = -999 +class PartitionInfo(NamedTuple): + active_domain: GridDataArray + partition_id: int + + class ExchangeCreator_Structured(ExchangeCreator): """ Creates the GroundWaterFlow to GroundWaterFlow exchange package (gwfgwf) as @@ -130,7 +134,7 @@ def _create_global_to_local_idx( compat="override", )["label"] - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id global_to_local_idx[model_id] = pd.DataFrame( { "global_idx": overlap.values.flatten(), diff --git a/imod/mf6/multimodel/exchange_creator_unstructured.py b/imod/mf6/multimodel/exchange_creator_unstructured.py index 2869bc1ea..43d40aaca 100644 --- a/imod/mf6/multimodel/exchange_creator_unstructured.py +++ b/imod/mf6/multimodel/exchange_creator_unstructured.py @@ -4,8 +4,7 @@ import pandas as pd import xarray as xr -from imod.mf6.multimodel.exchange_creator import ExchangeCreator -from imod.mf6.multimodel.modelsplitter import PartitionInfo +from imod.mf6.multimodel.exchange_creator import ExchangeCreator, PartitionInfo from imod.typing import GridDataArray @@ -122,7 +121,7 @@ def _create_global_to_local_idx( compat="override", )["label"] - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id global_to_local_idx[model_id] = pd.DataFrame( { "global_idx": overlap.values.flatten(), diff --git a/imod/mf6/multimodel/modelsplitter.py b/imod/mf6/multimodel/modelsplitter.py index b77a9f725..07e5bde34 100644 --- a/imod/mf6/multimodel/modelsplitter.py +++ b/imod/mf6/multimodel/modelsplitter.py @@ -1,8 +1,14 @@ -from typing import List, NamedTuple +import collections +from typing import Any, NamedTuple import numpy as np +from plum import Dispatcher +import imod +from imod.common.interfaces.ilinedatapackage import ILineDataPackage from imod.common.interfaces.imodel import IModel +from imod.common.interfaces.ipackagebase import IPackageBase +from imod.common.interfaces.ipointdatapackage import IPointDataPackage from imod.common.utilities.clip import clip_by_grid from imod.mf6.auxiliary_variables import ( expand_transient_auxiliary_variables, @@ -10,76 +16,285 @@ ) from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.hfb import HorizontalFlowBarrierBase +from imod.mf6.multimodel.exchange_creator import PartitionInfo +from imod.mf6.multimodel.exchange_creator_structured import ExchangeCreator_Structured +from imod.mf6.multimodel.exchange_creator_unstructured import ( + ExchangeCreator_Unstructured, +) from imod.mf6.wel import Well from imod.typing import GridDataArray -from imod.typing.grid import ones_like +from imod.typing.grid import bounding_polygon, is_unstructured HIGH_LEVEL_PKGS = (HorizontalFlowBarrierBase, Well) -class PartitionInfo(NamedTuple): - active_domain: GridDataArray - id: int +dispatch = Dispatcher() + + +@dispatch +def activity_count( + package: object, labels: object, polygons: list[Any], ignore_time_purge_empty: bool +) -> dict: + raise TypeError( + f"`labels` should be of type xr.DataArray, xu.Ugrid2d or xu.UgridDataArray, got {type(labels)}" + ) + + +@dispatch +def activity_count( # noqa: F811 + package: IPackageBase, + labels: object, + polygons: list[Any], + ignore_time_purge_empty: bool, +) -> dict: + label_dims = set(labels.dims) + dataset = package.dataset + + # Determine sample variable: it should be spatial. + # Otherwise return a count of 1 for each partition. + if not label_dims.intersection(dataset.dims): + return dict.fromkeys(range(len(polygons)), 1) + + # Find variable with spatial dimensions + # Accessing variables is cheaper than creating a DataArray. + ndim_per_variable = { + var_name: len(dims) + for var_name in dataset.data_vars + if label_dims.intersection(dims := dataset.variables[var_name].dims) + } + max_variable = max(ndim_per_variable, key=ndim_per_variable.get) + # TODO: there might be a more robust way to do this. + # Alternatively, we just define a predicate variable (e.g. conductance) + # on each package. + + sample = dataset[max_variable] + if "time" in sample.coords: + if ignore_time_purge_empty: + sample = sample.isel(time=0) + else: + sample = sample.max("time") + # Use ellipsis to reduce over ALL dimensions except label dims + dims_to_aggregate = [dim for dim in sample.dims if dim not in label_dims] + counts = sample.notnull().sum(dim=dims_to_aggregate).groupby(labels).sum() + return {label: int(n) for label, n in enumerate(counts.data)} -def create_partition_info(submodel_labels: GridDataArray) -> List[PartitionInfo]: + +@dispatch +def activity_count( # noqa: F811 + package: IPointDataPackage, + labels: object, + polygons: list[Any], + ignore_time_purge_empty: bool, +) -> dict: + point_labels = imod.select.points_values( + labels, out_of_bounds="ignore", x=package.x, y=package.y + ) + return {label: int(n) for label, n in enumerate(np.bincount(point_labels))} + + +@dispatch +def activity_count( # noqa: F811 + package: ILineDataPackage, + labels: object, + polygons: list[Any], + ignore_time_purge_empty: bool, +) -> dict: + counts = {} + gdf_linestrings = package.line_data + for partition_id, polygon in enumerate(polygons): + partition_linestrings = gdf_linestrings.clip(polygon) + # Catch edge case: when line crosses only vertex of polygon, a point + # or multipoint is returned. These will be dropped, and can be + # identified by zero length. + counts[partition_id] = sum(partition_linestrings.length > 0) + return counts + + +class PartitionModels(NamedTuple): """ - A PartitionInfo is used to partition a model or package. The partition info's of a domain are created using a - submodel_labels array. The submodel_labels provided as input should have the same shape as a single layer of the - model grid (all layers are split the same way), and contains an integer value in each cell. Each cell in the - model grid will end up in the submodel with the index specified by the corresponding label of that cell. The - labels should be numbers between 0 and the number of partitions. + Mapping of: + flow_model_name (str) => model (object) + partition_id (int) => transport_model_name (str) => model (object) """ - _validate_submodel_label_array(submodel_labels) - unique_labels = np.unique(submodel_labels.values) + flow_models: dict[str, object] + transport_models: dict[int, dict[str, object]] - partition_infos = [] - for label_id in unique_labels: - active_domain = submodel_labels.where(submodel_labels.values == label_id) - active_domain = ones_like(active_domain).where(active_domain.notnull(), 0) - active_domain = active_domain.astype(submodel_labels.dtype) + def paired_keys(self): + for partition_id, key in enumerate(self.flow_models.keys()): + yield key, list(self.transport_models[partition_id].keys()) - submodel_partition_info = PartitionInfo( - id=label_id, active_domain=active_domain - ) - partition_infos.append(submodel_partition_info) + def paired_models(self): + for partition_id, model in enumerate(self.flow_models.values()): + yield model, list(self.transport_models[partition_id].values()) - return partition_infos + def paired_items(self): + for partition_id, (key, model) in enumerate(self.flow_models.items()): + partition_models = self.transport_models[partition_id] + yield ( + (key, model), + (list(partition_models.keys()), list(partition_models.values())), + ) + @property + def flat_transport_models(self): + return { + name: model + for partition_models in self.transport_models.values() + for name, model in partition_models.items() + } -def _validate_submodel_label_array(submodel_labels: GridDataArray) -> None: - unique_labels = np.unique(submodel_labels.values) - if not ( - len(unique_labels) == unique_labels.max() + 1 - and unique_labels.min() == 0 - and np.issubdtype(submodel_labels.dtype, np.integer) +class ModelSplitter: + def __init__( + self, + flow_models: dict[str, object], + transport_models: dict[str, object], + submodel_labels: GridDataArray, + ignore_time_purge_empty: bool = False, ): - raise ValueError( - "The submodel_label array should be integer and contain all the numbers between 0 and the number of " - "partitions minus 1." - ) + self.flow_models = flow_models + self.transport_models = transport_models + self.models = {**flow_models, **transport_models} + self.submodel_labels = submodel_labels + self.unique_labels = self._validate_submodel_label_array(submodel_labels) + self.ignore_time_purge_empty = ignore_time_purge_empty + self._create_partition_info() + self.bounding_polygons = [ + bounding_polygon(partition.active_domain) + for partition in self.partition_info + ] + self.exchange_creator: ExchangeCreator_Unstructured | ExchangeCreator_Structured + if is_unstructured(self.submodel_labels): + self.exchange_creator = ExchangeCreator_Unstructured( + self.submodel_labels, self.partition_info + ) + else: + self.exchange_creator = ExchangeCreator_Structured( + self.submodel_labels, self.partition_info + ) -def slice_model(partition_info: PartitionInfo, model: IModel) -> IModel: - """ - This function slices a Modflow6Model. A sliced model is a model that - consists of packages of the original model that are sliced using the - domain_slice. A domain_slice can be created using the - :func:`imod.mf6.modelsplitter.create_domain_slices` function. - """ - modelclass = type(model) - new_model = modelclass(**model.options) + self._count_boundary_activity_per_partition() + + @staticmethod + def _validate_submodel_label_array(submodel_labels: GridDataArray) -> None: + unique_labels = np.unique(submodel_labels) + + if not ( + len(unique_labels) == unique_labels.max() + 1 + and unique_labels.min() == 0 + and np.issubdtype(submodel_labels.dtype, np.integer) + ): + raise ValueError( + "The submodel_label array should be integer and contain all the numbers between 0 and the number of " + "partitions minus 1." + ) + return unique_labels + + def _create_partition_info(self): + self.partition_info = [] + labels = self.submodel_labels + for label_id in self.unique_labels: + active_domain = (labels == label_id).astype(labels.dtype) + self.partition_info.append( + PartitionInfo( + active_domain=active_domain, + partition_id=int(label_id), + ) + ) + + def _create_partition_polygons(self): + self.partition_polygons = { + info.partition_id: bounding_polygon(info.active_domain) + for info in self.partition_info + } + + def _count_boundary_activity_per_partition(self): + counts = {} + for model_name, model in self.models.items(): + model_counts = {} + for pkg_name, package in model.items(): + # Packages like NPF, DIS are always required. + # We only need to check packages with a MAXBOUND entry. + if not isinstance(package, BoundaryCondition): + continue + model_counts[pkg_name] = activity_count( + package, + self.submodel_labels, + self.bounding_polygons, + self.ignore_time_purge_empty, + ) + counts[model_name] = model_counts + self.boundary_activity_counts = counts + + def slice_model( + self, model: IModel, info: PartitionInfo, boundary_activity_counts: dict + ) -> IModel: + modelclass = type(model) + new_model = modelclass(**model.options) - for pkg_name, package in model.items(): - if isinstance(package, BoundaryCondition): - remove_expanded_auxiliary_variables_from_dataset(package) + for pkg_name, package in model.items(): + if isinstance(package, BoundaryCondition): + # Skip empty boundary conditions + if boundary_activity_counts[pkg_name][info.partition_id] == 0: + continue + else: + remove_expanded_auxiliary_variables_from_dataset(package) + + sliced_package = clip_by_grid(package, info.active_domain) + if sliced_package is not None: + new_model[pkg_name] = sliced_package + + if isinstance(package, BoundaryCondition): + expand_transient_auxiliary_variables(sliced_package) + + return new_model + + def _split(self, models, nest: bool): + partition_models = collections.defaultdict(dict) + model_names = collections.defaultdict(list) + for model_name, model in models.items(): + for info in self.partition_info: + new_model = self.slice_model( + model, info, self.boundary_activity_counts[model_name] + ) + new_model_name = f"{model_name}_{info.partition_id}" + if nest: + partition_models[info.partition_id][new_model_name] = new_model + else: + partition_models[new_model_name] = new_model + + model_names[model_name].append(new_model_name) + return partition_models, model_names + + def split(self): + # FUTURE: we may currently assume there is a single flow model. See check above. + # And each separate transport model represents a different species. + flow_models, flow_names = self._split(self.flow_models, nest=False) + transport_models, transport_names = self._split( + self.transport_models, nest=True + ) + names = {**flow_names, **transport_names} + return PartitionModels(flow_models, transport_models), names - sliced_package = clip_by_grid(package, partition_info.active_domain) - if sliced_package is not None: - new_model[pkg_name] = sliced_package + def create_gwfgwf_exchanges(self): + exchanges: list[Any] = [] + for model_name, model in self.flow_models.items(): + exchanges += self.exchange_creator.create_gwfgwf_exchanges( + model_name, model.domain.layer + ) + return exchanges - if isinstance(package, BoundaryCondition): - expand_transient_auxiliary_variables(sliced_package) - return new_model + def create_gwtgwt_exchanges(self): + exchanges: list[Any] = [] + # TODO: weird/arbitrary dependence on the single flow model? + flow_model_name = list(self.flow_models.keys())[0] + model = self.flow_models[flow_model_name] + if any(self.transport_models): + for transport_model_name in self.transport_models: + exchanges += self.exchange_creator.create_gwtgwt_exchanges( + transport_model_name, flow_model_name, model.domain.layer + ) + return exchanges diff --git a/imod/mf6/pkgbase.py b/imod/mf6/pkgbase.py index d2fbf99bf..8e24232c6 100644 --- a/imod/mf6/pkgbase.py +++ b/imod/mf6/pkgbase.py @@ -10,6 +10,7 @@ import imod from imod.common.interfaces.ipackagebase import IPackageBase +from imod.mf6.utilities.zarr_helper import to_zarr from imod.typing.grid import ( GridDataArray, GridDataset, @@ -108,6 +109,9 @@ def to_netcdf( dataset = imod.util.spatial.gdal_compliant_grid(dataset, crs=crs) dataset.to_netcdf(*args, **kwargs) + def to_zarr(self, path, engine, **kwargs): + to_zarr(self.dataset, path, engine, **kwargs) + def _netcdf_encoding(self) -> dict: """ @@ -163,10 +167,16 @@ def from_file(cls, path: str | Path, **kwargs) -> Self: Refer to the xarray documentation for the possible keyword arguments. """ + path = Path(path) if path.suffix in (".zip", ".zarr"): - # TODO: seems like a bug? Remove str() call if fixed in xarray/zarr - dataset = xr.open_zarr(str(path), **kwargs) + import zarr + + if path.suffix == ".zip": + with zarr.storage.ZipStore(path, mode="r") as store: + dataset = xr.open_zarr(store, **kwargs) + else: + dataset = xr.open_zarr(str(path), **kwargs) else: dataset = xr.open_dataset(path, **kwargs) diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index 759612969..5c1d682c6 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -39,14 +39,11 @@ from imod.mf6.model import Modflow6Model from imod.mf6.model_gwf import GroundwaterFlowModel from imod.mf6.model_gwt import GroundwaterTransportModel -from imod.mf6.multimodel.exchange_creator_structured import ExchangeCreator_Structured -from imod.mf6.multimodel.exchange_creator_unstructured import ( - ExchangeCreator_Unstructured, -) -from imod.mf6.multimodel.modelsplitter import create_partition_info, slice_model +from imod.mf6.multimodel.modelsplitter import ModelSplitter, PartitionModels from imod.mf6.out import open_cbc, open_conc, open_hds from imod.mf6.package import Package from imod.mf6.ssm import SourceSinkMixing +from imod.mf6.utilities.zarr_helper import to_zarr from imod.mf6.validation_settings import ValidationSettings from imod.mf6.write_context import WriteContext from imod.prepare.partition import create_partition_labels @@ -92,6 +89,12 @@ def get_packages(simulation: Modflow6Simulation) -> dict[str, Package]: } +def force_load_dis(model): + key = model.get_diskey() + model[key].dataset.load() + return + + class Modflow6Simulation(collections.UserDict, ISimulation): """ Modflow6Simulation is a class that represents a Modflow 6 simulation. It @@ -969,6 +972,7 @@ def dump( validate: bool = True, mdal_compliant: bool = False, crs=None, + engine="netCDF4", ) -> None: """ Dump simulation to files. Writes a model definition as .TOML file, which @@ -990,6 +994,8 @@ def dump( crs: Any, optional Anything accepted by rasterio.crs.CRS.from_user_input Requires ``rioxarray`` installed. + engine: str, optional + "netCDF4" or "zarr" or "zarr.zip". Defaults to "netCDF4". Examples -------- @@ -1015,16 +1021,27 @@ def dump( directory = pathlib.Path(directory) directory.mkdir(parents=True, exist_ok=True) + match engine: + case "netCDF4": + ext = "nc" + case "zarr": + ext = "zarr" + case "zarr.zip": + ext = "zarr.zip" + case _: + raise ValueError(f"Unknown engine: {engine}") + toml_content: DefaultDict[str, dict] = collections.defaultdict(dict) # Dump version number version = get_version() toml_content["version"] = {"imod-python": version} + # Dump models and exchanges for key, value in self.items(): cls_name = type(value).__name__ if isinstance(value, Modflow6Model): model_toml_path = value.dump( - directory, key, validate, mdal_compliant, crs + directory, key, validate, mdal_compliant, crs, engine=engine ) toml_content[cls_name][key] = model_toml_path.relative_to( directory @@ -1034,13 +1051,23 @@ def dump( for exchange_package in self[key]: _, filename, _, _ = exchange_package.get_specification() exchange_class_short = type(exchange_package).__name__ - path = f"{filename}.nc" - exchange_package.dataset.to_netcdf(directory / path) + path = f"{filename}.{ext}" + + if engine == "netCDF4": + exchange_package.dataset.to_netcdf(directory / path) + else: + to_zarr( + exchange_package.dataset, directory / path, engine=engine + ) + toml_content[key][exchange_class_short].append(path) else: - path = f"{key}.nc" - value.dataset.to_netcdf(directory / path) + path = f"{key}.{ext}" + if engine == "netCDF4": + value.dataset.to_netcdf(directory / path) + else: + to_zarr(value.dataset, directory / path, engine=engine) toml_content[cls_name][key] = path with open(directory / f"{self.name}.toml", "wb") as f: @@ -1412,57 +1439,50 @@ def split( f"simulation cannot be split due to presence of package '{error_with_object}' in model '{model_name}'" ) - original_packages = get_packages(self) - - partition_info = create_partition_info(submodel_labels) + # Make sure the DIS package is available in memory and not lazily evaluated, + # since we need its values repeatedly. + for model in original_models.values(): + force_load_dis(model) - exchange_creator: ExchangeCreator_Unstructured | ExchangeCreator_Structured - if is_unstructured(submodel_labels): - exchange_creator = ExchangeCreator_Unstructured( - submodel_labels, partition_info - ) - else: - exchange_creator = ExchangeCreator_Structured( - submodel_labels, partition_info - ) + model_splitter = ModelSplitter( + flow_models, + transport_models, + submodel_labels, + ignore_time_purge_empty, + ) + partition_models, model_names = model_splitter.split() + # Create new simulation object and add the partitioned models. new_simulation = imod.mf6.Modflow6Simulation( f"{self.name}_partioned", validation_settings=self._validation_context ) - for package_name, package in {**original_packages}.items(): - new_simulation[package_name] = deepcopy(package) - - for model_name, model in original_models.items(): - solution_name = self.get_solution_name(model_name) - solution = cast(Solution, new_simulation[solution_name]) - solution._remove_model_from_solution(model_name) - for submodel_partition_info in partition_info: - new_model_name = f"{model_name}_{submodel_partition_info.id}" - new_simulation[new_model_name] = slice_model( - submodel_partition_info, model - ) - new_simulation[new_model_name].purge_empty_packages( - ignore_time=ignore_time_purge_empty - ) - solution._add_model_to_solution(new_model_name) - - exchanges: list[Any] = [] - - for flow_model_name, flow_model in flow_models.items(): - exchanges += exchange_creator.create_gwfgwf_exchanges( - flow_model_name, flow_model.domain.layer - ) + chained = { # ChainMap reverses order, annoyingly... + **partition_models.flow_models, + **partition_models.flat_transport_models, + } + for partition_model_name, partition_model in chained.items(): + new_simulation[partition_model_name] = partition_model - if any(transport_models): - for tpt_model_name in transport_models: - exchanges += exchange_creator.create_gwtgwt_exchanges( - tpt_model_name, flow_model_name, model.domain.layer - ) + # Add solution, time_discretization, etc. + # Replace the single model name by the partition model names. + original_packages = get_packages(self) + for package_name, package in original_packages.items(): + new_package = deepcopy(package) + if isinstance(package, Solution): + old_name = package.dataset["modelnames"].item() + new_package["modelnames"] = xr.DataArray(model_names[old_name]) + new_simulation[package_name] = new_package + + # Add exchanges + exchanges: list[Any] = ( + model_splitter.create_gwfgwf_exchanges() + + model_splitter.create_gwtgwt_exchanges() + ) new_simulation._add_modelsplit_exchanges(exchanges) - new_simulation._update_buoyancy_packages() + new_simulation._update_buoyancy_packages(partition_models) new_simulation._set_flow_exchange_options() new_simulation._set_transport_exchange_options() - new_simulation._update_ssm_packages() + new_simulation._update_ssm_packages(partition_models) new_simulation._filter_inactive_cells_from_exchanges() return new_simulation @@ -1506,6 +1526,7 @@ def _add_modelsplit_exchanges(self, exchanges_list: list[GWFGWF]) -> None: if not self.is_split(): self["split_exchanges"] = [] self["split_exchanges"].extend(exchanges_list) + return def _set_flow_exchange_options(self) -> None: # collect some options that we will auto-set @@ -1520,6 +1541,7 @@ def _set_flow_exchange_options(self) -> None: xt3d=model_1["npf"].get_xt3d_option(), newton=model_1.is_use_newton(), ) + return def _set_transport_exchange_options(self) -> None: for exchange in self["split_exchanges"]: @@ -1552,6 +1574,7 @@ def _filter_inactive_cells_from_exchanges(self) -> None: # Remove exchange if no cells are left if ex.dataset.sizes["index"] == 0: self["split_exchanges"].remove(ex) + return def _filter_inactive_cells_exchange_domain(self, ex: GWFGWF, i: int) -> None: """Filters inactive cells from one exchange domain inplace""" @@ -1575,6 +1598,7 @@ def _filter_inactive_cells_exchange_domain(self, ex: GWFGWF, i: int) -> None: active_exchange_domain = exchange_domain.where(exchange_domain > 0) active_exchange_domain = active_exchange_domain.dropna("index") ex.dataset = ex.dataset.sel(index=active_exchange_domain["index"]) + return def get_solution_name(self, model_name: str) -> Optional[str]: for k, v in self.items(): @@ -1626,15 +1650,12 @@ def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]: return exchanges - def _update_ssm_packages(self) -> None: - flow_transport_mapping = self._get_transport_models_per_flow_model() - for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): - flow_model = self[flow_name] - for tpt_model_name in tpt_models_of_flow_model: - tpt_model = self[tpt_model_name] - ssm_key = tpt_model._get_pkgkey("ssm") + def _update_ssm_packages(self, partition_models: PartitionModels) -> None: + for flow_model, paired_transport_models in partition_models.paired_models(): + for transport_model in paired_transport_models: + ssm_key = transport_model._get_pkgkey("ssm") if ssm_key is not None: - old_ssm_package = tpt_model.pop(ssm_key) + old_ssm_package = transport_model.pop(ssm_key) state_variable_name = old_ssm_package.dataset[ "auxiliary_variable_name" ].values[0] @@ -1642,13 +1663,13 @@ def _update_ssm_packages(self) -> None: flow_model, state_variable_name, is_split=self.is_split() ) if ssm_package is not None: - tpt_model[ssm_key] = ssm_package + transport_model[ssm_key] = ssm_package + return - def _update_buoyancy_packages(self) -> None: - flow_transport_mapping = self._get_transport_models_per_flow_model() - for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): - flow_model = cast(GroundwaterFlowModel, self[flow_name]) - flow_model._update_buoyancy_package(tpt_models_of_flow_model) + def _update_buoyancy_packages(self, partition_models: PartitionModels) -> None: + for (_, flow_model), (names, _) in partition_models.paired_items(): + flow_model._update_buoyancy_package(names) + return def is_split(self) -> bool: """ diff --git a/imod/mf6/utilities/zarr_helper.py b/imod/mf6/utilities/zarr_helper.py new file mode 100644 index 000000000..3ae21552a --- /dev/null +++ b/imod/mf6/utilities/zarr_helper.py @@ -0,0 +1,35 @@ +import shutil +from pathlib import Path + +import xugrid as xu + + +def to_zarr(dataset, path: str | Path, engine: str, **kwargs): + import zarr + + path = Path(path) + if path.exists(): + # Check if directory (ordinary .zarr, directory) or ZipStore (zip file). + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + + match engine: + case "zarr": + if isinstance(dataset, xu.UgridDataset): + dataset.ugrid.to_zarr(path, **kwargs) + else: + dataset.to_zarr(path, **kwargs) + case "zarr.zip": + with zarr.storage.ZipStore(path, mode="w") as store: + if isinstance(dataset, xu.UgridDataset): + dataset.ugrid.to_zarr(store, **kwargs) + else: + dataset.to_zarr(store, **kwargs) + case _: + raise ValueError( + f'Expected engine to be "zarr" or "zarr.zip", got: {engine}' + ) + + return