diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 805e5e1eb8..ca82359451 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -21,6 +21,8 @@ from .grid import GridType if TYPE_CHECKING: + from parcels.xgrid import XGrid + from .field import Field # from .grid import Grid @@ -270,6 +272,89 @@ def _search_indices_rectilinear( return (zeta, eta, xsi, _ei) +def _search_indices_curvilinear_2d( + grid: XGrid, y: float, x: float, yi_guess: int | None = None, xi_guess: int | None = None +): + yi, xi = yi_guess, xi_guess + if yi is None: + yi = int(grid.ydim / 2) - 1 + + if xi is None: + xi = int(grid.xdim / 2) - 1 + + xsi = eta = -1.0 + invA = np.array( + [ + [1, 0, 0, 0], + [-1, 1, 0, 0], + [-1, 0, 0, 1], + [1, -1, 1, -1], + ] + ) + maxIterSearch = 1e6 + it = 0 + tol = 1.0e-10 + + # # ! Error handling for out of bounds + # TODO: Re-enable in some capacity + # if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: + # if grid.lon[0, 0] < grid.lon[0, -1]: + # _raise_field_out_of_bound_error(y, x) + # elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] + # _raise_field_out_of_bound_error(z, y, x) + + # if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: + # _raise_field_out_of_bound_error(z, y, x) + + while xsi < -tol or xsi > 1 + tol or eta < -tol or eta > 1 + tol: + px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) + + py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]]) + a = np.dot(invA, px) + b = np.dot(invA, py) + + aa = a[3] * b[2] - a[2] * b[3] + bb = a[3] * b[0] - a[0] * b[3] + a[1] * b[2] - a[2] * b[1] + x * b[3] - y * a[3] + cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1] + if abs(aa) < 1e-12: # Rectilinear cell, or quasi + eta = -cc / bb + else: + det2 = bb * bb - 4 * aa * cc + if det2 > 0: # so, if det is nan we keep the xsi, eta from previous iter + det = np.sqrt(det2) + eta = (-bb + det) / (2 * aa) + if abs(a[1] + a[3] * eta) < 1e-12: # this happens when recti cell rotated of 90deg + xsi = ((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5 + else: + xsi = (x - a[0] - a[2] * eta) / (a[1] + a[3] * eta) + if xsi < 0 and eta < 0 and xi == 0 and yi == 0: + _raise_field_out_of_bound_error(0, y, x) + if xsi > 1 and eta > 1 and xi == grid.xdim - 1 and yi == grid.ydim - 1: + _raise_field_out_of_bound_error(0, y, x) + if xsi < -tol: + xi -= 1 + elif xsi > 1 + tol: + xi += 1 + if eta < -tol: + yi -= 1 + elif eta > 1 + tol: + yi += 1 + (yi, xi) = _reconnect_bnd_indices(yi, xi, grid.ydim, grid.xdim, grid.mesh) + it += 1 + if it > maxIterSearch: + print(f"Correct cell not found after {maxIterSearch} iterations") + _raise_field_out_of_bound_error(0, y, x) + xsi = max(0.0, xsi) + eta = max(0.0, eta) + xsi = min(1.0, xsi) + eta = min(1.0, eta) + + if not ((0 <= xsi <= 1) and (0 <= eta <= 1)): + _raise_field_sampling_error(y, x) + + return (yi, eta, xi, xsi) + + ## TODO : Still need to implement the search_indices_curvilinear def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False): if particle: diff --git a/parcels/basegrid.py b/parcels/basegrid.py index 813a7ced7d..a7ebcfb42a 100644 --- a/parcels/basegrid.py +++ b/parcels/basegrid.py @@ -1,15 +1,21 @@ +from __future__ import annotations + from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np class BaseGrid(ABC): @abstractmethod - def search(self, z: float, y: float, x: float, ei=None, search2D: bool = False): + def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int, float | np.ndarray]]: """ Perform a spatial (and optionally vertical) search to locate the grid element that contains a given point (x, y, z). This method delegates to grid-type-specific logic (e.g., structured or unstructured) - to determine the appropriate indices and interpolation coordinates for evaluating a field. + to determine the appropriate indices and barycentric coordinates for evaluating a field. Parameters ---------- @@ -28,12 +34,21 @@ def search(self, z: float, y: float, x: float, ei=None, search2D: bool = False): Returns ------- - bcoords : np.ndarray or tuple - Interpolation weights or barycentric coordinates within the containing cell/face. - The interpretation of `bcoords` depends on the grid type. - ei : int - Encoded index of the identified grid cell or face. This value can be cached for - future lookups to accelerate repeated searches. + dict + A dictionary mapping spatial axis names to tuples of (index, barycentric_coordinates). + The returned axes depend on the grid dimensionality and type: + + - 3D structured grid: {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} + - 2D structured grid: {"Y": (yi, eta), "X": (xi, xsi)} + - 1D structured grid (depth): {"Z": (zi, zeta)} + - Unstructured grid: {"Z": (zi, zeta), "FACE": (fi, bcoords)} + + Where: + - index (int): The cell position of a particle along the given axis + - barycentric_coordinates (float or np.ndarray): The coordinates defining + a particle's position within the grid cell. For structured grids, this + is a single coordinate per axis; for unstructured grids, this can be + an array of coordinates for the face polygon. Raises ------ @@ -43,3 +58,72 @@ def search(self, z: float, y: float, x: float, ei=None, search2D: bool = False): Raised if the search method is not implemented for the current grid type. """ ... + + @abstractmethod + def ravel_index(self, axis_indices: dict[str, int]) -> int: + """ + Convert a dictionary of axis indices to a single encoded index (ei). + + This method takes the individual indices for each spatial axis and combines them + into a single integer that uniquely identifies a grid cell. This encoded + index can be used for efficient caching and lookup operations. + + Parameters + ---------- + axis_indices : dict[str, int] + A dictionary mapping axis names to their corresponding indices. + The expected keys depend on the grid dimensionality and type: + + - 3D structured grid: {"Z": zi, "Y": yi, "X": xi} + - 2D structured grid: {"Y": yi, "X": xi} + - 1D structured grid: {"Z": zi} + - Unstructured grid: {"Z": zi, "FACE": fi} + + Returns + ------- + int + The encoded index (ei) representing the unique grid cell or face. + + Raises + ------ + KeyError + Raised when required axis keys are missing from axis_indices. + ValueError + Raised when index values are out of bounds for the grid. + NotImplementedError + Raised if the method is not implemented for the current grid type. + """ + ... + + @abstractmethod + def unravel_index(self, ei: int) -> dict[str, int]: + """ + Convert a single encoded index (ei) back to a dictionary of axis indices. + + This method is the inverse of ravel_index, taking an encoded index and + decomposing it back into the individual indices for each spatial axis. + + Parameters + ---------- + ei : int + The encoded index representing a unique grid cell or face. + + Returns + ------- + dict[str, int] + A dictionary mapping axis names to their corresponding indices. + The returned keys depend on the grid dimensionality and type: + + - 3D structured grid: {"Z": zi, "Y": yi, "X": xi} + - 2D structured grid: {"Y": yi, "X": xi} + - 1D structured grid: {"Z": zi} + - Unstructured grid: {"Z": zi, "FACE": fi} + + Raises + ------ + ValueError + Raised when the encoded index is out of bounds or invalid for the grid. + NotImplementedError + Raised if the method is not implemented for the current grid type. + """ + ... diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 3948d82871..c0bb066513 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -1,29 +1,28 @@ -from typing import Literal +from collections.abc import Hashable, Mapping +from typing import Literal, cast import numpy as np import numpy.typing as npt +import xarray as xr from parcels import xgcm +from parcels._index_search import _search_indices_curvilinear_2d from parcels.basegrid import BaseGrid from parcels.tools.converters import TimeConverter +_XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"] +_XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"] +_AXIS_DIRECTION = Literal["X", "Y", "Z"] +_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis] -def get_dimensionality(axis: xgcm.Axis | None) -> int: + +def get_cell_edge_count_along_dim(axis: xgcm.Axis | None) -> int: if axis is None: return 1 first_coord = list(axis.coords.items())[0] - pos, coord = first_coord - - pos_to_dim = { # TODO: These could do with being explicitly tested - "center": lambda x: x, - "left": lambda x: x, - "right": lambda x: x, - "inner": lambda x: x + 1, - "outer": lambda x: x - 1, - } + _, coord_var = first_coord - n = axis._ds[coord].size - return pos_to_dim[pos](n) + return axis._ds[coord_var].size def get_time(axis: xgcm.Axis) -> npt.NDArray: @@ -35,11 +34,17 @@ class XGrid(BaseGrid): Class to represent a structured grid in Parcels. Wraps a xgcm-like Grid object (we use a trimmed down version of the xgcm.Grid class that is vendored with Parcels). This class provides methods and properties required for indexing and interpolating on the grid. + + Assumptions: + - If using Parcels in the context of a spatially periodic simulation, the provided grid already has a halo + """ def __init__(self, grid: xgcm.Grid, mesh="flat"): self.xgcm_grid = grid self.mesh = mesh + ds = grid._ds + assert_valid_lat_lon(ds["lat"], ds["lon"], grid.axes) # ! Not ideal... Triggers computation on a throwaway item. Keeping for now for v3 compat, will be removed in v4. self.lonlat_minmax = np.array( @@ -107,19 +112,19 @@ def time(self): @property def xdim(self): - return get_dimensionality(self.xgcm_grid.axes.get("X")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("X")) @property def ydim(self): - return get_dimensionality(self.xgcm_grid.axes.get("Y")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("Y")) @property def zdim(self): - return get_dimensionality(self.xgcm_grid.axes.get("Z")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("Z")) @property def tdim(self): - return get_dimensionality(self.xgcm_grid.axes.get("T")) + return get_cell_edge_count_along_dim(self.xgcm_grid.axes.get("T")) @property def time_origin(self): @@ -164,4 +169,192 @@ def _gtype(self): else: return GridType.CurvilinearSGrid - def search(self, z, y, x, ei=None, search2D=False): ... + def search(self, z, y, x, ei=None): + ds = self.xgcm_grid._ds + + zi, zeta = _search_1d_array(ds.depth.values, z) + + if ds.lon.ndim == 1: + yi, eta = _search_1d_array(ds.lat.values, y) + xi, xsi = _search_1d_array(ds.lon.values, x) + return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} + + yi, xi = None, None + if ei is not None: + axis_indices = self.unravel_index(ei) + xi = axis_indices.get("X") + yi = axis_indices.get("Y") + + if ds.lon.ndim == 2: + yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi) + + return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} + + raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.") + + def ravel_index(self, axis_indices: dict[_AXIS_DIRECTION, int]) -> int: + xi = axis_indices.get("X", 0) + yi = axis_indices.get("Y", 0) + zi = axis_indices.get("Z", 0) + return xi + self.xdim * yi + self.xdim * self.ydim * zi + + def unravel_index(self, ei) -> dict[_AXIS_DIRECTION, int]: + zi = ei // (self.xdim * self.ydim) + ei = ei % (self.xdim * self.ydim) + + yi = ei // self.xdim + xi = ei % self.xdim + return {"Z": zi, "Y": yi, "X": xi} + + +def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: + """For a given dimension name in a grid, returns the direction axis it is on.""" + for axis_name, axis in axes.items(): + if dim in axis.coords.values(): + return axis_name + return None + + +def get_xgcm_position_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_POSITION | None: + """For a given dimension, returns the position of the variable in the grid.""" + for axis in axes.values(): + var_to_position = {var: position for position, var in axis.coords.items()} + + if dim in var_to_position: + return var_to_position[dim] + return None + + +def assert_all_dimensions_correspond_with_axis(da: xr.DataArray, axes: _XGCM_AXES) -> None: + dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} + + for dim, direction in dim_to_axis.items(): + if direction is None: + raise ValueError( + f"Dimension {dim!r} for DataArray {da.name!r} with dims {da.dims} is not associated with a direction on the provided grid." + ) + + +def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES): + """ + Asserts that for a data array: + - All dimensions are associated with a direction on the grid + - These directions are T, Z, Y, X and the array is ordered as T, Z, Y, X + """ + assert_all_dimensions_correspond_with_axis(da, axes) + + dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims} + dim_to_axis = cast(dict[Hashable, _XGCM_AXIS_DIRECTION], dim_to_axis) + + # Assert all dimensions are present + if set(dim_to_axis.values()) != {"T", "Z", "Y", "X"}: + raise ValueError( + f"DataArray {da.name!r} with dims {da.dims} has directions {tuple(dim_to_axis.values())}." + "Expected directions of 'T', 'Z', 'Y', and 'X'." + ) + + # Assert order is t, z, y, x + if list(dim_to_axis.values()) != ["T", "Z", "Y", "X"]: + raise ValueError( + f"Dimension order for array {da.name!r} is not valid. Got {tuple(dim_to_axis.keys())} with associated directions of {tuple(dim_to_axis.values())}. Expected directions of ('T', 'Z', 'Y', 'X'). Transpose your array accordingly." + ) + + +def assert_valid_lat_lon(da_lat, da_lon, axes: _XGCM_AXES): + """ + Asserts that the provided longitude and latitude DataArrays are defined appropriately + on the F points to match the internal representation in Parcels. + + - Longitude and latitude must be 1D or 2D (both must have the same dimensionality) + - Both are defined on the left points (i.e., not the centers) + - If 1D: + - Longitude is associated with the X axis + - Latitude is associated with the Y axis + - If 2D: + - Lon and lat are defined on the same dimensions + - Lon and lat are transposed such they're Y, X + """ + assert_all_dimensions_correspond_with_axis(da_lon, axes) + assert_all_dimensions_correspond_with_axis(da_lat, axes) + + dim_to_position = {dim: get_xgcm_position_from_dim_name(axes, dim) for dim in da_lon.dims} + dim_to_position.update({dim: get_xgcm_position_from_dim_name(axes, dim) for dim in da_lat.dims}) + + for dim in da_lon.dims: + if get_xgcm_position_from_dim_name(axes, dim) == "center": + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} is defined on the center of the grid, but must be defined on the F points." + ) + for dim in da_lat.dims: + if get_xgcm_position_from_dim_name(axes, dim) == "center": + raise ValueError( + f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is defined on the center of the grid, but must be defined on the F points." + ) + + if da_lon.ndim != da_lat.ndim: + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} have different dimensionalities." + ) + if da_lon.ndim not in (1, 2): + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be 1D or 2D." + ) + + if da_lon.ndim == 1: + if get_axis_from_dim_name(axes, da_lon.dims[0]) != "X": + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} is not associated with the X axis." + ) + if get_axis_from_dim_name(axes, da_lat.dims[0]) != "Y": + raise ValueError( + f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} is not associated with the Y axis." + ) + + if not np.all(np.diff(da_lon.values) > 0): + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} must be strictly increasing." + ) + if not np.all(np.diff(da_lat.values) > 0): + raise ValueError(f"Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be strictly increasing.") + + if da_lon.ndim == 2: + if da_lon.dims != da_lat.dims: + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the same dimensions." + ) + + lon_axes = [get_axis_from_dim_name(axes, dim) for dim in da_lon.dims] + if lon_axes != ["Y", "X"]: + raise ValueError( + f"Longitude DataArray {da_lon.name!r} with dims {da_lon.dims} and Latitude DataArray {da_lat.name!r} with dims {da_lat.dims} must be defined on the X and Y axes and transposed to have dimensions in order of Y, X." + ) + + +def _search_1d_array( + arr: np.array, + x: float, +) -> tuple[int, int]: + """ + Searches for the particle location in a 1D array and returns barycentric coordinate along dimension. + + Assumptions: + - particle position x is within the bounds of the array + - array is strictly monotonically increasing. + + Parameters + ---------- + arr : np.array + 1D array to search in. + x : float + Position in the 1D array to search for. + + Returns + ------- + int + Index of the element just before the position x in the array. + float + Barycentric coordinate. + """ + i = np.argmin(arr <= x) - 1 + bcoord = (x - arr[i]) / (arr[i + 1] - arr[i]) + return i, bcoord diff --git a/tests/v4/test_index_search.py b/tests/v4/test_index_search.py new file mode 100644 index 0000000000..7f9290f12c --- /dev/null +++ b/tests/v4/test_index_search.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +from parcels import xgcm +from parcels._datasets.structured.generic import datasets +from parcels._index_search import _search_indices_curvilinear_2d +from parcels.field import Field +from parcels.xgrid import ( + XGrid, +) + + +@pytest.fixture +def field_cone(): + ds = datasets["2d_left_unrolled_cone"] + grid = XGrid(xgcm.Grid(ds, periodic=False)) + field = Field( + name="test_field", + data=ds["data_g"], + grid=grid, + ) + return field + + +def test_grid_indexing_fpoints(field_cone): + grid = field_cone.grid + + for yi_expected in range(grid.ydim - 1): + for xi_expected in range(grid.xdim - 1): + x = grid.lon[yi_expected, xi_expected] + 0.00001 + y = grid.lat[yi_expected, xi_expected] + 0.00001 + + yi, eta, xi, xsi = _search_indices_curvilinear_2d(grid, y, x) + if eta > 0.9: + yi_expected -= 1 + if xsi > 0.9: + xi_expected -= 1 + assert yi == yi_expected, f"Expected yi {yi_expected} but got {yi}" + assert xi == xi_expected, f"Expected xi {xi_expected} but got {xi}" + + cell_lon = [ + grid.lon[yi, xi], + grid.lon[yi, xi + 1], + grid.lon[yi + 1, xi + 1], + grid.lon[yi + 1, xi], + ] + cell_lat = [ + grid.lat[yi, xi], + grid.lat[yi, xi + 1], + grid.lat[yi + 1, xi + 1], + grid.lat[yi + 1, xi], + ] + assert x > np.min(cell_lon) and x < np.max(cell_lon) + assert y > np.min(cell_lat) and y < np.max(cell_lat) diff --git a/tests/v4/test_xgrid.py b/tests/v4/test_xgrid.py index a1c0d9472d..39a5ee614d 100644 --- a/tests/v4/test_xgrid.py +++ b/tests/v4/test_xgrid.py @@ -2,13 +2,14 @@ import numpy as np import pytest +import xarray as xr from numpy.testing import assert_allclose from parcels import xgcm from parcels._datasets.structured.generic import T, X, Y, Z, datasets from parcels.grid import Grid as OldGrid from parcels.tools.converters import TimeConverter -from parcels.xgrid import XGrid +from parcels.xgrid import XGrid, _search_1d_array GridTestCase = namedtuple("GridTestCase", ["Grid", "attr", "expected"]) @@ -59,7 +60,7 @@ def test_xgrid_properties_ground_truth(ds, attr, expected): "_gtype", ], ) -@pytest.mark.parametrize("ds", datasets.values()) +@pytest.mark.parametrize("ds", [pytest.param(ds, id=key) for key, ds in datasets.items()]) def test_xgrid_against_old(ds, attr): grid = XGrid(xgcm.Grid(ds, periodic=False)) @@ -74,3 +75,121 @@ def test_xgrid_against_old(ds, attr): actual = getattr(grid, attr) expected = getattr(old_grid, attr) assert_equal(actual, expected) + + +@pytest.mark.parametrize("ds", [pytest.param(ds, id=key) for key, ds in datasets.items()]) +def test_grid_init_on_generic_datasets(ds): + XGrid(xgcm.Grid(ds, periodic=False)) + + +def test_invalid_xgrid_field_array(): + """Stress test initialiser by creating incompatible datasets that test the edge cases""" + ... + + +def test_invalid_lon_lat(): + """Stress test the grid initialiser by creating incompatible datasets that test the edge cases""" + ds = datasets["ds_2d_left"].copy() + ds["lon"], ds["lat"] = xr.broadcast(ds["YC"], ds["XC"]) + + with pytest.raises( + ValueError, + match=".*is defined on the center of the grid, but must be defined on the F points\.", + ): + XGrid(xgcm.Grid(ds, periodic=False)) + + ds = datasets["ds_2d_left"].copy() + ds["lon"], _ = xr.broadcast(ds["YG"], ds["XG"]) + with pytest.raises( + ValueError, + match=".*have different dimensionalities\.", + ): + XGrid(xgcm.Grid(ds, periodic=False)) + + ds = datasets["ds_2d_left"].copy() + ds["lon"], ds["lat"] = xr.broadcast(ds["YG"], ds["XG"]) + ds["lon"], ds["lat"] = ds["lon"].transpose(), ds["lat"].transpose() + + with pytest.raises( + ValueError, + match=".*must be defined on the X and Y axes and transposed to have dimensions in order of Y, X\.", + ): + XGrid(xgcm.Grid(ds, periodic=False)) + + +def test_xgrid_ravel_unravel_index(): + ds = datasets["ds_2d_left"] + grid = XGrid(xgcm.Grid(ds, periodic=False)) + + xdim = grid.xdim + ydim = grid.ydim + zdim = grid.zdim + + encountered_eis = [] + for xi in range(xdim): + for yi in range(ydim): + for zi in range(zdim): + axis_indices = {"Z": zi, "Y": yi, "X": xi} + ei = grid.ravel_index(axis_indices) + axis_indices_test = grid.unravel_index(ei) + assert axis_indices_test == axis_indices + encountered_eis.append(ei) + + encountered_eis = sorted(encountered_eis) + assert len(set(encountered_eis)) == len(encountered_eis), "Raveled indices are not unique." + assert np.allclose(np.diff(np.array(encountered_eis)), 1), "Raveled indices are not consecutive integers." + assert encountered_eis[0] == 0, "Raveled indices do not start at 0." + + +@pytest.mark.parametrize( + "ds", + [ + pytest.param(datasets["ds_2d_left"], id="1D lon/lat"), + pytest.param(datasets["2d_left_rotated"], id="2D lon/lat"), + ], +) # for key, ds in datasets.items()]) +def test_xgrid_search_cpoints(ds): + grid = XGrid(xgcm.Grid(ds, periodic=False)) + lat_array, lon_array = get_2d_fpoint_mesh(grid) + lat_array, lon_array = corner_to_cell_center_points(lat_array, lon_array) + + for xi in range(grid.xdim - 1): + for yi in range(grid.ydim - 1): + axis_indices = {"Z": 0, "Y": yi, "X": xi} + + lat, lon = lat_array[yi, xi], lon_array[yi, xi] + axis_indices_bcoords = grid.search(0, lat, lon, ei=None) + axis_indices_test = {k: v[0] for k, v in axis_indices_bcoords.items()} + assert axis_indices == axis_indices_test + + # assert np.isclose(bcoords[0], 0.5) #? Should this not be the case with the cell center points? + # assert np.isclose(bcoords[1], 0.5) + + +def get_2d_fpoint_mesh(grid: XGrid): + lat, lon = grid.lat, grid.lon + if lon.ndim == 1: + lat, lon = np.meshgrid(lat, lon, indexing="ij") + return lat, lon + + +def corner_to_cell_center_points(lat, lon): + """Convert F points to C points.""" + lon_c = (lon[:-1, :-1] + lon[:-1, 1:]) / 2 + lat_c = (lat[:-1, :-1] + lat[1:, :-1]) / 2 + return lat_c, lon_c + + +@pytest.mark.parametrize( + "array, x, expected_xi, expected_xsi", + [ + (np.array([1, 2, 3, 4, 5]), 1.1, 0, 0.1), + (np.array([1, 2, 3, 4, 5]), 2.1, 1, 0.1), + (np.array([1, 2, 3, 4, 5]), 3.1, 2, 0.1), + (np.array([1, 2, 3, 4, 5]), 4.5, 3, 0.5), + ], +) +def test_search_1d_array(array, x, expected_xi, expected_xsi): + xi, xsi = _search_1d_array(array, x) + assert xi == expected_xi + assert np.isclose(xsi, expected_xsi)