diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 8de81421b9..572c0224d8 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -161,15 +161,6 @@ def add_constant_field(self, name: str, value, mesh: Mesh = "flat"): da = xr.DataArray( data=np.full((1, 1, 1, 1), value), dims=["time", "ZG", "YG", "XG"], - coords={ - "ZG": (["ZG"], np.arange(1), {"axis": "Z"}), - "YG": (["YG"], np.arange(1), {"axis": "Y"}), - "XG": (["XG"], np.arange(1), {"axis": "X"}), - "lon": (["XG"], np.arange(1), {"axis": "X"}), - "lat": (["YG"], np.arange(1), {"axis": "Y"}), - "depth": (["ZG"], np.arange(1), {"axis": "Z"}), - "time": (["time"], np.arange(1), {"axis": "T"}), - }, ) grid = XGrid(xgcm.Grid(da)) self.add_field( diff --git a/parcels/xgrid.py b/parcels/xgrid.py index c0bb066513..5d536a39ed 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -10,6 +10,9 @@ from parcels.basegrid import BaseGrid from parcels.tools.converters import TimeConverter +_XGRID_AXES_ORDERING = "ZYX" +_XGRID_AXES = Literal["X", "Y", "Z"] + _XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] _XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] _AXIS_DIRECTION = Literal["X", "Y", "Z"] @@ -29,6 +32,11 @@ def get_time(axis: xgcm.Axis) -> npt.NDArray: return axis._ds[axis.coords["center"]].values +def _get_xgrid_axes(grid: xgcm.Grid) -> list[_XGRID_AXES]: + spatial_axes = [a for a in grid.axes.keys() if a in ["X", "Y", "Z"]] + return sorted(spatial_axes, key=_XGRID_AXES_ORDERING.index) + + class XGrid(BaseGrid): """ Class to represent a structured grid in Parcels. Wraps a xgcm-like Grid object (we use a trimmed down version of the xgcm.Grid class that is vendored with Parcels). @@ -44,17 +52,13 @@ def __init__(self, grid: xgcm.Grid, mesh="flat"): self.xgcm_grid = grid self.mesh = mesh ds = grid._ds - assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes) - - # ! Not ideal... Triggers computation on a throwaway item. Keeping for now for v3 compat, will be removed in v4. - self.lonlat_minmax = np.array( - [ - np.nanmin(self.xgcm_grid._ds["lon"]), - np.nanmax(self.xgcm_grid._ds["lon"]), - np.nanmin(self.xgcm_grid._ds["lat"]), - np.nanmax(self.xgcm_grid._ds["lat"]), - ] - ) + + if len(set(grid.axes) & {"X", "Y", "Z"}) > 0: # Only if spatial grid is >0D (see #2054 for further development) + assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes) + + @property + def axes(self) -> list[_XGRID_AXES]: + return _get_xgrid_axes(self.xgcm_grid) @property def lon(self): diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index 39a5ee614d..af84f125dd 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -78,10 +78,15 @@ def test_xgrid_against_old(ds, attr): @pytest.mark.parametrize("ds", [pytest.param(ds, id=key) for key, ds in datasets.items()]) -def test_grid_init_on_generic_datasets(ds): +def test_xgrid_init_on_generic_datasets(ds): XGrid(xgcm.Grid(ds, periodic=False)) +def test_xgrid_axes(): + # Tests that the xgrid.axes property correctly identifies the axes and ordering + ... + + def test_invalid_xgrid_field_array(): """Stress test initialiser by creating incompatible datasets that test the edge cases""" ...