Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure attributes are kept upon regridding #27

Merged
merged 9 commits into from
Feb 13, 2024
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

## Unreleased

Fixed:
- Ensured all attributes are kept upon regridding (dataset, variable and coordinate attrs).

Changed:
- Moved to the Ruff formatter, instead of black.

## v0.2.2 (20203-11-24)

Added:
Expand Down
20 changes: 9 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ benchmarking = [
dev = [
"hatch",
"ruff",
"black[jupyter]",
"mypy",
"pytest",
"pytest-cov",
Expand All @@ -63,22 +62,19 @@ features = ["dev", "benchmarking"]
lint = [
"ruff check .",
"mypy src",
"black --check --diff .",
"ruff format . --check",
]
format = ["black .", "lint",]
format = ["ruff format .", "lint",]
test = ["pytest ./src/ ./tests/ --doctest-modules",]
coverage = [
"pytest --cov --cov-report term --cov-report xml --junitxml=xunit-result.xml tests/",
]

[tool.black]
target-version = ["py310"]
line-length = 88
skip-string-normalization = true

[tool.ruff]
target-version = "py310"
line-length = 88

[tool.ruff.lint]
select = [
"A",
"ARG",
Expand Down Expand Up @@ -115,19 +111,21 @@ ignore = [
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Causes conflicts with formatter
"ISC001",
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["xarray_regrid"]

[tool.ruff.flake8-tidy-imports]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

Expand Down
29 changes: 27 additions & 2 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ def conservative_regrid_dataset(
latitude_coord: str,
) -> xr.Dataset:
"""Dataset implementation of the conservative regridding method."""
data_vars: list[str] = list(data.data_vars)
data_vars = list(data.data_vars)
data_coords = list(data.coords)
dataarrays = [data[var] for var in data_vars]

attrs = data.attrs
da_attrs = [da.attrs for da in dataarrays]
coord_attrs = [data[coord].attrs for coord in data_coords]

for coord in coords:
target_coords = coords[coord].to_numpy()
source_coords = data[coord].to_numpy()
Expand All @@ -98,7 +103,17 @@ def conservative_regrid_dataset(
da = dataarrays[i].transpose(coord, ...)
dataarrays[i] = apply_weights(da, weights, coord, target_coords)

return xr.merge(dataarrays) # TODO: add other coordinates/data variables back in.
for da, attr in zip(dataarrays, da_attrs, strict=True):
da.attrs = attr
regridded = xr.merge(dataarrays)

regridded.attrs = attrs

new_coords = [regridded[coord] for coord in data_coords]
for coord, attr in zip(new_coords, coord_attrs, strict=True):
coord.attrs = attr

return regridded # TODO: add other coordinates/data variables back in.


def conservative_regrid_dataarray(
Expand All @@ -107,6 +122,11 @@ def conservative_regrid_dataarray(
latitude_coord: str,
) -> xr.DataArray:
"""DataArray implementation of the conservative regridding method."""
data_coords = list(data.coords)

attrs = data.attrs
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved
coord_attrs = [data[coord].attrs for coord in data_coords]

for coord in coords:
if coord in data.coords:
target_coords = coords[coord].to_numpy()
Expand All @@ -125,6 +145,11 @@ def conservative_regrid_dataarray(
data = data.transpose(coord, ...)
data = apply_weights(data, weights, coord, target_coords)

new_coords = [data[coord] for coord in data_coords]
for coord, attr in zip(new_coords, coord_attrs, strict=True):
coord.attrs = attr
data.attrs = attrs
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved

return data


Expand Down
9 changes: 8 additions & 1 deletion src/xarray_regrid/methods/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ def interp_regrid(
"""
coord_names = set(target_ds.coords).intersection(set(data.coords))
coords = {name: target_ds[name] for name in coord_names}
coord_attrs = {coord: data[coord].attrs for coord in coord_names}

return data.interp(
interped = data.interp(
coords=coords,
method=method,
)

# xarray's interp drops some of the coordinate's attributes (e.g. long_name)
for coord in coord_names:
interped[coord].attrs = coord_attrs[coord]

return interped
3 changes: 3 additions & 0 deletions src/xarray_regrid/methods/most_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
"""
dim_order = data.dims
coords = utils.common_coords(data, target_ds, remove_coord=time_dim)
coord_attrs = {coord: data[coord].attrs for coord in target_ds.coords}

bounds = tuple(
_construct_intervals(target_ds[coord].to_numpy()) for coord in coords
)
Expand Down Expand Up @@ -185,6 +187,7 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
ds_regrid = ds_regrid.rename({f"{coord}_bins": coord for coord in coords})
for coord in coords:
ds_regrid[coord] = target_ds[coord]
ds_regrid[coord].attrs = coord_attrs[coord]

return ds_regrid.transpose(*dim_order)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_most_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def dummy_lc_data():
"longitude": (["longitude"], lon_coords),
"latitude": (["latitude"], lat_coords),
},
attrs={"test": "not empty"},
)


Expand Down Expand Up @@ -77,3 +78,20 @@ def test_most_common(dummy_lc_data, dummy_target_grid):
dummy_lc_data.regrid.most_common(dummy_target_grid)["lc"],
expected["lc"],
)


def test_attrs_dataarray(dummy_lc_data, dummy_target_grid):
dummy_lc_data["lc"].attrs = {"test": "testing"}
da_regrid = dummy_lc_data["lc"].regrid.most_common(dummy_target_grid)
assert da_regrid.attrs != {}
assert da_regrid.attrs == dummy_lc_data["lc"].attrs
assert da_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs


def test_attrs_dataset(dummy_lc_data, dummy_target_grid):
ds_regrid = dummy_lc_data.regrid.most_common(
dummy_target_grid,
)
assert ds_regrid.attrs != {}
assert ds_regrid.attrs == dummy_lc_data.attrs
assert ds_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs
33 changes: 33 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,36 @@ def test_conservative_nans(conservative_input_data, conservative_sample_grid):
rtol=0.002,
atol=2e-6,
)


@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
def test_attrs_dataarray(sample_input_data, sample_grid_ds, method):
regridder = getattr(sample_input_data["d2m"].regrid, method)
da_regrid = regridder(sample_grid_ds)
assert da_regrid.attrs == sample_input_data["d2m"].attrs
assert da_regrid["longitude"].attrs == sample_input_data["longitude"].attrs


def test_attrs_dataarray_conservative(sample_input_data, sample_grid_ds):
da_regrid = sample_input_data["d2m"].regrid.conservative(
sample_grid_ds, latitude_coord="latitude"
)
assert da_regrid.attrs == sample_input_data["d2m"].attrs
assert da_regrid["longitude"].attrs == sample_input_data["longitude"].attrs


@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
def test_attrs_dataset(sample_input_data, sample_grid_ds, method):
regridder = getattr(sample_input_data.regrid, method)
ds_regrid = regridder(sample_grid_ds)
assert ds_regrid.attrs == sample_input_data.attrs
assert ds_regrid["longitude"].attrs == sample_input_data["longitude"].attrs


def test_attrs_dataset_conservative(sample_input_data, sample_grid_ds):
ds_regrid = sample_input_data.regrid.conservative(
sample_grid_ds, latitude_coord="latitude"
)
assert ds_regrid.attrs == sample_input_data.attrs
assert ds_regrid["d2m"].attrs == sample_input_data["d2m"].attrs
assert ds_regrid["longitude"].attrs == sample_input_data["longitude"].attrs
Loading