Skip to content

Commit

Permalink
Make from_dict more flexible, and add from_pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
ColCarroll committed Mar 13, 2024
1 parent 3fc5962 commit c68c913
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.x.x Unreleased

### New features
- Support for `pytree`s and robust to nested dictionaries. (2291)

### Maintenance and fixes
- Fix deprecations introduced in latest pandas and xarray versions, and prepare for numpy 2.0 ones ([2315](https://github.com/arviz-devs/arviz/pull/2315)))
Expand Down
3 changes: 2 additions & 1 deletion arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_datatree import from_datatree, to_datatree
from .io_dict import from_dict
from .io_dict import from_dict, from_pytree
from .io_emcee import from_emcee
from .io_json import from_json, to_json
from .io_netcdf import from_netcdf, to_netcdf
Expand Down Expand Up @@ -38,6 +38,7 @@
"from_cmdstanpy",
"from_datatree",
"from_dict",
"from_pytree",
"from_json",
"from_pyro",
"from_numpyro",
Expand Down
79 changes: 76 additions & 3 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import numpy as np
import tree
import xarray as xr

try:
Expand Down Expand Up @@ -67,6 +68,48 @@ def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
return wrapped


def _yield_flat_up_to(shallow_tree, input_tree, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.
Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
lists as leaves.
Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.
Yields:
Pairs of (path, value), where path the tuple path of a leaf node in
shallow_tree, and value is the value of the corresponding node in
input_tree.
"""
# pylint: disable=protected-access
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
isinstance(shallow_tree, tree.collections_abc.Mapping)
or tree._is_namedtuple(shallow_tree)
or tree._is_attrs(shallow_tree)
):
yield (path, input_tree)
else:
input_tree = dict(tree._yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(
shallow_subtree, input_subtree, path=subpath
):
yield (leaf_path, leaf_value)
# pylint: enable=protected-access


def _flatten_with_path(structure):
return list(_yield_flat_up_to(structure, structure))


def generate_dims_coords(
shape,
var_name,
Expand Down Expand Up @@ -255,7 +298,7 @@ def numpy_to_data_array(
return xr.DataArray(ary, coords=coords, dims=dims)


def dict_to_dataset(
def pytree_to_dataset(
data,
*,
attrs=None,
Expand All @@ -266,11 +309,34 @@ def dict_to_dataset(
index_origin=None,
skip_event_dims=None,
):
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
"""Convert a pytree of numpy arrays to an xarray.Dataset.
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
this inclues at least dictionaries and tuple types.
In case of nested pytrees, the variable name will be a tuple of individual names.
For example,
pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
will have `var_names` `('top', 'second')` and `top2`.
Dimensions and co-ordinates can be defined as usual:
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dataset = convert_to_dataset(datadict,
coords={"c": np.arange(10)},
dims={("top", "b"): ["c"]})
Then `dataset.data_vars` will be `('top', 'a'), ('top', 'b'), 'd'`.
Parameters
----------
data : dict[str] -> ndarray
data : pytree
Data to convert. Keys are variable names.
attrs : dict
Json serializable metadata to attach to the dataset, in addition to defaults.
Expand Down Expand Up @@ -302,6 +368,10 @@ def dict_to_dataset(
"""
if dims is None:
dims = {}
try:
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
except TypeError: # probably unsortable keys -- the function will still work if
pass # it is an honest dictionary.

data_vars = {
key: numpy_to_data_array(
Expand All @@ -318,6 +388,9 @@ def dict_to_dataset(
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))


dict_to_dataset = pytree_to_dataset


def make_attrs(attrs=None, library=None):
"""Make standard attributes to attach to xarray datasets.
Expand Down
4 changes: 4 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""High level conversion functions."""
import numpy as np
import tree
import xarray as xr

from .base import dict_to_dataset
Expand Down Expand Up @@ -105,6 +106,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)

Check warning on line 110 in arviz/data/converters.py

View check run for this annotation

Codecov / codecov/patch

arviz/data/converters.py#L110

Added line #L110 was not covered by tests
elif isinstance(obj, np.ndarray):
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
Expand All @@ -118,6 +121,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"xarray dataarray",
"xarray dataset",
"dict",
"pytree",
"netcdf filename",
"numpy array",
"pystan fit",
Expand Down
3 changes: 3 additions & 0 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,6 @@ def from_dict(
attrs=attrs,
**kwargs,
).to_inference_data()


from_pytree = from_dict
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def plot_pair(
if reference_values:
x_name = flat_var_names[i]
y_name = flat_var_names[j + not_marginals]
if x_name and y_name not in difference:
if (x_name not in difference) and (y_name not in difference):
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
Expand Down
14 changes: 14 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,20 @@ def test_dict_to_dataset():
assert set(dataset.b.coords) == {"chain", "draw", "c"}


def test_nested_dict_to_dataset():
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
assert set(dataset.coords) == {"chain", "draw", "c"}

assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
assert set(dataset.d.coords) == {"chain", "draw"}


def test_dict_to_dataset_event_dims_error():
datadict = {"a": np.random.randn(1, 100, 10)}
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ numpy>=1.22.0,<2.0
scipy>=1.8.0
packaging
pandas>=1.4.0
dm-tree>=0.1.8
xarray>=0.21.0
h5netcdf>=1.0.2
typing_extensions>=4.1.0
Expand Down

0 comments on commit c68c913

Please sign in to comment.