Skip to content

Commit

Permalink
Define .dims property. FutureWarning on .dimensions. Fixes #217
Browse files Browse the repository at this point in the history
  • Loading branch information
Huite committed Aug 23, 2024
1 parent 2e0e04a commit 4f16821
Show file tree
Hide file tree
Showing 10 changed files with 79 additions and 42 deletions.
6 changes: 4 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ UGRID1D Topology
Ugrid1d

Ugrid1d.topology_dimension
Ugrid1d.dimensions
Ugrid1d.dims
Ugrid1d.sizes
Ugrid1d.attrs

Ugrid1d.n_node
Expand Down Expand Up @@ -249,7 +250,8 @@ UGRID2D Topology
Ugrid2d

Ugrid2d.topology_dimension
Ugrid2d.dimensions
Ugrid2d.dims
Ugrid2d.sizes
Ugrid2d.attrs

Ugrid2d.n_node
Expand Down
2 changes: 1 addition & 1 deletion tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_labels_to_indices():

def test_single_ugrid_chunk():
grid = generate_mesh_2d(3, 3)
ugrid_dims = set(grid.dimensions)
ugrid_dims = set(grid.dims)
da = xr.DataArray(np.ones(grid.n_face), dims=(grid.face_dimension,))
assert pt.single_ugrid_chunk(da, ugrid_dims) is da

Expand Down
5 changes: 4 additions & 1 deletion tests/test_ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,13 @@ def test_dimensions():
grid = grid1d()
assert grid.node_dimension == f"{NAME}_nNodes"
assert grid.edge_dimension == f"{NAME}_nEdges"
assert grid.dimensions == {
assert grid.dims == (f"{NAME}_nNodes", f"{NAME}_nEdges")
assert grid.sizes == {
f"{NAME}_nNodes": 3,
f"{NAME}_nEdges": 2,
}
with pytest.warns(FutureWarning):
assert grid.dimensions == grid.sizes


@requires_meshkernel
Expand Down
9 changes: 8 additions & 1 deletion tests/test_ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,18 @@ def test_dimensions():
assert grid.node_dimension == f"{NAME}_nNodes"
assert grid.edge_dimension == f"{NAME}_nEdges"
assert grid.face_dimension == f"{NAME}_nFaces"
assert grid.dimensions == {
assert grid.dims == (
f"{NAME}_nNodes",
f"{NAME}_nEdges",
f"{NAME}_nFaces",
)
assert grid.sizes == {
f"{NAME}_nNodes": 7,
f"{NAME}_nEdges": 10,
f"{NAME}_nFaces": 4,
}
with pytest.warns(FutureWarning):
assert grid.dimensions == grid.sizes


def test_edge_node_connectivity():
Expand Down
16 changes: 6 additions & 10 deletions xugrid/core/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def maybe_xugrid(obj, topology, old_indexes=None):

# Topology can either be a sequence of grids or a grid.
if isinstance(topology, (list, set, tuple)):
grids = {dim: grid for grid in topology for dim in grid.dimensions}
grids = {dim: grid for grid in topology for dim in grid.dims}
else:
grids = {dim: topology for dim in topology.dimensions}
grids = {dim: topology for dim in topology.dims}

item_grids = unique_grids([grids[dim] for dim in obj.dims if dim in grids])

Expand Down Expand Up @@ -183,7 +183,7 @@ class DatasetForwardMixin:


def assign_ugrid_coords(obj, grids):
grid_dims = ChainMap(*(grid.dimensions for grid in grids))
grid_dims = ChainMap(*(grid.sizes for grid in grids))
ugrid_dims = set(grid_dims.keys()).intersection(obj.dims)
ugrid_coords = {dim: RangeIndex(0, grid_dims[dim]) for dim in ugrid_dims}
obj = obj.assign_coords(ugrid_coords)
Expand Down Expand Up @@ -385,16 +385,12 @@ def __setitem__(self, key, value):
# Check if the dimensions occur in self.
# if they don't, the grid should be added.
if self.grids is not None:
alldims = set(
chain.from_iterable([grid.dimensions for grid in self.grids])
)
matching_dims = set(value.grid.dimensions).intersection(alldims)
alldims = set(chain.from_iterable([grid.dims for grid in self.grids]))
matching_dims = set(value.grid.dims).intersection(alldims)
if matching_dims:
append = False
# If they do match: the grids should match.
grids = {
dim: grid for grid in self.grids for dim in grid.dimensions
}
grids = {dim: grid for grid in self.grids for dim in grid.dims}
firstdim = next(iter(matching_dims))
grid_to_check = grids[firstdim]
if not grid_to_check.equals(value.grid):
Expand Down
4 changes: 2 additions & 2 deletions xugrid/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,11 @@ def __init__(self, obj):
darray = obj.obj
grid = obj.grid

invalid = set(darray.dims) - set(grid.dimensions)
invalid = set(darray.dims) - set(grid.dims)
if invalid:
raise ValueError(
f"UgridDataArray contains non-topology dimensions: {invalid}.\n"
f"Expected only one of {set(grid.dimensions.keys())}."
f"Expected only one of {grid.dims}."
)

self.grid = grid
Expand Down
4 changes: 2 additions & 2 deletions xugrid/ugrid/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def validate_partition_topology(grouped: defaultdict[str, UgridType]) -> None:
f"same type, received: {types}"
)

griddims = list({tuple(grid.dimensions) for grid in grids})
griddims = list({grid.dims for grid in grids})
if len(griddims) > 1:
raise ValueError(
f"Dimension names on UGRID topology {name} do not match "
Expand Down Expand Up @@ -365,7 +365,7 @@ def merge_partitions(partitions, merge_ugrid_chunks: bool = True):

# Collect grids
grids = [grid for p in partitions for grid in p.grids]
ugrid_dims = {dim for grid in grids for dim in grid.dimensions}
ugrid_dims = {dim for grid in grids for dim in grid.dims}
grids_by_name = group_grids_by_name(partitions)

data_objects_by_name = group_data_objects_by_gridname(partitions)
Expand Down
15 changes: 9 additions & 6 deletions xugrid/ugrid/ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ def core_dimension(self):
return self.node_dimension

@property
def dimensions(self):
def dims(self):
# Tuple to preserve order, unlike set.
return (self.node_dimension, self.edge_dimension)

@property
def sizes(self):
return {self.node_dimension: self.n_node, self.edge_dimension: self.n_edge}

def connectivity_matrix(self, dim: str, xy_weights: bool):
Expand Down Expand Up @@ -432,17 +437,15 @@ def isel(self, indexers=None, return_index=False, **indexers_kwargs):
their respective index. Only returned if return_index is True.
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
alldims = set(self.dimensions)
alldims = self.dims
invalid = indexers.keys() - alldims
if invalid:
raise ValueError(
f"Dimensions {invalid} do not exist. Expected one of {alldims}"
)

indexers = {
k: as_pandas_index(v, self.dimensions[k]) for k, v in indexers.items()
}
nodedim, edgedim = self.dimensions
indexers = {k: as_pandas_index(v, self.sizes[k]) for k, v in indexers.items()}
nodedim, edgedim = self.dims
edge_index = {}
if nodedim in indexers:
node_index = indexers[nodedim]
Expand Down
19 changes: 13 additions & 6 deletions xugrid/ugrid/ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,16 @@ def core_dimension(self):
return self.face_dimension

@property
def dimensions(self):
def dims(self):
# Tuple to preserve order, unlike set.
return (
self.node_dimension,
self.edge_dimension,
self.face_dimension,
)

@property
def sizes(self):
return {
self.node_dimension: self.n_node,
self.edge_dimension: self.n_edge,
Expand Down Expand Up @@ -1160,17 +1169,15 @@ def isel(self, indexers=None, return_index=False, **indexers_kwargs):
True.
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
alldims = set(self.dimensions)
alldims = set(self.dims)
invalid = indexers.keys() - alldims
if invalid:
raise ValueError(
f"Dimensions {invalid} do not exist. Expected one of {alldims}"
)

indexers = {
k: as_pandas_index(v, self.dimensions[k]) for k, v in indexers.items()
}
nodedim, edgedim, facedim = self.dimensions
indexers = {k: as_pandas_index(v, self.sizes[k]) for k, v in indexers.items()}
nodedim, edgedim, facedim = self.dims
face_index = {}
if nodedim in indexers:
node_index = indexers[nodedim]
Expand Down
41 changes: 30 additions & 11 deletions xugrid/ugrid/ugridbase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import abc
import copy
import warnings
from itertools import chain
from typing import Tuple, Type, Union, cast
from typing import Dict, Tuple, Type, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -60,9 +61,9 @@ def align(obj, grids, old_indexes):
if old_indexes is None:
return obj, grids

ugrid_dims = set(
chain.from_iterable(grid.dimensions for grid in grids)
).intersection(old_indexes)
ugrid_dims = set(chain.from_iterable(grid.dims for grid in grids)).intersection(
old_indexes
)
new_indexes = {
k: index
for k, index in obj.indexes.items()
Expand All @@ -74,7 +75,7 @@ def align(obj, grids, old_indexes):
# Group the indexers by grid
new_grids = []
for grid in grids:
ugrid_dims = set(grid.dimensions).intersection(new_indexes)
ugrid_dims = set(grid.dims).intersection(new_indexes)
ugrid_indexes = {dim: new_indexes[dim] for dim in ugrid_dims}
newgrid, indexers = grid.isel(indexers=ugrid_indexes, return_index=True)
indexers = {
Expand All @@ -98,7 +99,29 @@ def core_dimension(self):

@property
@abc.abstractmethod
def dimensions(self):
def dims(self) -> Tuple[str]:
pass

@property
def dimensions(self) -> Dict[str, Dict[str, str]]:
"""
Mapping from UGRID dimension names to lengths.
This property will be changed to return a type more consistent with
DataArray.dims in the future, i.e. a set of dimension names.
"""

warnings.warn(
".dimensions will is replaced by .dims and its return type is a set "
"of dimension names in future. To access a mapping of names to "
"lengths, use .sizes instead.",
FutureWarning,
)
return self.sizes

@property
@abc.abstractmethod
def sizes(self):
pass

@property
Expand Down Expand Up @@ -175,7 +198,7 @@ def _create_data_array(self, data: ArrayLike, dimension: str):
f"Data has {data.ndim} dimensions."
)
len_data = len(data)
len_grid = self.dimensions[dimension]
len_grid = self.sizes[dimension]
if len_data != len_grid:
raise ValueError(
f"Conflicting sizes for dimension {dimension}: length "
Expand Down Expand Up @@ -342,10 +365,6 @@ def max_connectivity_dimensions(self) -> tuple[str]:
def max_connectivity_sizes(self) -> dict[str, int]:
return {}

@property
def sizes(self) -> dict[str, int]:
return self.dimensions

@property
def node_coordinates(self) -> FloatArray:
"""Coordinates (x, y) of the nodes (vertices)"""
Expand Down

0 comments on commit 4f16821

Please sign in to comment.