diff --git a/parcels/_core/utils/__init__.py b/parcels/_core/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/_core/utils/common.py b/parcels/_core/utils/common.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/_core/utils/structured.py b/parcels/_core/utils/structured.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/parcels/_core/utils/unstructured.py b/parcels/_core/utils/unstructured.py new file mode 100644 index 0000000000..76c4c93d7f --- /dev/null +++ b/parcels/_core/utils/unstructured.py @@ -0,0 +1,28 @@ +from collections.abc import Hashable + +DIM_TO_VERTICAL_LOCATION_MAP = { + "nz1": "center", + "nz": "face", +} + + +def get_vertical_location_from_dims(dims: tuple[Hashable, ...]): + """ + Determine the vertical location of the field based on the uxarray.UxDataArray object variables. + + Only used for unstructured grids. + """ + vertical_dims_in_data = set(dims) & set(DIM_TO_VERTICAL_LOCATION_MAP.keys()) + + if len(vertical_dims_in_data) != 1: + raise ValueError( + f"Expected exactly one vertical dimension ({set(DIM_TO_VERTICAL_LOCATION_MAP.keys())}) in the data, got {vertical_dims_in_data}" + ) + + return DIM_TO_VERTICAL_LOCATION_MAP[vertical_dims_in_data.pop()] + + +def get_vertical_dim_name_from_location(location: str): + """Determine the vertical location of the field based on the uxarray.UxGrid object variables.""" + location_to_dim_map = {v: k for k, v in DIM_TO_VERTICAL_LOCATION_MAP.items()} + return location_to_dim_map[location] diff --git a/parcels/field.py b/parcels/field.py index 364e483549..613a88ac82 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -10,6 +10,7 @@ import xarray as xr from uxarray.grid.neighbors import _barycentric_coordinates +from parcels._core.utils.unstructured import get_vertical_location_from_dims from parcels._typing import ( Mesh, VectorType, @@ -28,6 +29,7 @@ _raise_field_out_of_bound_error, ) from parcels.v4.grid import Grid +from parcels.v4.gridadapter import GridAdapter from ._index_search import _search_indices_rectilinear, _search_time_index @@ -166,6 +168,13 @@ def __init__( self.data = data self.grid = grid + # For compatibility with parts of the codebase that rely on v3 definition of Grid. + # Should be worked to be removed in v4 + if isinstance(grid, Grid): + self.gridadapter = GridAdapter(grid) + else: + self.gridadapter = None + try: if isinstance(data, ux.UxDataArray): _assert_valid_uxdataarray(data) @@ -178,8 +187,6 @@ def __init__( self._parent_mesh = data.attrs["mesh"] self._mesh_type = mesh_type - self._location = data.attrs["location"] - self._vertical_location = None # Setting the interpolation method dynamically if interp_method is None: @@ -202,55 +209,9 @@ def __init__( else: self.allow_time_extrapolation = allow_time_extrapolation - if type(self.data) is ux.UxDataArray: - self._spatialhash = self.grid.get_spatial_hash() - self._gtype = None - # Set the vertical location - if "nz1" in data.dims: - self._vertical_location = "center" - elif "nz" in data.dims: - self._vertical_location = "face" - else: # TODO Nick : This bit probably needs an overhaul once the parcels.Grid class is integrated. - self._spatialhash = None - # Set the grid type - if "x_g" in self.data.coords: - lon = self.data.x_g - elif "x_c" in self.data.coords: - lon = self.data.x_c - else: - lon = self.data.lon - - if "nz1" in self.data.coords: - depth = self.data.nz1 - elif "nz" in self.data.coords: - depth = self.data.nz - elif "depth" in self.data.coords: - depth = self.data.depth - else: - depth = None - - if len(lon.shape) <= 1: - if depth is None or len(depth.shape) <= 1: - self._gtype = GridType.RectilinearZGrid - else: - self._gtype = GridType.RectilinearSGrid - else: - if depth is None or len(depth.shape) <= 1: - self._gtype = GridType.CurvilinearZGrid - else: - self._gtype = GridType.CurvilinearSGrid - - self._lonlat_minmax = np.array( - [np.nanmin(self.lon), np.nanmax(self.lon), np.nanmin(self.lat), np.nanmax(self.lat)], dtype=np.float32 - ) - def __repr__(self): return field_repr(self) - @property - def lonlat_minmax(self): - return self._lonlat_minmax - @property def units(self): return self._units @@ -264,69 +225,63 @@ def units(self, value): @property def lat(self): if type(self.data) is ux.UxDataArray: - if self._location == "node": + if self.data.attrs["location"] == "node": return self.grid.node_lat - elif self._location == "face": + elif self.data.attrs["location"] == "face": return self.grid.face_lat - elif self._location == "edge": + elif self.data.attrs["location"] == "edge": return self.grid.edge_lat else: - return self.data.lat + return self.gridadapter.lat @property def lon(self): if type(self.data) is ux.UxDataArray: - if self._location == "node": + if self.data.attrs["location"] == "node": return self.grid.node_lon - elif self._location == "face": + elif self.data.attrs["location"] == "face": return self.grid.face_lon - elif self._location == "edge": + elif self.data.attrs["location"] == "edge": return self.grid.edge_lon else: - return self.data.lon + return self.gridadapter.lon @property def depth(self): if type(self.data) is ux.UxDataArray: - if self._vertical_location == "center": + vertical_location = get_vertical_location_from_dims(self.data.dims) + if vertical_location == "center": return self.grid.nz1 - elif self._vertical_location == "face": + elif vertical_location == "face": return self.grid.nz else: - return self.data.depth + return self.gridadapter.depth @property def xdim(self): if type(self.data) is xr.DataArray: - if "face_lon" in self.data.dims: - return self.data.sizes["face_lon"] - elif "node_lon" in self.data.dims: - return self.data.sizes["node_lon"] - else: - return self.data.sizes["lon"] + return self.gridadapter.xdim else: - return 0 # TODO : Discuss what we want to return as xdim for uxdataarray obj + raise NotImplementedError("xdim not implemented for unstructured grids") @property def ydim(self): if type(self.data) is xr.DataArray: - if "face_lat" in self.data.dims: - return self.data.sizes["face_lat"] - elif "node_lat" in self.data.dims: - return self.data.sizes["node_lat"] - else: - return self.data.sizes["lat"] + return self.gridadapter.ydim else: - return 0 # TODO : Discuss what we want to return as ydim for uxdataarray obj + raise NotImplementedError("ydim not implemented for unstructured grids") @property def zdim(self): - if "nz1" in self.data.dims: - return self.data.sizes["nz1"] - elif "nz" in self.data.dims: - return self.data.sizes["nz"] + if type(self.data) is xr.DataArray: + return self.gridadapter.zdim else: - return 0 + if "nz1" in self.data.dims: + return self.data.sizes["nz1"] + elif "nz" in self.data.dims: + return self.data.sizes["nz"] + else: + return 0 @property def n_face(self): @@ -365,7 +320,7 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): tol = 1e-10 if ei is None: # Search using global search - fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle + fi, bcoords = self.grid.get_spatial_hash().query([[x, y]]) # Get the face id for the particle if fi == -1: raise FieldOutOfBoundError(z, y, x) # TODO Joe : Do the vertical grid search @@ -389,12 +344,12 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False): return bcoords, self.ravel_index(zi, 0, neighbor) # If we reach this point, we do a global search as a last ditch effort the particle is out of bounds - fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle + fi, bcoords = self.grid.get_spatial_hash().query([[x, y]]) # Get the face id for the particle if fi == -1: raise FieldOutOfBoundError(z, y, x) def _search_indices_structured(self, z, y, x, ei=None, search2D=False): - if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: + if self.gridadapter._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]: (zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear(self, z, y, x, ei=ei, search2D=search2D) else: ## TODO : Still need to implement the search_indices_curvilinear diff --git a/parcels/v4/gridadapter.py b/parcels/v4/gridadapter.py index 99e34a2c65..31329d6496 100644 --- a/parcels/v4/gridadapter.py +++ b/parcels/v4/gridadapter.py @@ -30,67 +30,68 @@ def get_time(axis: Axis) -> npt.NDArray: return axis._ds[axis.coords["center"]].values -class GridAdapter(NewGrid): - def __init__(self, ds, mesh="flat", *args, **kwargs): - super().__init__(ds, *args, **kwargs) +class GridAdapter: + def __init__(self, grid: NewGrid, mesh="flat"): + self.grid = grid self.mesh = mesh + # ! Not ideal... Triggers computation on a throwaway item. If adapter is still needed in codebase, and this is prohibitively expensive, perhaps store GridAdapter on Field object instead of Grid self.lonlat_minmax = np.array( [ - np.nanmin(self._ds["lon"]), - np.nanmax(self._ds["lon"]), - np.nanmin(self._ds["lat"]), - np.nanmax(self._ds["lat"]), + np.nanmin(self.grid._ds["lon"]), + np.nanmax(self.grid._ds["lon"]), + np.nanmin(self.grid._ds["lat"]), + np.nanmax(self.grid._ds["lat"]), ] ) @property def lon(self): try: - _ = self.axes["X"] + _ = self.grid.axes["X"] except KeyError: return np.zeros(1) - return self._ds["lon"].values + return self.grid._ds["lon"].values @property def lat(self): try: - _ = self.axes["Y"] + _ = self.grid.axes["Y"] except KeyError: return np.zeros(1) - return self._ds["lat"].values + return self.grid._ds["lat"].values @property def depth(self): try: - _ = self.axes["Z"] + _ = self.grid.axes["Z"] except KeyError: return np.zeros(1) - return self._ds["depth"].values + return self.grid._ds["depth"].values @property def time(self): try: - axis = self.axes["T"] + axis = self.grid.axes["T"] except KeyError: return np.zeros(1) return get_time(axis) @property def xdim(self): - return get_dimensionality(self.axes.get("X")) + return get_dimensionality(self.grid.axes.get("X")) @property def ydim(self): - return get_dimensionality(self.axes.get("Y")) + return get_dimensionality(self.grid.axes.get("Y")) @property def zdim(self): - return get_dimensionality(self.axes.get("Z")) + return get_dimensionality(self.grid.axes.get("Z")) @property def tdim(self): - return get_dimensionality(self.axes.get("T")) + return get_dimensionality(self.grid.axes.get("T")) @property def time_origin(self): diff --git a/tests/v4/test_field.py b/tests/v4/test_field.py index 082918567e..be0f0b87e4 100644 --- a/tests/v4/test_field.py +++ b/tests/v4/test_field.py @@ -32,11 +32,18 @@ def test_field_init_param_types(): "data,grid", [ pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"), - pytest.param(xr.DataArray(), ux.UxDataArray().uxgrid, id="xarray-uxgrid"), + pytest.param( + xr.DataArray(), + ux.UxDataArray().uxgrid, + id="xarray-uxgrid", + marks=pytest.mark.xfail( + reason="Replace uxDataArray object with one that actually has a grid (once unstructured example datasets are in the codebase)." + ), + ), ], ) def test_field_incompatible_combination(data, grid): - with pytest.raises(ValueError, msg="Incompatible data-grid combination."): + with pytest.raises(ValueError, match="Incompatible data-grid combination."): Field( name="test_field", data=data, diff --git a/tests/v4/test_gridadapter.py b/tests/v4/test_gridadapter.py index 1fdcc66745..24c67b70b6 100644 --- a/tests/v4/test_gridadapter.py +++ b/tests/v4/test_gridadapter.py @@ -7,6 +7,7 @@ from parcels._datasets.structured.grid_datasets import N, T, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter +from parcels.v4.grid import Grid as NewGrid from parcels.v4.gridadapter import GridAdapter TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"]) @@ -38,7 +39,7 @@ def assert_equal(actual, expected): @pytest.mark.parametrize("ds, attr, expected", test_cases) def test_grid_adapter_properties_ground_truth(ds, attr, expected): - adapter = GridAdapter(ds, periodic=False) + adapter = GridAdapter(NewGrid(ds, periodic=False)) actual = getattr(adapter, attr) assert_equal(actual, expected) @@ -60,7 +61,7 @@ def test_grid_adapter_properties_ground_truth(ds, attr, expected): ) @pytest.mark.parametrize("ds", datasets.values()) def test_grid_adapter_against_old(ds, attr): - adapter = GridAdapter(ds, periodic=False) + adapter = GridAdapter(NewGrid(ds, periodic=False)) grid = OldGrid.create_grid( lon=ds.lon.values, diff --git a/tests/v4/utils/test_unstructured.py b/tests/v4/utils/test_unstructured.py new file mode 100644 index 0000000000..e8c296feca --- /dev/null +++ b/tests/v4/utils/test_unstructured.py @@ -0,0 +1,33 @@ +import pytest + +from parcels._core.utils.unstructured import ( + get_vertical_dim_name_from_location, + get_vertical_location_from_dims, +) + + +def test_get_vertical_location_from_dims(): + # Test with nz1 dimension + assert get_vertical_location_from_dims(("nz1", "time")) == "center" + + # Test with nz dimension + assert get_vertical_location_from_dims(("nz", "time")) == "face" + + # Test with both dimensions + with pytest.raises(ValueError): + get_vertical_location_from_dims(("nz1", "nz", "time")) + + # Test with no vertical dimension + with pytest.raises(ValueError): + get_vertical_location_from_dims(("time", "x", "y")) + + +def test_get_vertical_dim_name_from_location(): + # Test with center location + assert get_vertical_dim_name_from_location("center") == "nz1" + + # Test with face location + assert get_vertical_dim_name_from_location("face") == "nz" + + with pytest.raises(KeyError): + get_vertical_dim_name_from_location("invalid_location")