Skip to content
Merged
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

## Unreleased

Fixed:
- Ensured all attributes are kept upon regridding (dataset, variable and coordinate attrs).

Changed:
- Moved to the Ruff formatter, instead of black.

## v0.2.2 (20203-11-24)

Added:
Expand Down
20 changes: 9 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ benchmarking = [
dev = [
"hatch",
"ruff",
"black[jupyter]",
"mypy",
"pytest",
"pytest-cov",
Expand All @@ -63,22 +62,19 @@ features = ["dev", "benchmarking"]
lint = [
"ruff check .",
"mypy src",
"black --check --diff .",
"ruff format . --check",
]
format = ["black .", "lint",]
format = ["ruff format .", "lint",]
test = ["pytest ./src/ ./tests/ --doctest-modules",]
coverage = [
"pytest --cov --cov-report term --cov-report xml --junitxml=xunit-result.xml tests/",
]

[tool.black]
target-version = ["py310"]
line-length = 88
skip-string-normalization = true

[tool.ruff]
target-version = "py310"
line-length = 88

[tool.ruff.lint]
select = [
"A",
"ARG",
Expand Down Expand Up @@ -115,19 +111,21 @@ ignore = [
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Causes conflicts with formatter
"ISC001",
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["xarray_regrid"]

[tool.ruff.flake8-tidy-imports]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

Expand Down
29 changes: 27 additions & 2 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,14 @@ def conservative_regrid_dataset(
latitude_coord: str,
) -> xr.Dataset:
"""Dataset implementation of the conservative regridding method."""
data_vars: list[str] = list(data.data_vars)
data_vars = list(data.data_vars)
data_coords = list(data.coords)
dataarrays = [data[var] for var in data_vars]

attrs = data.attrs
da_attrs = [da.attrs for da in dataarrays]
coord_attrs = [data[coord].attrs for coord in data_coords]

for coord in coords:
target_coords = coords[coord].to_numpy()
source_coords = data[coord].to_numpy()
Expand All @@ -98,7 +103,17 @@ def conservative_regrid_dataset(
da = dataarrays[i].transpose(coord, ...)
dataarrays[i] = apply_weights(da, weights, coord, target_coords)

return xr.merge(dataarrays) # TODO: add other coordinates/data variables back in.
for da, attr in zip(dataarrays, da_attrs, strict=True):
da.attrs = attr
regridded = xr.merge(dataarrays)

regridded.attrs = attrs

new_coords = [regridded[coord] for coord in data_coords]
for coord, attr in zip(new_coords, coord_attrs, strict=True):
coord.attrs = attr

return regridded # TODO: add other coordinates/data variables back in.


def conservative_regrid_dataarray(
Expand All @@ -107,6 +122,11 @@ def conservative_regrid_dataarray(
latitude_coord: str,
) -> xr.DataArray:
"""DataArray implementation of the conservative regridding method."""
data_coords = list(data.coords)

attrs = data.attrs
coord_attrs = [data[coord].attrs for coord in data_coords]

for coord in coords:
if coord in data.coords:
target_coords = coords[coord].to_numpy()
Expand All @@ -125,6 +145,11 @@ def conservative_regrid_dataarray(
data = data.transpose(coord, ...)
data = apply_weights(data, weights, coord, target_coords)

new_coords = [data[coord] for coord in data_coords]
for coord, attr in zip(new_coords, coord_attrs, strict=True):
coord.attrs = attr
data.attrs = attrs

return data


Expand Down
9 changes: 8 additions & 1 deletion src/xarray_regrid/methods/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ def interp_regrid(
"""
coord_names = set(target_ds.coords).intersection(set(data.coords))
coords = {name: target_ds[name] for name in coord_names}
coord_attrs = {coord: data[coord].attrs for coord in coord_names}

return data.interp(
interped = data.interp(
coords=coords,
method=method,
)

# xarray's interp drops some of the coordinate's attributes (e.g. long_name)
for coord in coord_names:
interped[coord].attrs = coord_attrs[coord]

return interped
3 changes: 3 additions & 0 deletions src/xarray_regrid/methods/most_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
"""
dim_order = data.dims
coords = utils.common_coords(data, target_ds, remove_coord=time_dim)
coord_attrs = {coord: data[coord].attrs for coord in target_ds.coords}

bounds = tuple(
_construct_intervals(target_ds[coord].to_numpy()) for coord in coords
)
Expand Down Expand Up @@ -185,6 +187,7 @@ def most_common(data: xr.Dataset, target_ds: xr.Dataset, time_dim: str) -> xr.Da
ds_regrid = ds_regrid.rename({f"{coord}_bins": coord for coord in coords})
for coord in coords:
ds_regrid[coord] = target_ds[coord]
ds_regrid[coord].attrs = coord_attrs[coord]

return ds_regrid.transpose(*dim_order)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_most_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def dummy_lc_data():
"longitude": (["longitude"], lon_coords),
"latitude": (["latitude"], lat_coords),
},
attrs={"test": "not empty"},
)


Expand Down Expand Up @@ -77,3 +78,20 @@ def test_most_common(dummy_lc_data, dummy_target_grid):
dummy_lc_data.regrid.most_common(dummy_target_grid)["lc"],
expected["lc"],
)


def test_attrs_dataarray(dummy_lc_data, dummy_target_grid):
dummy_lc_data["lc"].attrs = {"test": "testing"}
da_regrid = dummy_lc_data["lc"].regrid.most_common(dummy_target_grid)
assert da_regrid.attrs != {}
assert da_regrid.attrs == dummy_lc_data["lc"].attrs
assert da_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs


def test_attrs_dataset(dummy_lc_data, dummy_target_grid):
ds_regrid = dummy_lc_data.regrid.most_common(
dummy_target_grid,
)
assert ds_regrid.attrs != {}
assert ds_regrid.attrs == dummy_lc_data.attrs
assert ds_regrid["longitude"].attrs == dummy_lc_data["longitude"].attrs
33 changes: 33 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,36 @@ def test_conservative_nans(conservative_input_data, conservative_sample_grid):
rtol=0.002,
atol=2e-6,
)


@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
def test_attrs_dataarray(sample_input_data, sample_grid_ds, method):
regridder = getattr(sample_input_data["d2m"].regrid, method)
da_regrid = regridder(sample_grid_ds)
assert da_regrid.attrs == sample_input_data["d2m"].attrs
assert da_regrid["longitude"].attrs == sample_input_data["longitude"].attrs


def test_attrs_dataarray_conservative(sample_input_data, sample_grid_ds):
da_regrid = sample_input_data["d2m"].regrid.conservative(
sample_grid_ds, latitude_coord="latitude"
)
assert da_regrid.attrs == sample_input_data["d2m"].attrs
assert da_regrid["longitude"].attrs == sample_input_data["longitude"].attrs


@pytest.mark.parametrize("method", ["linear", "nearest", "cubic"])
def test_attrs_dataset(sample_input_data, sample_grid_ds, method):
regridder = getattr(sample_input_data.regrid, method)
ds_regrid = regridder(sample_grid_ds)
assert ds_regrid.attrs == sample_input_data.attrs
assert ds_regrid["longitude"].attrs == sample_input_data["longitude"].attrs


def test_attrs_dataset_conservative(sample_input_data, sample_grid_ds):
ds_regrid = sample_input_data.regrid.conservative(
sample_grid_ds, latitude_coord="latitude"
)
assert ds_regrid.attrs == sample_input_data.attrs
assert ds_regrid["d2m"].attrs == sample_input_data["d2m"].attrs
assert ds_regrid["longitude"].attrs == sample_input_data["longitude"].attrs