Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added parcels/_core/utils/__init__.py
Empty file.
Empty file added parcels/_core/utils/common.py
Empty file.
Empty file.
28 changes: 28 additions & 0 deletions parcels/_core/utils/unstructured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from collections.abc import Hashable

DIM_TO_VERTICAL_LOCATION_MAP = {
"nz1": "center",
"nz": "face",
}


def get_vertical_location_from_dims(dims: tuple[Hashable, ...]):
"""
Determine the vertical location of the field based on the uxarray.UxDataArray object variables.

Only used for unstructured grids.
"""
vertical_dims_in_data = set(dims) & set(DIM_TO_VERTICAL_LOCATION_MAP.keys())

if len(vertical_dims_in_data) != 1:
raise ValueError(
f"Expected exactly one vertical dimension ({set(DIM_TO_VERTICAL_LOCATION_MAP.keys())}) in the data, got {vertical_dims_in_data}"
)

return DIM_TO_VERTICAL_LOCATION_MAP[vertical_dims_in_data.pop()]


def get_vertical_dim_name_from_location(location: str):
"""Determine the vertical location of the field based on the uxarray.UxGrid object variables."""
location_to_dim_map = {v: k for k, v in DIM_TO_VERTICAL_LOCATION_MAP.items()}
return location_to_dim_map[location]
117 changes: 36 additions & 81 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import xarray as xr
from uxarray.grid.neighbors import _barycentric_coordinates

from parcels._core.utils.unstructured import get_vertical_location_from_dims
from parcels._typing import (
Mesh,
VectorType,
Expand All @@ -28,6 +29,7 @@
_raise_field_out_of_bound_error,
)
from parcels.v4.grid import Grid
from parcels.v4.gridadapter import GridAdapter

from ._index_search import _search_indices_rectilinear, _search_time_index

Expand Down Expand Up @@ -166,6 +168,13 @@ def __init__(
self.data = data
self.grid = grid

# For compatibility with parts of the codebase that rely on v3 definition of Grid.
# Should be worked to be removed in v4
if isinstance(grid, Grid):
self.gridadapter = GridAdapter(grid)
else:
self.gridadapter = None

try:
if isinstance(data, ux.UxDataArray):
_assert_valid_uxdataarray(data)
Expand All @@ -178,8 +187,6 @@ def __init__(

self._parent_mesh = data.attrs["mesh"]
self._mesh_type = mesh_type
self._location = data.attrs["location"]
self._vertical_location = None

# Setting the interpolation method dynamically
if interp_method is None:
Expand All @@ -202,55 +209,9 @@ def __init__(
else:
self.allow_time_extrapolation = allow_time_extrapolation

if type(self.data) is ux.UxDataArray:
self._spatialhash = self.grid.get_spatial_hash()
self._gtype = None
# Set the vertical location
if "nz1" in data.dims:
self._vertical_location = "center"
elif "nz" in data.dims:
self._vertical_location = "face"
else: # TODO Nick : This bit probably needs an overhaul once the parcels.Grid class is integrated.
self._spatialhash = None
# Set the grid type
if "x_g" in self.data.coords:
lon = self.data.x_g
elif "x_c" in self.data.coords:
lon = self.data.x_c
else:
lon = self.data.lon

if "nz1" in self.data.coords:
depth = self.data.nz1
elif "nz" in self.data.coords:
depth = self.data.nz
elif "depth" in self.data.coords:
depth = self.data.depth
else:
depth = None

if len(lon.shape) <= 1:
if depth is None or len(depth.shape) <= 1:
self._gtype = GridType.RectilinearZGrid
else:
self._gtype = GridType.RectilinearSGrid
else:
if depth is None or len(depth.shape) <= 1:
self._gtype = GridType.CurvilinearZGrid
else:
self._gtype = GridType.CurvilinearSGrid

self._lonlat_minmax = np.array(
[np.nanmin(self.lon), np.nanmax(self.lon), np.nanmin(self.lat), np.nanmax(self.lat)], dtype=np.float32
)

def __repr__(self):
return field_repr(self)

@property
def lonlat_minmax(self):
return self._lonlat_minmax

@property
def units(self):
return self._units
Expand All @@ -264,69 +225,63 @@ def units(self, value):
@property
def lat(self):
if type(self.data) is ux.UxDataArray:
if self._location == "node":
if self.data.attrs["location"] == "node":
return self.grid.node_lat
elif self._location == "face":
elif self.data.attrs["location"] == "face":
return self.grid.face_lat
elif self._location == "edge":
elif self.data.attrs["location"] == "edge":
return self.grid.edge_lat
else:
return self.data.lat
return self.gridadapter.lat

@property
def lon(self):
if type(self.data) is ux.UxDataArray:
if self._location == "node":
if self.data.attrs["location"] == "node":
return self.grid.node_lon
elif self._location == "face":
elif self.data.attrs["location"] == "face":
return self.grid.face_lon
elif self._location == "edge":
elif self.data.attrs["location"] == "edge":
return self.grid.edge_lon
else:
return self.data.lon
return self.gridadapter.lon

@property
def depth(self):
if type(self.data) is ux.UxDataArray:
if self._vertical_location == "center":
vertical_location = get_vertical_location_from_dims(self.data.dims)
if vertical_location == "center":
return self.grid.nz1
elif self._vertical_location == "face":
elif vertical_location == "face":
return self.grid.nz
else:
return self.data.depth
return self.gridadapter.depth

@property
def xdim(self):
if type(self.data) is xr.DataArray:
if "face_lon" in self.data.dims:
return self.data.sizes["face_lon"]
elif "node_lon" in self.data.dims:
return self.data.sizes["node_lon"]
else:
return self.data.sizes["lon"]
return self.gridadapter.xdim
else:
return 0 # TODO : Discuss what we want to return as xdim for uxdataarray obj
raise NotImplementedError("xdim not implemented for unstructured grids")

@property
def ydim(self):
if type(self.data) is xr.DataArray:
if "face_lat" in self.data.dims:
return self.data.sizes["face_lat"]
elif "node_lat" in self.data.dims:
return self.data.sizes["node_lat"]
else:
return self.data.sizes["lat"]
return self.gridadapter.ydim
else:
return 0 # TODO : Discuss what we want to return as ydim for uxdataarray obj
raise NotImplementedError("ydim not implemented for unstructured grids")

@property
def zdim(self):
if "nz1" in self.data.dims:
return self.data.sizes["nz1"]
elif "nz" in self.data.dims:
return self.data.sizes["nz"]
if type(self.data) is xr.DataArray:
return self.gridadapter.zdim
else:
return 0
if "nz1" in self.data.dims:
return self.data.sizes["nz1"]
elif "nz" in self.data.dims:
return self.data.sizes["nz"]
else:
return 0

@property
def n_face(self):
Expand Down Expand Up @@ -365,7 +320,7 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False):
tol = 1e-10
if ei is None:
# Search using global search
fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle
fi, bcoords = self.grid.get_spatial_hash().query([[x, y]]) # Get the face id for the particle
if fi == -1:
raise FieldOutOfBoundError(z, y, x)
# TODO Joe : Do the vertical grid search
Expand All @@ -389,12 +344,12 @@ def _search_indices_unstructured(self, z, y, x, ei=None, search2D=False):
return bcoords, self.ravel_index(zi, 0, neighbor)

# If we reach this point, we do a global search as a last ditch effort the particle is out of bounds
fi, bcoords = self._spatialhash.query([[x, y]]) # Get the face id for the particle
fi, bcoords = self.grid.get_spatial_hash().query([[x, y]]) # Get the face id for the particle
if fi == -1:
raise FieldOutOfBoundError(z, y, x)

def _search_indices_structured(self, z, y, x, ei=None, search2D=False):
if self._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
if self.gridadapter._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
(zeta, eta, xsi, zi, yi, xi) = _search_indices_rectilinear(self, z, y, x, ei=ei, search2D=search2D)
else:
## TODO : Still need to implement the search_indices_curvilinear
Expand Down
37 changes: 19 additions & 18 deletions parcels/v4/gridadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,67 +30,68 @@ def get_time(axis: Axis) -> npt.NDArray:
return axis._ds[axis.coords["center"]].values


class GridAdapter(NewGrid):
def __init__(self, ds, mesh="flat", *args, **kwargs):
super().__init__(ds, *args, **kwargs)
class GridAdapter:
def __init__(self, grid: NewGrid, mesh="flat"):
self.grid = grid
self.mesh = mesh

# ! Not ideal... Triggers computation on a throwaway item. If adapter is still needed in codebase, and this is prohibitively expensive, perhaps store GridAdapter on Field object instead of Grid
self.lonlat_minmax = np.array(
[
np.nanmin(self._ds["lon"]),
np.nanmax(self._ds["lon"]),
np.nanmin(self._ds["lat"]),
np.nanmax(self._ds["lat"]),
np.nanmin(self.grid._ds["lon"]),
np.nanmax(self.grid._ds["lon"]),
np.nanmin(self.grid._ds["lat"]),
np.nanmax(self.grid._ds["lat"]),
]
)

@property
def lon(self):
try:
_ = self.axes["X"]
_ = self.grid.axes["X"]
except KeyError:
return np.zeros(1)
return self._ds["lon"].values
return self.grid._ds["lon"].values

@property
def lat(self):
try:
_ = self.axes["Y"]
_ = self.grid.axes["Y"]
except KeyError:
return np.zeros(1)
return self._ds["lat"].values
return self.grid._ds["lat"].values

@property
def depth(self):
try:
_ = self.axes["Z"]
_ = self.grid.axes["Z"]
except KeyError:
return np.zeros(1)
return self._ds["depth"].values
return self.grid._ds["depth"].values

@property
def time(self):
try:
axis = self.axes["T"]
axis = self.grid.axes["T"]
except KeyError:
return np.zeros(1)
return get_time(axis)

@property
def xdim(self):
return get_dimensionality(self.axes.get("X"))
return get_dimensionality(self.grid.axes.get("X"))

@property
def ydim(self):
return get_dimensionality(self.axes.get("Y"))
return get_dimensionality(self.grid.axes.get("Y"))

@property
def zdim(self):
return get_dimensionality(self.axes.get("Z"))
return get_dimensionality(self.grid.axes.get("Z"))

@property
def tdim(self):
return get_dimensionality(self.axes.get("T"))
return get_dimensionality(self.grid.axes.get("T"))

@property
def time_origin(self):
Expand Down
11 changes: 9 additions & 2 deletions tests/v4/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,18 @@ def test_field_init_param_types():
"data,grid",
[
pytest.param(ux.UxDataArray(), Grid(xr.Dataset()), id="uxdata-grid"),
pytest.param(xr.DataArray(), ux.UxDataArray().uxgrid, id="xarray-uxgrid"),
pytest.param(
xr.DataArray(),
ux.UxDataArray().uxgrid,
id="xarray-uxgrid",
marks=pytest.mark.xfail(
reason="Replace uxDataArray object with one that actually has a grid (once unstructured example datasets are in the codebase)."
),
),
],
)
def test_field_incompatible_combination(data, grid):
with pytest.raises(ValueError, msg="Incompatible data-grid combination."):
with pytest.raises(ValueError, match="Incompatible data-grid combination."):
Field(
name="test_field",
data=data,
Expand Down
5 changes: 3 additions & 2 deletions tests/v4/test_gridadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from parcels._datasets.structured.grid_datasets import N, T, datasets
from parcels.grid import Grid as OldGrid
from parcels.tools.converters import TimeConverter
from parcels.v4.grid import Grid as NewGrid
from parcels.v4.gridadapter import GridAdapter

TestCase = namedtuple("TestCase", ["Grid", "attr", "expected"])
Expand Down Expand Up @@ -38,7 +39,7 @@ def assert_equal(actual, expected):

@pytest.mark.parametrize("ds, attr, expected", test_cases)
def test_grid_adapter_properties_ground_truth(ds, attr, expected):
adapter = GridAdapter(ds, periodic=False)
adapter = GridAdapter(NewGrid(ds, periodic=False))
actual = getattr(adapter, attr)
assert_equal(actual, expected)

Expand All @@ -60,7 +61,7 @@ def test_grid_adapter_properties_ground_truth(ds, attr, expected):
)
@pytest.mark.parametrize("ds", datasets.values())
def test_grid_adapter_against_old(ds, attr):
adapter = GridAdapter(ds, periodic=False)
adapter = GridAdapter(NewGrid(ds, periodic=False))

grid = OldGrid.create_grid(
lon=ds.lon.values,
Expand Down
Loading
Loading