Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reading data from zarr or netcdf files loads all data into memory with distributed scheduler #8969

Closed
schlunma opened this issue Dec 18, 2024 · 8 comments
Labels
memory performance scheduling stability Issue or feature related to cluster stability (e.g. deadlock)

Comments

@schlunma
Copy link

Describe the issue:

I am working with climate model output (mostly in zarr/netcdf format) and came across some issues when reading those data and calculating averages with a distributed scheduler. Looking at the Dask dashboard, I find that basically all the data is read into memory at some point.

Here is a screenshot of an example dashboard (input data size: ~45 GiB):

654032d4ddd8b5220db80435d8f753844834fff3

See also https://dask.discourse.group/t/reading-data-from-netcdf-or-zarr-files-loads-all-data-into-memory/3733 for more details and discussion about this.

Minimal Complete Verifiable Example:

from pathlib import Path
from dask.distributed import Client
import dask.array as da

if __name__ == '__main__':
    # Create Dummy data
    shape = (365, 1280, 2560)
    chunks = (10, -1, -1)
    out_dir = Path("tmp_dir")
    zarr_paths = [
        out_dir / "1.zarr",
        out_dir / "2.zarr",
        out_dir / "3.zarr",
        out_dir / "4.zarr",
        out_dir / "5.zarr",
    ]
    for path in zarr_paths:
        arr = da.random.random(shape, chunks=chunks)
        da.to_zarr(arr, path)

    client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")

    # Read dummy data
    arrs = [da.from_zarr(p) for p in zarr_paths]

    # Concatenate and calculate weighted average
    arr = da.concatenate(arrs, axis=0)
    weights = da.ones_like(arr, chunks=arr.chunks)
    avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2))  # da.average with weights behaves identically
    print(avg.dask)
    print(avg.compute())
Corresponding Dask graph
HighLevelGraph with 18 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7fae1a6c8980>
 0. ones_like-b2eeb017d903fe0162dbf9fabb034160
 1. sum-bb0b49a151ee1ec3649751f9df4dc22d
 2. sum-aggregate-799a254cbdfcab0da6691f0d1f8abc3a
 3. original-from-zarr-c97365cf561afe64ebf973c34a876e62
 4. from-zarr-c97365cf561afe64ebf973c34a876e62
 5. original-from-zarr-6d55f3c94c52e1e2179024850b15b00a
 6. from-zarr-6d55f3c94c52e1e2179024850b15b00a
 7. original-from-zarr-4e162aaecc691c159a5d43200b078036
 8. from-zarr-4e162aaecc691c159a5d43200b078036
 9. original-from-zarr-92b1e21aaf6078648da539f59f16adb2
 10. from-zarr-92b1e21aaf6078648da539f59f16adb2
 11. original-from-zarr-76c9fa38a30197be5753e24414e5aee7
 12. from-zarr-76c9fa38a30197be5753e24414e5aee7
 13. concatenate-c92880ca8542262e5ae5848e7d65c648
 14. mul-eed970fe66ae9b8888ccb4b1c1830a65
 15. sum-3e71a2b5934567d7f2ad5eb642933a5f
 16. sum-aggregate-90c144806d5ebbe9efba7664c1927918
 17. truediv-4636ae24fac9a860e7a1882e7dd30811

Anything else we need to know?:

This behavior is more or less independent from the scheduler parameters and the actual calculation that is performed. I also tested this with a dask_jobqueue.SLURMCluster and found the same. I also found the same when reading netcdf files using xarray (see MWE below).

The issue does not appear if

sum_res = da.sum(arr * weights, axis=(1, 2)).compute()
weight_res = da.sum(weights, axis=(1, 2)).compute()
avg_res = sum_res / weight_res
Example with xarray (same problem as with zarr)
from pathlib import Path
from dask.distributed import Client
import dask.array as da
import xarray as xr

if __name__ == '__main__':
    # Create Dummy data
    shape = (365, 1280, 2560)
    chunks = (10, -1, -1)
    out_dir = Path("tmp_dir")
    nc_paths = [
        out_dir / "1.nc",
        out_dir / "2.nc",
        out_dir / "3.nc",
        out_dir / "4.nc",
        out_dir / "5.nc",
    ]
    for path in nc_paths:
        arr = xr.DataArray(
            da.random.random(shape, chunks=chunks),
            name="x",
            dims=("time", "lat", "lon"),
        )
        arr.to_netcdf(path)

    client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")

    # Read dummy data
    dss = [xr.open_dataset(p, chunks={"time": 10, "lat": -1, "lon": -1}) for p in nc_paths]
    arrs = [ds.x.data for ds in dss]

    # Concatenate and calculate weighted average
    arr = da.concatenate(arrs, axis=0)
    weights = da.ones_like(arr, chunks=arr.chunks)
    avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2))  # da.average with weights behaves identically
    print(avg.dask)
    print(avg.compute())
Corresponding Dask graph using xarray
HighLevelGraph with 18 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f03ec0b1c70>
 0. ones_like-c7c9cf4f6871a4068d6bda6f4e5141ea
 1. sum-142584b6bc297f1d60f68230643654e5
 2. sum-aggregate-4939d78fcafe476a8465a60e20df2c65
 3. original-open_dataset-x-760bd344ae92ba60483bc3f282e1b70e
 4. open_dataset-x-760bd344ae92ba60483bc3f282e1b70e
 5. original-open_dataset-x-59206dd4469c13d5957b7e3c3e3da309
 6. open_dataset-x-59206dd4469c13d5957b7e3c3e3da309
 7. original-open_dataset-x-1bafe17e170d74a2bbae3a9518808761
 8. open_dataset-x-1bafe17e170d74a2bbae3a9518808761
 9. original-open_dataset-x-61de7e5d23d744435a3d0116d35e6333
 10. open_dataset-x-61de7e5d23d744435a3d0116d35e6333
 11. original-open_dataset-x-25d1b20a8f923f1a653cd0ef1ce5163b
 12. open_dataset-x-25d1b20a8f923f1a653cd0ef1ce5163b
 13. concatenate-8f55d06266aaeabe8c000c30ce2a4048
 14. mul-d8a6804c719b299cad52f97707b11a86
 15. sum-012c92287950bccb2aba104d10dea9fa
 16. sum-aggregate-8a1d67d086a9a1d7841895d269782fd7
 17. truediv-8338ef8062243673d277f81853de876d
Example using dask.array directly (no problem: memory usage per worker is basically always < 2 GiB, dashboard looks clean)
from dask.distributed import Client
import dask.array as da

if __name__ == '__main__':
    # Create Dummy data
    shape = (365, 1280, 2560)
    chunks = (10, -1, -1)

    client = Client(n_workers=2, threads_per_worker=2, memory_limit="4GiB")

    # Create data
    arrs = [da.random.random(shape, chunks=chunks) for _ in zarr_paths]

    # Concatenate and calculate weighted average
    arr = da.concatenate(arrs, axis=0)
    weights = da.ones_like(arr, chunks=arr.chunks)
    avg = da.sum(arr * weights, axis=(1, 2)) / da.sum(weights, axis=(1, 2))  # da.average with weights behaves identically
    print(avg.dask)
    print(avg.compute())
Corresponding Dask graph using dask.array
HighLevelGraph with 13 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f84bb681280>
  0. ones_like-b2eeb017d903fe0162dbf9fabb034160
  1. sum-bb0b49a151ee1ec3649751f9df4dc22d
  2. sum-aggregate-a4961dcbcb818a25426264c800426259
  3. random_sample-b773ea283c4c668852b9a00421c3055d
  4. random_sample-ef776b5e484b45a0e538b513523db6cb
  5. random_sample-87da63b5fee3774ddb748ed498c9ec89
  6. random_sample-e37dcf6cb5d6d272ca5ab3f1d1c2df00
  7. random_sample-32b9374ec13f6972c0f19c85eb6fd784
  8. concatenate-f892adc0c62d53897e0e5293ecdbb26e
  9. mul-bcbc2daf6a3f04985c3af0a5ad0d7d5d
  10. sum-ddf13f4205911e5bb7cbfe019e1c7f4d
  11. sum-aggregate-e8835cb74a96aed802fd656f976893b5
  12. truediv-373a8a1ea714c9051d5ea535ae246621

Environment:

  • Dask version: 2024.12.1 (build: pyhd8ed1ab_0)
  • Python version: 3.13.1 (build: ha99a958_102_cp313)
  • Operating System: openSUSE Tumbleweed
  • Install method (conda, pip, source): conda
@fjetter
Copy link
Member

fjetter commented Dec 18, 2024

Thanks for opening this issue. I can reproduce this behavior. Looking at the dashboard my first suspicion was that this was an ordering issue but inspecting this, I cannot see any faults.

image

GH8969-order.pdf

Will have to look a little deeper

@fjetter
Copy link
Member

fjetter commented Dec 18, 2024

Ah, I missed one thing. The from_zarr tasks are not detected as rootish. #8950 will fix this

cc @jrbourbeau @phofl

@fjetter fjetter added performance stability Issue or feature related to cluster stability (e.g. deadlock) memory scheduling and removed needs triage labels Dec 18, 2024
@fjetter
Copy link
Member

fjetter commented Dec 18, 2024

Using four zarr paths works fine but adding the fifth changes the topology in a way that this cluster won't detect the tasks as root-ish.

@schlunma a mitigation until a proper fix is merged to play with a configuration value of our scheduler.

import dask
dask.config.set({"distributed.scheduler.rootish-taskgroup-dependencies": len(zarr_paths) + 1})

This parameter is controlling how our "root-ish" detection works. That's basically a heuristic that tries to detect data generating tasks such that those can be scheduled a little differently. The default value for this is 5 and your zarr paths all represent individual "taskgroups", i.e. by adding more and more datasets to concat this is messing with our heuristic. Ensuring that value is above the number of datasets you're loading will fix that.
If the number is too high, this will cause other (likely non-critical) scheduler artifacts that might slow your computation down a little but should otherwise not hurt.

Be aware that this config has to be set before the cluster is started

@schlunma
Copy link
Author

Thanks for the very quick answer @fjetter! Just tested this for the small MWE, and I can confirm that this fixes the problem! Amazing!

I will run our more complicated pipeline tomorrow and check if this will also be fixed by this. Cheers!

@schlunma
Copy link
Author

And one more question: do you think that #8950 will also fix this behavior for netcdf data read with external libraries such as xarray or Iris?

@phofl
Copy link
Collaborator

phofl commented Dec 18, 2024

And one more question: do you think that #8950 will also fix this behavior for netcdf data read with external libraries such as xarray or Iris?

Yes for xarray, probably for iris but haven't looked yet

@schlunma
Copy link
Author

I just tested dask/dask#11558 and #8950 with zarr, xarray, and iris. Everything works smoothly with this!

Thanks so much, looking forward to the new release!

@phofl
Copy link
Collaborator

phofl commented Jan 17, 2025

The prs were merged, so closing this one

@phofl phofl closed this as completed Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
memory performance scheduling stability Issue or feature related to cluster stability (e.g. deadlock)
Projects
None yet
Development

No branches or pull requests

3 participants