Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling dask kwargs to be passed into xr.open_netcdf #1749

Merged
merged 37 commits into from
Nov 14, 2021

Conversation

mortonjt
Copy link
Contributor

@mortonjt mortonjt commented Jul 29, 2021

Description

This pull request allows for dask kwargs to be passed into az.InferenceData.from_netcdf via xr.open_dataset.
Ultimately, this allows for chunking information to be passed into the az.InferenceData objects.

Relevant PRs
#1655
Relevant issues
#786
#1066
#1631

I'm noticing orders of magnitude improvement for rhat calculations, which are one of the bottlenecks in my current pipeline.

Checklist

  • Follows official PR format
  • Includes a sample plot to visually illustrate the changes (only for plot-related functions)
  • New features are properly documented (with an example if appropriate)?
  • Includes new or updated tests to cover the new feature
  • Code style correct (follows pylint and black guidelines)
  • Changes are listed in changelog

@mortonjt
Copy link
Contributor Author

mortonjt commented Jul 29, 2021

Moving conversation from #1655 to here
Addressing the comment from @OriolAbril

We should be a bit careful with the api I think though. We can continue the discussion on the PR, but I have some questions we should have answers for before merging. For example, not all groups will have the same dimensions, what happens if the chunk argument passed to xr.open_dataset has non existing dimensions? or if we want to chunk the posterior predictive and log likelihood in blocks of 10 draws but we don't want to chunk posterior nor sample stats ones?

I don't know the right answer to this question -- maybe better documentation? The only thing that is done here is further exposing the xr.open_dataset api. I can't comment on the possible failure modes (I bet there is a ton of them), but currently the chunking information is offloaded to the user. For instance, it doesn't make sense to chunk across chains, so the following will break the az.rhat calculations

import arviz as az
res = az.from_netcdf(fname, chunks="auto")
az.rhat(res).compute()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-23-4be8a6df097d> in <module>
----> 1 res = az.InferenceData.from_netcdf(fname, chunks={'auto'})
      2 az.rhat(res).compute()

~/software/arviz/arviz/data/inference_data.py in from_netcdf(filename, **kwargs)
    340 
    341             for group in data_groups:
--> 342                 with xr.open_dataset(filename, group=group, **kwargs) as data:
    343                     if rcParams["data.load"] == "eager":
    344                         groups[group] = data.load()

~/miniconda3/envs/qiime2-2021.4/lib/python3.8/site-packages/xarray/backends/api.py in open_dataset(filename_or_obj, engine, chunks, cache, decode_cf, mask_and_scale, decode_times, decode_timedelta, use_cftime, concat_characters, decode_coords, drop_variables, backend_kwargs, *args, **kwargs)
    500         **kwargs,
    501     )
--> 502     ds = _dataset_from_backend_dataset(
    503         backend_ds,
    504         filename_or_obj,

~/miniconda3/envs/qiime2-2021.4/lib/python3.8/site-packages/xarray/backends/api.py in _dataset_from_backend_dataset(backend_ds, filename_or_obj, engine, chunks, cache, overwrite_encoded_chunks, **extra_tokens)
    308 ):
    309     if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
--> 310         raise ValueError(
    311             f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
    312         )

ValueError: chunks must be an int, dict, 'auto', or None. Instead found {'auto'}.

A way around this is to chunk by draws instead

res = az.InferenceData.from_netcdf(fname, chunks={'draws': 50})
az.rhat(res).compute()

I'm not sure about the other functions, it maybe a case-by-case basis. But from what I can tell, it is useful to first expose the full functionality of xr.open_dataset first and provide the flexibility to specify chunking infomration

@OriolAbril
Copy link
Member

The point about better documentation is very relevant but different from what I was asking.

One thing is that users will need to learn how to properly chunk for things to work properly. Deciding how to chunk is generally tricky and has a huge effect on performance, so it is definitely something users will need to be aware of. I hope we can externalize that guidance however and in our docs include only links to dask and xarray docs, with a highlights paragraph at most.

Another question is once this chunking has been decide and it's optimal, how do we make sure the user can actually transmit that downstream to xarray/dask. I'll try to use an example. Take the centered_eight example. For rhat and ess, the ideal chunking (it it were needed) would be to chunk over the school dimension (see next paragraph for details). This raises two questions:

  • Does az.InferenceData.from_netcdf("centered_eight.nc", chunks={'school': 2}) work? not all groups have a school dimension, so using the same set of kwargs for all groups may not really be a possibility
  • Can we apply the chunking over the school dimension to all groups with school dimension except for the observed data one? There as we don't have samples the memory it needs will be much smaller making chunking probably worse performant and more error prone than pure numpy

Back to your current examples and the selection of chunking dimension. For things to work, the dimensions that are chunked should not be core dimensions of the operations we want to perform. For rhat and ess the core dimensions are both chain and draw, so if a dataset is chunked on those dimensions, one can't call rhat nor ess without rechunking. In your examples, the first is erroring out when reading (not sure why though), and the second one is working even though the data is chunked on draw which makes me suspect that dask was not activated, the data was read into memory as a numpy array and the computation went ahead without chunking problems 🤔 .

As for possible solutions, maybe we could have some dict of kwargs? so that we have first keys group names and values kwarg dicts for that group instead of kwargs directly. It would be great to allow regex or metagroups like we do with idata methods (i.e. https://arviz-devs.github.io/arviz/api/generated/arviz.InferenceData.sel.html#arviz.InferenceData.sel) so that {".": kwargs} would use the same kwargs on all groups

@mortonjt
Copy link
Contributor Author

Fair enough. I pushed in a dict of dict argument. You should be able to verify the speed up on regression10

import arviz as az
from arviz.utils import Dask
from dask.distributed import Client
centered_data = az.load_arviz_data("regression10d")
%timeit az.rhat(centered_data)

22.8 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
vs the dask chunked version

Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]})
client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB")
group_kwargs = {
    'posterior': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}},
    'posterior_predictive': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0': 2}}
}
centered_data = az.load_arviz_data("regression10d", group_kwargs=group_kwargs)
%timeit az.rhat(centered_data)

4.06 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

The eight schools problem is a bit too small - I suspect data exchange time is overwhelming gains made from parallelism.

@mortonjt
Copy link
Contributor Author

Regarding the earlier comment about the user interface, perhaps configparser is the way to go. The dask interface is sufficiently complex enough where this may be helpful (at least as a user, I would find configuring a config more intuitive than passing in a triple layer dict).

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}

I think this part should not have the dimension repeated, two keys should be used for 2d chunks. 'chunks': {'true_w_dim_0': 2} should do the same.

22.8 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
vs the dask chunked version
4.06 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

This is a quite good speedup. This is probably due to our make_ufunc not being very efficient because dask should not provide such speedups for data that comfortably fits into memory.

Regarding the earlier comment about the user interface, perhaps configparser is the way to go. The dask interface is sufficiently complex enough where this may be helpful (at least as a user, I would find configuring a config more intuitive than passing in a triple layer dict).

I think it would be better to include an example (maybe also some helper configparser already configured?) of how to use configparser to define this nested dict instead of trying to get the function to use configparser. for example:

import configparser

sample_config = """
[posterior]
  [chunks] 
  true_w_dim_0: 2

[posterior_predictive]
  [chunks] 
  true_w_dim_0: 2
"""

config = configparser.ConfigParser()
config.read_string(sample_config)
idata = az.from_netcdf(..., group_kwargs=config)

Note that I have not used configparser and only skimmed the linked page. I have added another comment with more functionality that would be useful in my opinion, not sure if configparser would be useful for that. Not everything needs to be implemennted in this PR directly, once we have chosen the api we can extend the functionality later (i.e. if we use dict of dicts, getting the 2nd option in the comment to work and adding tests can be done in a follow up PR)

arviz/data/inference_data.py Outdated Show resolved Hide resolved
@ahartikainen
Copy link
Contributor

This is a bit of "silly" question, but is there a way to chunk against all but chain and draw dimension? E.g. if inf data has 100k parameters -> chunk against "non-dim" vs 100k long vector -> chunk against vector dim? Or is this something dask does automatically?

@mortonjt
Copy link
Contributor Author

@OriolAbril I've modified the regex option and confirmed that it works, and I like this option since it is simple and easy to extend. I don't believe that it will necessarily exclude the other options we discussed, since the extended group_kwargs could be passed directly into from_netcdf. In this way future improvements on the user interface could just modify group_kwargs wkthout breaking backwards compatibility. I've also added unittests to check if the chunks are configured appropriately. These tests will currently break since dask is currently not installed in the CI (not sure where the CI configuration files are).

@ahartikainen yes, but at the moment, this will have to be done manually by specifying each dimension -- I anticipate that future iterations could help solve this by creating helper functions to output the appropriate group_kwargs.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is great for now, but I would probably add some warning in the docstring for group_kwargs and regex about them being experimental, or somewhat experimental. The timeline is not completely clear but xarray is working on hierarchical netcdf like objects support which would be some extended version of InferenceData (pydata/xarray#4118), and we should use that once available also with its much probably better dask support.

arviz/data/datasets.py Outdated Show resolved Hide resolved
arviz/data/inference_data.py Outdated Show resolved Hide resolved
arviz/data/inference_data.py Outdated Show resolved Hide resolved
arviz/tests/base_tests/test_data.py Outdated Show resolved Hide resolved
@OriolAbril
Copy link
Member

I am not sure why tests are not working for basetests. I think dask should be installed in https://github.com/arviz-devs/arviz/blob/main/.azure-pipelines/azure-pipelines-base.yml#L57, is dask not the right package? should it be dask[all] or something of the sort?

@mortonjt
Copy link
Contributor Author

Probably needs the dask distributed dependency. See https://distributed.dask.org/en/latest/install.html

@OriolAbril
Copy link
Member

OriolAbril commented Aug 24, 2021

I am a bit hazy about dask components, so I don't know which should be installed, but I do think it would be better, especially for users who don't know about dask (nor need to know much if interfacing with it via xarray only), to use these dask[component] pseudo-packages as they might not know what distributed alone is. reference: https://docs.dask.org/en/latest/install.html

My main issue is that I don't know what distributed and delayed are so I am not sure which should be installed: http://xarray.pydata.org/en/stable/user-guide/dask.html

@mortonjt
Copy link
Contributor Author

Right, dask is quite a complex ecosystem ...

Regarding dask components, the distributed package is currently the recommended way to setup single-machine parallelism, since it comes with a more sophisticated scheduler (see here : https://docs.dask.org/en/latest/scheduling.html) -- and it is also easier to keep track of scope. xarray also has the dependency in the ci. So I don't think it is unreasonable to default to xarray's dask optional dependencies.

@mortonjt
Copy link
Contributor Author

ok, it looks like most of the tests are passing. Although I'm seeing a weird error [error]Circuit Breaker "HttpClientThrottler-LocationHttpClient-mmsproduks1.vstsmms.visualstudio.com" short-circuited. In the last 10000 milliseconds, there were: 1 failure, 0 timeout, 21 short circuited, 0 concurrency rejected, and 0 limit rejected. Last exception: Microsoft.VisualStudio.Services.WebApi.VssServiceResponseException TF24668: The following team project collection is stopped: tnnattbiggw. Start the collection and then try again. Administrator Reason: abuse . Thoughts?

@OriolAbril
Copy link
Member

No idea really, it looks like we are trying to do something forbidden in Azure somehow. Are the tests failing always the same?

@mortonjt
Copy link
Contributor Author

From what I can tell no. Pushing another random commit to see if it isn't a random failure.

@mortonjt
Copy link
Contributor Author

ok, so I'm a little stumped why there isn't any chunking information returned in the azure ci. On my system, I'm able to confirm that the chunking works

=========================================================================================== test session starts ============================================================================================
platform linux -- Python 3.8.8, pytest-6.2.3, py-1.10.0, pluggy-0.13.1
rootdir: /mnt/home/jmorton/software/arviz, configfile: pytest.ini
plugins: anyio-3.1.0
collected 2 items

test_data_dask.py ..                                                                                                                                                                                  [2/2]

=========================================================================================== slowest 20 durations ===========================================================================================
0.19s call     arviz/tests/base_tests/test_data_dask.py::TestDataDask::test_dask_chunk_group_kwds
0.14s call     arviz/tests/base_tests/test_data_dask.py::TestDataDask::test_dask_chunk_group_regex

(4 durations < 0.005s hidden.  Use -vv to show these durations.)
============================================================================================ 2 passed in 0.34s =============================================================================================

@mortonjt
Copy link
Contributor Author

Hi @OriolAbril / @ahartikainen I'm still a bit puzzled by the errors, its as if xarray isn't chunking correctly.
Unfortunately, debugging via garbage commits is quite slow. Is there by any chance a VM of the azure ci that is available to download (I'm not as familiar with azure sorry).

@mortonjt mortonjt force-pushed the parallel-netcdf branch 4 times, most recently from a037840 to 8f0a08f Compare October 12, 2021 03:55
@mortonjt
Copy link
Contributor Author

ok, so most of the tests are passing now; but there appears to be some random tests failing that aren't relevant for this PR.
Thoughts?

@ahartikainen
Copy link
Contributor

I think cloudpickle has a new version and tfp can not use it. Let me try to find a fix for it.

@ahartikainen
Copy link
Contributor

It should work after rebase with main

@mortonjt
Copy link
Contributor Author

mortonjt commented Oct 15, 2021

@ahartikainen when you mean rebase, you mean just merging with main correct? I did notice some windows carriage returns in .pylinterc, io_tfp.py and stats_utils.py, which is why I didn't pull those down earlier. But as you can see, merging those changes in doesn't change the error messages ...

@OriolAbril
Copy link
Member

OriolAbril commented Oct 16, 2021

I noticed that cloudpickle version was 2.0 in CI even though it is set to be <1.5 in https://github.com/arviz-devs/arviz/blob/main/requirements-dev.txt#L12. I am not sure why. I tried rebasing myself to see if it helped but it looks like it did not.

I have also opened #1890 as I believe the best way forward is probably to remove the tfp converter altogether as it is severely outdated and will generate more and more problems every day that passes.

@codecov
Copy link

codecov bot commented Nov 14, 2021

Codecov Report

Merging #1749 (92fcbf0) into main (ff796cf) will decrease coverage by 0.01%.
The diff coverage is 95.96%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1749      +/-   ##
==========================================
- Coverage   91.72%   91.70%   -0.02%     
==========================================
  Files         116      116              
  Lines       12362    12373      +11     
==========================================
+ Hits        11339    11347       +8     
- Misses       1023     1026       +3     
Impacted Files Coverage Δ
arviz/data/inference_data.py 83.54% <75.00%> (-0.22%) ⬇️
arviz/stats/stats_utils.py 96.71% <96.71%> (ø)
arviz/data/datasets.py 98.38% <100.00%> (ø)
arviz/data/io_netcdf.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ff796cf...92fcbf0. Read the comment docs.

@OriolAbril
Copy link
Member

Tests are working! Thanks for your patience @mortonjt!

@ahartikainen ahartikainen merged commit 44b908f into arviz-devs:main Nov 14, 2021
@mortonjt
Copy link
Contributor Author

Thank you @OriolAbril @ahartikainen !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants