Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN threshold for conservative method #39

Closed

Conversation

slevang
Copy link
Collaborator

@slevang slevang commented May 21, 2024

Nice package! In testing it out where I've previously used xesmf, I noticed two features lacking from the conservative method:

  1. the regridded dataset can't be constructed lazily if dask-backed due to this line
  2. no ability to keep target cells where the input points are partially NaN, as noted in Adding NaN thresholding #32

Number 1 is easy, number 2 is trickier. I added a naive implementation for the nan_threshold capabilities of xesmf here for discussion. As noted in the previous issue, to do this 100% correctly we would need to track the NaN fraction as we reduce over each dimension, which I'm not doing here. The nan_threshold value doesn't translate directly to total fraction of NaN cells due to the sequential reduction. It would also get complicated for isolated NaNs in the temporal dimension.

I'm not sure any of this matters much for a dataset where you have consistent NaN's e.g. SST. Here's an example of the new functionality used on the MUR dataset. Note this is a 33TB array but we can now generate the (lazily) regridded dataset instantaneously.

import xarray as xr
import xarray_regrid

sst = xr.open_zarr("https://mur-sst.s3.us-west-2.amazonaws.com/zarr-v1").analysed_sst
grid = xarray_regrid.Grid(
    north=90,
    east=180,
    south=-90,
    west=-180,
    resolution_lat=1,
    resolution_lon=1,
)
target = xarray_regrid.create_regridding_dataset(grid)
target = target.rename(latitude="lat", longitude="lon")

ds0p0 = sst.regrid.conservative(target, nan_threshold=0)
ds0p5 = sst.regrid.conservative(target, nan_threshold=0.5)
ds1p0 = sst.regrid.conservative(target, nan_threshold=1)

ds0p0

ds0p5

ds1p0

@slevang slevang marked this pull request as ready for review May 22, 2024 15:15
Copy link
Contributor

@BSchilperoort BSchilperoort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for contributing this! I'll try to give it a full review and run it next week, but here are already some comments.

Did you compare the performance to xESMF to ensure that the masking has the same results here as there? I can also try that later.

Lastly, would you be able to write some unit tests to test the nan masking? And update the changelog? If anything is unclear please let me know.

src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
@slevang
Copy link
Collaborator Author

slevang commented May 26, 2024

Thanks, and yes for sure, hopefully next week I will find time to add tests and benchmarks.

@slevang
Copy link
Collaborator Author

slevang commented May 26, 2024

Did some quick profiling on a ~4GB array of 1/4deg global data coarsening to 1deg. Dask array on a 32 CPU node. Results:

  • This PR, skipna=False: 32s
  • This PR, skipna=True: 64s
  • main: 96s

So adding skipna forces roughly one additional pass through the array with the weight renormalization. The reason this PR is faster than main is because the current code has the np.any(np.isnan()) check which forces computation, plus the separately calculated isnan array, which forces 3 passes through the data. If I cut out the logic branch of checking for NaNs on main and go straight to the einsum, we recover the ~32s run above.

Copy link
Collaborator Author

@slevang slevang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally found some time to get back to this.

I did manage to come up with an implementation that should accurately track the fraction of NaN points across dimensions (see the tests added). The problem currently is that there is a major performance penalty because we are tracking the nans independently across all dimensions (such as an additional time dim) which causes the einsum ops to have much larger dimensionality.

Also did some general reworking to use xr.dot, consolidate some things like different the DataArray/Dataset pathways, and knocked out some other small improvements to conservative that have been noted in the issues.

pyproject.toml Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
src/xarray_regrid/utils.py Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
@slevang
Copy link
Collaborator Author

slevang commented Jul 14, 2024

Made the modification to take notnull.any(non_regrid_dims) which leaves us at about a 3x performance penalty for skipna=True in the benchmarks I've run. I think this should maybe be a configurable arg though in cases where you want to track NaNs very carefully throughout the dataset.

src/xarray_regrid/utils.py Show resolved Hide resolved
xr.testing.assert_allclose(da_coarsen, da_regrid)


@pytest.mark.skip(reason="requires xesmf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use pytest.mark.skipif here, and run the test automatically if xesmf is available.

src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
tests/test_most_common.py Show resolved Hide resolved
pyproject.toml Show resolved Hide resolved
benchmarks/benchmarking_conservative.ipynb Show resolved Hide resolved
src/xarray_regrid/methods/conservative.py Outdated Show resolved Hide resolved
Copy link
Contributor

@BSchilperoort BSchilperoort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the changes you've made! I've invited you to this repository so you don't have to wait for me to have the CI run.

src/xarray_regrid/methods/conservative.py Show resolved Hide resolved
BSchilperoort added a commit that referenced this pull request Sep 4, 2024
* add nan_threshold option

* track nan frac across dims, use xr.dot, consolidate ds/da paths, tests

* fixes and cleanup plus initial notebook cell

* speed up by only tracking the max of nonnull points over non_regrid_dims

* fix tests for newer dependency versions

* Improve typing of call_on_dataset

* Fix typing in updated conservative routines

* Apply code formatting

* Ensure hashable is a valid input for coordinate identifier

* Make tests & typing pass

* Make `create_regridding_dataset` a method of `Grid` #38

* Update notebooks and dependencies

* Allow xesmf test to run if it's available

* Add @slevang to the contributors list

* Update changelog

* Ignore linter in test

* Update readme with badges and "why use..." text

---------

Co-authored-by: Sam Levang <slevang@salientpredictions.com>
@BSchilperoort
Copy link
Contributor

Merged as part of #41

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adding NaN thresholding Dealing with 1x1 grids Conservative: detect/handle non-regularly spaced intervals
2 participants