diff --git a/imod/mf6/oc.py b/imod/mf6/oc.py index ca820e7a0..6921a3db8 100644 --- a/imod/mf6/oc.py +++ b/imod/mf6/oc.py @@ -172,7 +172,7 @@ def _render(self, directory, pkgname, globaltimes, binary): package_times = self.dataset[datavar].coords["time"].values starts = np.searchsorted(globaltimes, package_times) + 1 for i, s in enumerate(starts): - setting = self.dataset[datavar].isel(time=i).item() + setting = self.dataset[datavar].isel(time=i).values[()] periods[s][key] = self._get_ocsetting(setting) else: diff --git a/imod/mf6/pkgbase.py b/imod/mf6/pkgbase.py index d2fbf99bf..fe78a789a 100644 --- a/imod/mf6/pkgbase.py +++ b/imod/mf6/pkgbase.py @@ -94,6 +94,17 @@ def to_netcdf( kwargs.update({"encoding": self._netcdf_encoding()}) dataset = self.dataset + + # Create encoding dict for float16 variables + for var in dataset.data_vars: + if dataset[var].dtype == np.float16: + kwargs["encoding"][var] = {"dtype": "float32"} + + # Also check coordinates + for coord in dataset.coords: + if dataset[coord].dtype == np.float16: + kwargs["encoding"][coord] = {"dtype": "float32"} + if isinstance(dataset, xu.UgridDataset): if mdal_compliant: dataset = dataset.ugrid.to_dataset() @@ -168,7 +179,7 @@ def from_file(cls, path: str | Path, **kwargs) -> Self: # TODO: seems like a bug? Remove str() call if fixed in xarray/zarr dataset = xr.open_zarr(str(path), **kwargs) else: - dataset = xr.open_dataset(path, **kwargs) + dataset = xr.open_dataset(path, chunks="auto", **kwargs) if dataset.ugrid_roles.topology: dataset = xu.UgridDataset(dataset) @@ -183,4 +194,12 @@ def from_file(cls, path: str | Path, **kwargs) -> Self: if _is_scalar_nan(value): dataset[key] = None + # to_netcdf converts strings into NetCDF "variable‑length UTF‑8 strings" + # which are loaded as dtype=object arrays. This will convert them back + # to str. + vars = ["species"] + for var in vars: + if var in dataset: + dataset[var] = dataset[var].astype(str) + return cls._from_dataset(dataset) diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index caa4fe796..d873db1b0 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -60,7 +60,7 @@ from imod.typing import GridDataArray, GridDataset from imod.typing.grid import ( concat, - is_equal, + is_same_domain, is_unstructured, merge_partitions, ) @@ -1622,10 +1622,16 @@ def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]: for flow_model_name in flow_models: flow_model = self[flow_model_name] + + matched_tsp_models = [] for tpt_model_name in transport_models: tpt_model = self[tpt_model_name] - if is_equal(tpt_model.domain, flow_model.domain): + if is_same_domain(tpt_model.domain, flow_model.domain): result[flow_model_name].append(tpt_model_name) + matched_tsp_models.append(tpt_model_name) + for tpt_model_name in matched_tsp_models: + transport_models.pop(tpt_model_name) + return result def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]: diff --git a/imod/mf6/wel.py b/imod/mf6/wel.py index 9c82b78d3..76f55a934 100644 --- a/imod/mf6/wel.py +++ b/imod/mf6/wel.py @@ -540,7 +540,7 @@ def _to_mf6_package_information( else: message += " The first 10 unplaced wells are: \n" - is_filtered = self.dataset["id"].isin([filtered_wells]) + is_filtered = self.dataset["id"].compute().isin(filtered_wells) for i in range(min(10, len(filtered_wells))): ids = filtered_wells[i] x = self.dataset["x"].data[is_filtered][i] @@ -1073,9 +1073,9 @@ def _assign_wells_to_layers( ) -> pd.DataFrame: # Ensure top, bottom & k # are broadcasted to 3d grid - like = ones_like(active) - bottom = like * bottom - top_2d = (like * top).sel(layer=1) + like = ones_like(active.compute()) + bottom = like * bottom.compute() + top_2d = (like * top.compute()).sel(layer=1) top_3d = bottom.shift(layer=1).fillna(top_2d) k = like * k diff --git a/imod/typing/grid.py b/imod/typing/grid.py index 7a4f6dc92..98701dd2a 100644 --- a/imod/typing/grid.py +++ b/imod/typing/grid.py @@ -316,11 +316,15 @@ def is_spatial_grid(_: Any) -> bool: # noqa: F811 @dispatch def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool: + if not is_same_domain(array1, array2): + return False return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid) @dispatch # type: ignore[no-redef] def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811 + if not is_same_domain(array1, array2): + return False return array1.equals(array2) diff --git a/pyproject.toml b/pyproject.toml index 4071e7c5e..07867f179 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,6 @@ ignore = [ "E501", # line-too-long. This rule can't be fullfilled by the ruff formatter. The same behavior as black. "PD003", "PD004", - "PD901", "PD011", "PD013", "PD015",