diff --git a/CHANGELOG.md b/CHANGELOG.md index a5b807ff4c..6c47440dd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v0.x.x Unreleased ### New features +- Support for `pytree`s and robust to nested dictionaries. (2291) ### Maintenance and fixes - Fix deprecations introduced in latest pandas and xarray versions, and prepare for numpy 2.0 ones ([2315](https://github.com/arviz-devs/arviz/pull/2315))) diff --git a/arviz/data/__init__.py b/arviz/data/__init__.py index 742fece161..572e9dfe6c 100644 --- a/arviz/data/__init__.py +++ b/arviz/data/__init__.py @@ -1,5 +1,5 @@ """Code for loading and manipulating data structures.""" -from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array +from .base import CoordSpec, DimSpec, dict_to_dataset, numpy_to_data_array, pytree_to_dataset from .converters import convert_to_dataset, convert_to_inference_data from .datasets import clear_data_home, list_datasets, load_arviz_data from .inference_data import InferenceData, concat @@ -7,7 +7,7 @@ from .io_cmdstan import from_cmdstan from .io_cmdstanpy import from_cmdstanpy from .io_datatree import from_datatree, to_datatree -from .io_dict import from_dict +from .io_dict import from_dict, from_pytree from .io_emcee import from_emcee from .io_json import from_json, to_json from .io_netcdf import from_netcdf, to_netcdf @@ -38,10 +38,12 @@ "from_cmdstanpy", "from_datatree", "from_dict", + "from_pytree", "from_json", "from_pyro", "from_numpyro", "from_netcdf", + "pytree_to_dataset", "to_datatree", "to_json", "to_netcdf", diff --git a/arviz/data/base.py b/arviz/data/base.py index cf0f281e87..bc85acc2df 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import numpy as np +import tree import xarray as xr try: @@ -67,6 +68,48 @@ def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]: return wrapped +def _yield_flat_up_to(shallow_tree, input_tree, path=()): + """Yields (path, value) pairs of input_tree flattened up to shallow_tree. + + Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow + lists as leaves. + + Args: + shallow_tree: Nested structure. Traverse no further than its leaf nodes. + input_tree: Nested structure. Return the paths and values from this tree. + Must have the same upper structure as shallow_tree. + path: Tuple. Optional argument, only used when recursing. The path from the + root of the original shallow_tree, down to the root of the shallow_tree + arg of this recursive call. + + Yields: + Pairs of (path, value), where path the tuple path of a leaf node in + shallow_tree, and value is the value of the corresponding node in + input_tree. + """ + # pylint: disable=protected-access + if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not ( + isinstance(shallow_tree, tree.collections_abc.Mapping) + or tree._is_namedtuple(shallow_tree) + or tree._is_attrs(shallow_tree) + ): + yield (path, input_tree) + else: + input_tree = dict(tree._yield_sorted_items(input_tree)) + for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree): + subpath = path + (shallow_key,) + input_subtree = input_tree[shallow_key] + for leaf_path, leaf_value in _yield_flat_up_to( + shallow_subtree, input_subtree, path=subpath + ): + yield (leaf_path, leaf_value) + # pylint: enable=protected-access + + +def _flatten_with_path(structure): + return list(_yield_flat_up_to(structure, structure)) + + def generate_dims_coords( shape, var_name, @@ -255,7 +298,7 @@ def numpy_to_data_array( return xr.DataArray(ary, coords=coords, dims=dims) -def dict_to_dataset( +def pytree_to_dataset( data, *, attrs=None, @@ -266,26 +309,29 @@ def dict_to_dataset( index_origin=None, skip_event_dims=None, ): - """Convert a dictionary of numpy arrays to an xarray.Dataset. + """Convert a dictionary or pytree of numpy arrays to an xarray.Dataset. + + See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but + this inclues at least dictionaries and tuple types. Parameters ---------- - data : dict[str] -> ndarray + data : dict of {str : array_like or dict} or pytree Data to convert. Keys are variable names. - attrs : dict + attrs : dict, optional Json serializable metadata to attach to the dataset, in addition to defaults. - library : module + library : module, optional Library used for performing inference. Will be attached to the attrs metadata. - coords : dict[str] -> ndarray + coords : dict of {str : ndarray}, optional Coordinates for the dataset - dims : dict[str] -> list[str] + dims : dict of {str : list of str}, optional Dimensions of each variable. The keys are variable names, values are lists of coordinates. default_dims : list of str, optional Passed to :py:func:`numpy_to_data_array` index_origin : int, optional Passed to :py:func:`numpy_to_data_array` - skip_event_dims : bool + skip_event_dims : bool, optional If True, cut extra dims whenever present to match the shape of the data. Necessary for PPLs which have the same name in both observed data and log likelihood groups, to account for their different shapes when observations are @@ -293,15 +339,56 @@ def dict_to_dataset( Returns ------- - xr.Dataset + xarray.Dataset + In case of nested pytrees, the variable name will be a tuple of individual names. + + Notes + ----- + This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``. Examples -------- - dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)}) + Convert a dictionary with two 2D variables to a Dataset. + + .. ipython:: + + In [1]: import arviz as az + ...: import numpy as np + ...: az.dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)}) + + Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra + information to the generated Dataset such as default dimension names for sampled + dimensions and some attributes. + + The function is also general enough to work on pytrees such as nested dictionaries: + + .. ipython:: + + In [1]: az.pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.}) + + which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``. + + Dimensions and co-ordinates can be defined as usual: + + .. ipython:: + + In [1]: datadict = { + ...: "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}, + ...: "d": np.random.randn(100), + ...: } + ...: az.dict_to_dataset( + ...: datadict, + ...: coords={"c": np.arange(10)}, + ...: dims={("top", "b"): ["c"]} + ...: ) """ if dims is None: dims = {} + try: + data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)} + except TypeError: # probably unsortable keys -- the function will still work if + pass # it is an honest dictionary. data_vars = { key: numpy_to_data_array( @@ -318,6 +405,9 @@ def dict_to_dataset( return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library)) +dict_to_dataset = pytree_to_dataset + + def make_attrs(attrs=None, library=None): """Make standard attributes to attach to xarray datasets. diff --git a/arviz/data/converters.py b/arviz/data/converters.py index 2961f0aaf1..a8f34bc490 100644 --- a/arviz/data/converters.py +++ b/arviz/data/converters.py @@ -1,5 +1,6 @@ """High level conversion functions.""" import numpy as np +import tree import xarray as xr from .base import dict_to_dataset @@ -105,6 +106,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, dataset = obj.to_dataset() elif isinstance(obj, dict): dataset = dict_to_dataset(obj, coords=coords, dims=dims) + elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)): + dataset = dict_to_dataset(obj, coords=coords, dims=dims) elif isinstance(obj, np.ndarray): dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims) elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"): @@ -118,6 +121,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, "xarray dataarray", "xarray dataset", "dict", + "pytree", "netcdf filename", "numpy array", "pystan fit", diff --git a/arviz/data/io_dict.py b/arviz/data/io_dict.py index 4d34157ddc..d76a7511c6 100644 --- a/arviz/data/io_dict.py +++ b/arviz/data/io_dict.py @@ -458,3 +458,6 @@ def from_dict( attrs=attrs, **kwargs, ).to_inference_data() + + +from_pytree = from_dict diff --git a/arviz/plots/backends/matplotlib/pairplot.py b/arviz/plots/backends/matplotlib/pairplot.py index 7e50b43d8e..f17f51e19b 100644 --- a/arviz/plots/backends/matplotlib/pairplot.py +++ b/arviz/plots/backends/matplotlib/pairplot.py @@ -333,7 +333,7 @@ def plot_pair( if reference_values: x_name = flat_var_names[i] y_name = flat_var_names[j + not_marginals] - if x_name and y_name not in difference: + if (x_name not in difference) and (y_name not in difference): ax[j, i].plot( reference_values_copy[x_name], reference_values_copy[y_name], diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index a777898b96..bb1882059e 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -1077,6 +1077,20 @@ def test_dict_to_dataset(): assert set(dataset.b.coords) == {"chain", "draw", "c"} +def test_nested_dict_to_dataset(): + datadict = { + "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}, + "d": np.random.randn(100), + } + dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]}) + assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"} + assert set(dataset.coords) == {"chain", "draw", "c"} + + assert set(dataset[("top", "a")].coords) == {"chain", "draw"} + assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"} + assert set(dataset.d.coords) == {"chain", "draw"} + + def test_dict_to_dataset_event_dims_error(): datadict = {"a": np.random.randn(1, 100, 10)} coords = {"b": np.arange(10), "c": ["x", "y", "z"]} diff --git a/requirements.txt b/requirements.txt index d764477be1..549d4fa07b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ numpy>=1.22.0,<2.0 scipy>=1.8.0 packaging pandas>=1.4.0 +dm-tree>=0.1.8 xarray>=0.21.0 h5netcdf>=1.0.2 typing_extensions>=4.1.0