diff --git a/parcels/application_kernels/interpolation.py b/parcels/application_kernels/interpolation.py index 18efd02763..1622ffcac7 100644 --- a/parcels/application_kernels/interpolation.py +++ b/parcels/application_kernels/interpolation.py @@ -1,9 +1,16 @@ """Collection of pre-built interpolation kernels.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np from parcels.field import Field +if TYPE_CHECKING: + from parcels.uxgrid import _UXGRID_AXES + __all__ = [ "UXPiecewiseConstantFace", "UXPiecewiseLinearNode", @@ -13,8 +20,7 @@ def UXPiecewiseConstantFace( field: Field, ti: int, - ei: int, - bcoords: np.ndarray, + position: dict[_UXGRID_AXES, tuple[int, float | np.ndarray]], tau: np.float32 | np.float64, t: np.float32 | np.float64, z: np.float32 | np.float64, @@ -26,15 +32,13 @@ def UXPiecewiseConstantFace( This interpolation method is appropriate for fields that are face registered, such as u,v in FESOM. """ - zi, fi = field.grid.unravel_index(ei) - return field.data.values[ti, zi, fi] + return field.data.values[ti, position["Z"][0], position["FACE"][0]] def UXPiecewiseLinearNode( field: Field, ti: int, - ei: int, - bcoords: np.ndarray, + position: dict[_UXGRID_AXES, tuple[int, float | np.ndarray]], tau: np.float32 | np.float64, t: np.float32 | np.float64, z: np.float32 | np.float64, @@ -47,7 +51,8 @@ def UXPiecewiseLinearNode( velocity W in FESOM2. Effectively, it applies barycentric interpolation in the lateral direction and piecewise linear interpolation in the vertical direction. """ - k, fi = field.grid.unravel_index(ei) + k, fi = position["Z"][0], position["FACE"][0] + bcoords = position["FACE"][1] node_ids = field.grid.uxgrid.face_node_connectivity[fi, :] # The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels. # For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1. diff --git a/parcels/field.py b/parcels/field.py index 9c7ae9f631..ce82d4c781 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -105,8 +105,7 @@ class Field: def _interp_template( self, ti: int, - ei: int, - bcoords: np.ndarray, + position: dict[str, tuple[int, float | np.ndarray]], tau: np.float32 | np.float64, t: np.float32 | np.float64, z: np.float32 | np.float64, @@ -316,8 +315,8 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): try: tau, ti = _search_time_index(self, time) - bcoords, _ei = self.grid.search(z, y, x, ei=_ei) - value = self._interp_method(self, ti, _ei, bcoords, tau, time, z, y, x) + position = self.grid.search(z, y, x, ei=_ei) + value = self._interp_method(self, ti, position, tau, time, z, y, x) if np.isnan(value): # Detect Out-of-bounds sampling and raise exception @@ -445,14 +444,14 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): try: tau, ti = _search_time_index(self.U, time) - bcoords, _ei = self.grid.search(z, y, x, ei=_ei) + position = self.grid.search(z, y, x, ei=_ei) if self._vector_interp_method is None: - u = self.U._interp_method(self.U, ti, _ei, bcoords, tau, time, z, y, x) - v = self.V._interp_method(self.V, ti, _ei, bcoords, tau, time, z, y, x) + u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x) + v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x) if "3D" in self.vector_type: - w = self.W._interp_method(self.W, ti, _ei, bcoords, tau, time, z, y, x) + w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x) else: - (u, v, w) = self._vector_interp_method(self, ti, _ei, bcoords, time, z, y, x) + (u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x) # print(u,v) if applyConversion: diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index 850b7b8a00..641213b7a0 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -1,13 +1,18 @@ from __future__ import annotations +from typing import Literal + import numpy as np import uxarray as ux from uxarray.grid.neighbors import _barycentric_coordinates from parcels.field import FieldOutOfBoundError # Adjust import as necessary +from parcels.xgrid import _search_1d_array from .basegrid import BaseGrid +_UXGRID_AXES = Literal["Z", "FACE"] + class UxGrid(BaseGrid): """ @@ -49,9 +54,7 @@ def depth(self): return np.zeros(1) return self.z.values - def search( - self, z: float, y: float, x: float, ei: int | None = None, search2D: bool = False - ) -> tuple[np.ndarray, int]: + def search(self, z, y, x, ei=None): tol = 1e-10 def try_face(fid): @@ -60,21 +63,7 @@ def try_face(fid): return bcoords, fid return None, None - def find_vertical_index() -> int: - if search2D: - return 0 - else: - nz = self.z.shape[0] - if nz == 1: - return 0 - zf = self.z.values - # Return zi such that zf[zi] <= z < zf[zi+1] - zi = np.searchsorted(zf, z, side="right") - 1 # Search assumes that z is positive and increasing with i - if zi < 0 or zi >= nz - 1: - raise FieldOutOfBoundError(z, y, x) - return zi - - zi = find_vertical_index() # Find the vertical cell center nearest to z + zi, zeta = _search_1d_array(self.z.values, z) if ei is not None: _, fi = self.unravel_index(ei) @@ -94,7 +83,7 @@ def find_vertical_index() -> int: if fi == -1: raise FieldOutOfBoundError(z, y, x) - return bcoords[0], self.ravel_index(zi, fi[0]) + return {"Z": (zi, zeta), "FACE": (fi, bcoords[0])} def _get_barycentric_coordinates(self, y, x, fi): """Checks if a point is inside a given face id on a UxGrid.""" @@ -113,40 +102,10 @@ def _get_barycentric_coordinates(self, y, x, fi): err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs(np.dot(bcoord, nodes[:, 1]) - coord[1]) return bcoord, err - def ravel_index(self, zi, fi): - """ - Converts a face index and a vertical index into a single encoded index. - - Parameters - ---------- - zi : int - Vertical index (not used in unstructured grids, but kept for compatibility). - fi : int - Face index. - - Returns - ------- - int - Encoded index combining the face index and vertical index. - """ - return fi + self.uxgrid.n_face * zi - - def unravel_index(self, ei): - """ - Converts a single encoded index back into a vertical index and face index. + def ravel_index(self, axis_indices: dict[_UXGRID_AXES, int]): + return axis_indices["FACE"] + self.uxgrid.n_face * axis_indices["Z"] - Parameters - ---------- - ei : int - Encoded index to be unraveled. - - Returns - ------- - zi : int - Vertical index. - fi : int - Face index. - """ + def unravel_index(self, ei) -> dict[_UXGRID_AXES, int]: zi = ei // self.uxgrid.n_face fi = ei % self.uxgrid.n_face - return zi, fi + return {"Z": zi, "FACE": fi} diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 5d536a39ed..7ca1af7f00 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -15,7 +15,6 @@ _XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] _XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] -_AXIS_DIRECTION = Literal["X", "Y", "Z"] _XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis] @@ -196,13 +195,13 @@ def search(self, z, y, x, ei=None): raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") - def ravel_index(self, axis_indices: dict[_AXIS_DIRECTION, int]) -> int: + def ravel_index(self, axis_indices: dict[_XGRID_AXES, int]) -> int: xi = axis_indices.get("X", 0) yi = axis_indices.get("Y", 0) zi = axis_indices.get("Z", 0) return xi + self.xdim * yi + self.xdim * self.ydim * zi - def unravel_index(self, ei) -> dict[_AXIS_DIRECTION, int]: + def unravel_index(self, ei) -> dict[_XGRID_AXES, int]: zi = ei // (self.xdim * self.ydim) ei = ei % (self.xdim * self.ydim)