Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use xarray.map_blocks to speed up pyramid_reproject #10

Open
jhamman opened this issue Nov 23, 2021 · 5 comments
Open

Use xarray.map_blocks to speed up pyramid_reproject #10

jhamman opened this issue Nov 23, 2021 · 5 comments

Comments

@jhamman
Copy link
Contributor

jhamman commented Nov 23, 2021

@norlandrhagen and I have been experimenting with approaches for speeding up pyramid generation when using rio-xarray's reproject functionality. We have this rough prototype to share:

def pyramid_reproject(
    ds, levels: int = None, pixels_per_tile=128, resampling="average", extra_dim=None
) -> dt.DataTree:
    from rasterio.transform import Affine
    from rasterio.warp import Resampling

    # multiscales spec
    save_kwargs = {"levels": levels, "pixels_per_tile": pixels_per_tile}
    attrs = {
        "multiscales": _multiscales_template(
            datasets=[{"path": str(i) for i in range(levels)}],
            type="reduce",
            method="pyramid_reproject",
            version=_get_version(),
            kwargs=save_kwargs,
        )
    }

    # set up pyramid
    root = xr.Dataset(attrs=attrs)
    pyramid = dt.DataTree(data_objects={"root": root})

    def make_template(da, dim, dst_transform, shape=None):

        template = xr.DataArray(
            data=dask.array.empty(shape, chunks=shape), dims=("y", "x"), attrs=da.attrs
        )
        template = make_grid_ds(template, dim, dst_transform)
        template.coords["spatial_ref"] = xr.DataArray(np.array(1.0))
        return template

    def reproject(da, shape=None, dst_transform=None, resampling="average"):
        return da.rio.reproject(
            "EPSG:3857",
            resampling=Resampling[resampling],
            shape=shape,
            transform=dst_transform,
        )

    for level in range(levels):
        lkey = str(level)
        dim = 2 ** level * pixels_per_tile

        dst_transform = Affine.translation(-20026376.39, 20048966.10) * Affine.scale(
            (20026376.39 * 2) / dim, -(20048966.10 * 2) / dim
        )

        pyramid[lkey] = xr.Dataset(attrs=ds.attrs)
        for k, da in ds.items():
            template = make_template(ds[k], dim, dst_transform, (dim, dim))
            pyramid[lkey].ds[k] = xr.map_blocks(
                reproject,
                da,
                kwargs=dict(shape=(dim, dim), dst_transform=dst_transform),
                template=template,
            )

    return pyramid
@dcherian
Copy link

I thought @djhoese was working hard at dask aware reprojection in pyresample?

@djhoese
Copy link

djhoese commented Jul 29, 2022

"working hard" has mostly been in my head as I haven't had time for any of the "real" work. Luckily, I'm not the only one worried about this. @mraspaud did the work on that resample_blocks function and it looks like it might be a game changer for some of our algorithms. The basic idea is:

  1. Get the bounds of each target/output chunk.
  2. Slice the input array that covers this output chunk.
  3. Resample that input slice to the output chunk.

I initially wasn't a fan of this strategy as it requires slicing and rechunking of the input data, but @mraspaud's experience shows that it performs much better than resampling all overlapping input chunks and then merging/reducing them later.

@dcherian
Copy link

resampling all overlapping input chunks and then merging/reducing them later.

I think @gjoseph92 had something clever in stackstac for doing something like this and avoiding shuffling a bunch of NaNs (or other useless data) around. Can't find it now though. I might have totally misinterpreted though

@djhoese
Copy link

djhoese commented Jul 29, 2022

I have another very hacky implementation in pyresample for the "EWA" resampling algorithm (very specific to VIIRS and MODIS instruments) where I do a dask reduction but use tuples of values between functions. If the data is destined for the output chunk then the tuple contains arrays, if not then it contains Nones. It isn't how dask intends the function to be used (array functions should return arrays), but it works for me to prevent unnecessary processing of chunks that I know would be all NaNs/fills.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants