-
-
Notifications
You must be signed in to change notification settings - Fork 403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling dask kwargs to be passed into xr.open_netcdf #1749
Conversation
Moving conversation from #1655 to here
I don't know the right answer to this question -- maybe better documentation? The only thing that is done here is further exposing the import arviz as az
res = az.from_netcdf(fname, chunks="auto")
az.rhat(res).compute()
A way around this is to chunk by draws instead res = az.InferenceData.from_netcdf(fname, chunks={'draws': 50})
az.rhat(res).compute() I'm not sure about the other functions, it maybe a case-by-case basis. But from what I can tell, it is useful to first expose the full functionality of |
The point about better documentation is very relevant but different from what I was asking. One thing is that users will need to learn how to properly chunk for things to work properly. Deciding how to chunk is generally tricky and has a huge effect on performance, so it is definitely something users will need to be aware of. I hope we can externalize that guidance however and in our docs include only links to dask and xarray docs, with a highlights paragraph at most. Another question is once this chunking has been decide and it's optimal, how do we make sure the user can actually transmit that downstream to xarray/dask. I'll try to use an example. Take the centered_eight example. For rhat and ess, the ideal chunking (it it were needed) would be to chunk over the school dimension (see next paragraph for details). This raises two questions:
Back to your current examples and the selection of chunking dimension. For things to work, the dimensions that are chunked should not be core dimensions of the operations we want to perform. For rhat and ess the core dimensions are both chain and draw, so if a dataset is chunked on those dimensions, one can't call rhat nor ess without rechunking. In your examples, the first is erroring out when reading (not sure why though), and the second one is working even though the data is chunked on draw which makes me suspect that dask was not activated, the data was read into memory as a numpy array and the computation went ahead without chunking problems 🤔 . As for possible solutions, maybe we could have some dict of kwargs? so that we have first keys group names and values kwarg dicts for that group instead of kwargs directly. It would be great to allow regex or metagroups like we do with idata methods (i.e. https://arviz-devs.github.io/arviz/api/generated/arviz.InferenceData.sel.html#arviz.InferenceData.sel) so that |
Fair enough. I pushed in a dict of dict argument. You should be able to verify the speed up on regression10 import arviz as az
from arviz.utils import Dask
from dask.distributed import Client
centered_data = az.load_arviz_data("regression10d")
%timeit az.rhat(centered_data)
Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]})
client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB")
group_kwargs = {
'posterior': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}},
'posterior_predictive': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0': 2}}
}
centered_data = az.load_arviz_data("regression10d", group_kwargs=group_kwargs)
%timeit az.rhat(centered_data)
The eight schools problem is a bit too small - I suspect data exchange time is overwhelming gains made from parallelism. |
Regarding the earlier comment about the user interface, perhaps configparser is the way to go. The dask interface is sufficiently complex enough where this may be helpful (at least as a user, I would find configuring a config more intuitive than passing in a triple layer dict). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}
I think this part should not have the dimension repeated, two keys should be used for 2d chunks. 'chunks': {'true_w_dim_0': 2}
should do the same.
22.8 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
vs the dask chunked version
4.06 ms ± 22.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
This is a quite good speedup. This is probably due to our make_ufunc
not being very efficient because dask should not provide such speedups for data that comfortably fits into memory.
Regarding the earlier comment about the user interface, perhaps configparser is the way to go. The dask interface is sufficiently complex enough where this may be helpful (at least as a user, I would find configuring a config more intuitive than passing in a triple layer dict).
I think it would be better to include an example (maybe also some helper configparser already configured?) of how to use configparser to define this nested dict instead of trying to get the function to use configparser. for example:
import configparser
sample_config = """
[posterior]
[chunks]
true_w_dim_0: 2
[posterior_predictive]
[chunks]
true_w_dim_0: 2
"""
config = configparser.ConfigParser()
config.read_string(sample_config)
idata = az.from_netcdf(..., group_kwargs=config)
Note that I have not used configparser and only skimmed the linked page. I have added another comment with more functionality that would be useful in my opinion, not sure if configparser would be useful for that. Not everything needs to be implemennted in this PR directly, once we have chosen the api we can extend the functionality later (i.e. if we use dict of dicts, getting the 2nd option in the comment to work and adding tests can be done in a follow up PR)
This is a bit of "silly" question, but is there a way to chunk against all but chain and draw dimension? E.g. if inf data has 100k parameters -> chunk against "non-dim" vs 100k long vector -> chunk against vector dim? Or is this something dask does automatically? |
@OriolAbril I've modified the regex option and confirmed that it works, and I like this option since it is simple and easy to extend. I don't believe that it will necessarily exclude the other options we discussed, since the extended @ahartikainen yes, but at the moment, this will have to be done manually by specifying each dimension -- I anticipate that future iterations could help solve this by creating helper functions to output the appropriate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is great for now, but I would probably add some warning in the docstring for group_kwargs and regex about them being experimental, or somewhat experimental. The timeline is not completely clear but xarray is working on hierarchical netcdf like objects support which would be some extended version of InferenceData (pydata/xarray#4118), and we should use that once available also with its much probably better dask support.
I am not sure why tests are not working for basetests. I think dask should be installed in https://github.com/arviz-devs/arviz/blob/main/.azure-pipelines/azure-pipelines-base.yml#L57, is dask not the right package? should it be |
Probably needs the dask distributed dependency. See https://distributed.dask.org/en/latest/install.html |
I am a bit hazy about dask components, so I don't know which should be installed, but I do think it would be better, especially for users who don't know about dask (nor need to know much if interfacing with it via xarray only), to use these My main issue is that I don't know what distributed and delayed are so I am not sure which should be installed: http://xarray.pydata.org/en/stable/user-guide/dask.html |
Right, dask is quite a complex ecosystem ... Regarding dask components, the distributed package is currently the recommended way to setup single-machine parallelism, since it comes with a more sophisticated scheduler (see here : https://docs.dask.org/en/latest/scheduling.html) -- and it is also easier to keep track of scope. xarray also has the dependency in the ci. So I don't think it is unreasonable to default to xarray's dask optional dependencies. |
ok, it looks like most of the tests are passing. Although I'm seeing a weird error |
No idea really, it looks like we are trying to do something forbidden in Azure somehow. Are the tests failing always the same? |
From what I can tell no. Pushing another random commit to see if it isn't a random failure. |
ok, so I'm a little stumped why there isn't any chunking information returned in the azure ci. On my system, I'm able to confirm that the chunking works
|
Hi @OriolAbril / @ahartikainen I'm still a bit puzzled by the errors, its as if xarray isn't chunking correctly. |
a037840
to
8f0a08f
Compare
ok, so most of the tests are passing now; but there appears to be some random tests failing that aren't relevant for this PR. |
I think cloudpickle has a new version and tfp can not use it. Let me try to find a fix for it. |
It should work after rebase with main |
@ahartikainen when you mean rebase, you mean just merging with main correct? I did notice some windows carriage returns in .pylinterc, io_tfp.py and stats_utils.py, which is why I didn't pull those down earlier. But as you can see, merging those changes in doesn't change the error messages ... |
f620a31
to
3ce21ec
Compare
I noticed that cloudpickle version was 2.0 in CI even though it is set to be <1.5 in https://github.com/arviz-devs/arviz/blob/main/requirements-dev.txt#L12. I am not sure why. I tried rebasing myself to see if it helped but it looks like it did not. I have also opened #1890 as I believe the best way forward is probably to remove the tfp converter altogether as it is severely outdated and will generate more and more problems every day that passes. |
3ce21ec
to
79805d3
Compare
Codecov Report
@@ Coverage Diff @@
## main #1749 +/- ##
==========================================
- Coverage 91.72% 91.70% -0.02%
==========================================
Files 116 116
Lines 12362 12373 +11
==========================================
+ Hits 11339 11347 +8
- Misses 1023 1026 +3
Continue to review full report at Codecov.
|
Tests are working! Thanks for your patience @mortonjt! |
Thank you @OriolAbril @ahartikainen ! |
Description
This pull request allows for dask kwargs to be passed into
az.InferenceData.from_netcdf
viaxr.open_dataset
.Ultimately, this allows for chunking information to be passed into the
az.InferenceData
objects.Relevant PRs
#1655
Relevant issues
#786
#1066
#1631
I'm noticing orders of magnitude improvement for rhat calculations, which are one of the bottlenecks in my current pipeline.
Checklist