Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Collection of pre-built interpolation kernels."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from parcels.field import Field

if TYPE_CHECKING:
from parcels.uxgrid import _UXGRID_AXES

Check warning on line 12 in parcels/application_kernels/interpolation.py

View check run for this annotation

Codecov / codecov/patch

parcels/application_kernels/interpolation.py#L12

Added line #L12 was not covered by tests

__all__ = [
"UXPiecewiseConstantFace",
"UXPiecewiseLinearNode",
Expand All @@ -13,8 +20,7 @@
def UXPiecewiseConstantFace(
field: Field,
ti: int,
ei: int,
bcoords: np.ndarray,
position: dict[_UXGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
Expand All @@ -26,15 +32,13 @@
This interpolation method is appropriate for fields that are
face registered, such as u,v in FESOM.
"""
zi, fi = field.grid.unravel_index(ei)
return field.data.values[ti, zi, fi]
return field.data.values[ti, position["Z"][0], position["FACE"][0]]


def UXPiecewiseLinearNode(
field: Field,
ti: int,
ei: int,
bcoords: np.ndarray,
position: dict[_UXGRID_AXES, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
Expand All @@ -47,7 +51,8 @@
velocity W in FESOM2. Effectively, it applies barycentric interpolation in the lateral direction
and piecewise linear interpolation in the vertical direction.
"""
k, fi = field.grid.unravel_index(ei)
k, fi = position["Z"][0], position["FACE"][0]
bcoords = position["FACE"][1]
node_ids = field.grid.uxgrid.face_node_connectivity[fi, :]
# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
Expand Down
17 changes: 8 additions & 9 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@
def _interp_template(
self,
ti: int,
ei: int,
bcoords: np.ndarray,
position: dict[str, tuple[int, float | np.ndarray]],
tau: np.float32 | np.float64,
t: np.float32 | np.float64,
z: np.float32 | np.float64,
Expand Down Expand Up @@ -316,8 +315,8 @@

try:
tau, ti = _search_time_index(self, time)
bcoords, _ei = self.grid.search(z, y, x, ei=_ei)
value = self._interp_method(self, ti, _ei, bcoords, tau, time, z, y, x)
position = self.grid.search(z, y, x, ei=_ei)
value = self._interp_method(self, ti, position, tau, time, z, y, x)

if np.isnan(value):
# Detect Out-of-bounds sampling and raise exception
Expand Down Expand Up @@ -445,14 +444,14 @@

try:
tau, ti = _search_time_index(self.U, time)
bcoords, _ei = self.grid.search(z, y, x, ei=_ei)
position = self.grid.search(z, y, x, ei=_ei)

Check warning on line 447 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L447

Added line #L447 was not covered by tests
if self._vector_interp_method is None:
u = self.U._interp_method(self.U, ti, _ei, bcoords, tau, time, z, y, x)
v = self.V._interp_method(self.V, ti, _ei, bcoords, tau, time, z, y, x)
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)

Check warning on line 450 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L449-L450

Added lines #L449 - L450 were not covered by tests
if "3D" in self.vector_type:
w = self.W._interp_method(self.W, ti, _ei, bcoords, tau, time, z, y, x)
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)

Check warning on line 452 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L452

Added line #L452 was not covered by tests
else:
(u, v, w) = self._vector_interp_method(self, ti, _ei, bcoords, time, z, y, x)
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)

Check warning on line 454 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L454

Added line #L454 was not covered by tests

# print(u,v)
if applyConversion:
Expand Down
65 changes: 12 additions & 53 deletions parcels/uxgrid.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

from typing import Literal

import numpy as np
import uxarray as ux
from uxarray.grid.neighbors import _barycentric_coordinates

from parcels.field import FieldOutOfBoundError # Adjust import as necessary
from parcels.xgrid import _search_1d_array

from .basegrid import BaseGrid

_UXGRID_AXES = Literal["Z", "FACE"]


class UxGrid(BaseGrid):
"""
Expand Down Expand Up @@ -49,9 +54,7 @@
return np.zeros(1)
return self.z.values

def search(
self, z: float, y: float, x: float, ei: int | None = None, search2D: bool = False
) -> tuple[np.ndarray, int]:
def search(self, z, y, x, ei=None):
tol = 1e-10

def try_face(fid):
Expand All @@ -60,21 +63,7 @@
return bcoords, fid
return None, None

def find_vertical_index() -> int:
if search2D:
return 0
else:
nz = self.z.shape[0]
if nz == 1:
return 0
zf = self.z.values
# Return zi such that zf[zi] <= z < zf[zi+1]
zi = np.searchsorted(zf, z, side="right") - 1 # Search assumes that z is positive and increasing with i
if zi < 0 or zi >= nz - 1:
raise FieldOutOfBoundError(z, y, x)
return zi

zi = find_vertical_index() # Find the vertical cell center nearest to z
zi, zeta = _search_1d_array(self.z.values, z)

if ei is not None:
_, fi = self.unravel_index(ei)
Expand All @@ -94,7 +83,7 @@
if fi == -1:
raise FieldOutOfBoundError(z, y, x)

return bcoords[0], self.ravel_index(zi, fi[0])
return {"Z": (zi, zeta), "FACE": (fi, bcoords[0])}

def _get_barycentric_coordinates(self, y, x, fi):
"""Checks if a point is inside a given face id on a UxGrid."""
Expand All @@ -113,40 +102,10 @@
err = abs(np.dot(bcoord, nodes[:, 0]) - coord[0]) + abs(np.dot(bcoord, nodes[:, 1]) - coord[1])
return bcoord, err

def ravel_index(self, zi, fi):
"""
Converts a face index and a vertical index into a single encoded index.

Parameters
----------
zi : int
Vertical index (not used in unstructured grids, but kept for compatibility).
fi : int
Face index.

Returns
-------
int
Encoded index combining the face index and vertical index.
"""
return fi + self.uxgrid.n_face * zi

def unravel_index(self, ei):
"""
Converts a single encoded index back into a vertical index and face index.
def ravel_index(self, axis_indices: dict[_UXGRID_AXES, int]):
return axis_indices["FACE"] + self.uxgrid.n_face * axis_indices["Z"]

Check warning on line 106 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L106

Added line #L106 was not covered by tests

Parameters
----------
ei : int
Encoded index to be unraveled.

Returns
-------
zi : int
Vertical index.
fi : int
Face index.
"""
def unravel_index(self, ei) -> dict[_UXGRID_AXES, int]:
zi = ei // self.uxgrid.n_face
fi = ei % self.uxgrid.n_face
return zi, fi
return {"Z": zi, "FACE": fi}

Check warning on line 111 in parcels/uxgrid.py

View check run for this annotation

Codecov / codecov/patch

parcels/uxgrid.py#L111

Added line #L111 was not covered by tests
5 changes: 2 additions & 3 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

_XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"]
_XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"]
_AXIS_DIRECTION = Literal["X", "Y", "Z"]
_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis]


Expand Down Expand Up @@ -196,13 +195,13 @@ def search(self, z, y, x, ei=None):

raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")

def ravel_index(self, axis_indices: dict[_AXIS_DIRECTION, int]) -> int:
def ravel_index(self, axis_indices: dict[_XGRID_AXES, int]) -> int:
xi = axis_indices.get("X", 0)
yi = axis_indices.get("Y", 0)
zi = axis_indices.get("Z", 0)
return xi + self.xdim * yi + self.xdim * self.ydim * zi

def unravel_index(self, ei) -> dict[_AXIS_DIRECTION, int]:
def unravel_index(self, ei) -> dict[_XGRID_AXES, int]:
zi = ei // (self.xdim * self.ydim)
ei = ei % (self.xdim * self.ydim)

Expand Down
Loading