Skip to content

Commit

Permalink
restructure docs
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 13, 2024
1 parent c68c913 commit 9330389
Showing 1 changed file with 44 additions and 29 deletions.
73 changes: 44 additions & 29 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,61 +309,76 @@ def pytree_to_dataset(
index_origin=None,
skip_event_dims=None,
):
"""Convert a pytree of numpy arrays to an xarray.Dataset.
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
this inclues at least dictionaries and tuple types.
In case of nested pytrees, the variable name will be a tuple of individual names.
For example,
pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
will have `var_names` `('top', 'second')` and `top2`.
Dimensions and co-ordinates can be defined as usual:
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dataset = convert_to_dataset(datadict,
coords={"c": np.arange(10)},
dims={("top", "b"): ["c"]})
Then `dataset.data_vars` will be `('top', 'a'), ('top', 'b'), 'd'`.
Parameters
----------
data : pytree
data : dict of {str : array_like or dict} or pytree
Data to convert. Keys are variable names.
attrs : dict
attrs : dict, optional
Json serializable metadata to attach to the dataset, in addition to defaults.
library : module
library : module, optional
Library used for performing inference. Will be attached to the attrs metadata.
coords : dict[str] -> ndarray
coords : dict of {str : ndarray}, optional
Coordinates for the dataset
dims : dict[str] -> list[str]
dims : dict of {str : list of str}, optional
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
default_dims : list of str, optional
Passed to :py:func:`numpy_to_data_array`
index_origin : int, optional
Passed to :py:func:`numpy_to_data_array`
skip_event_dims : bool
skip_event_dims : bool, optional
If True, cut extra dims whenever present to match the shape of the data.
Necessary for PPLs which have the same name in both observed data and log
likelihood groups, to account for their different shapes when observations are
multivariate.
Returns
-------
xr.Dataset
xarray.Dataset
In case of nested pytrees, the variable name will be a tuple of individual names.
Notes
-----
This function is available through two aliases: ``dict_to_dataset`` or ``pytree_to_dataset``.
Examples
--------
dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
Convert a dictionary with two 2D variables to a Dataset.
.. ipython::
dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
Note that unlike the :class:`xarray.Dataset` constructor, ArviZ has added extra
information to the generated Dataset such as default dimension names for sampled
dimensions and some attributes.
The function is also general enough to work on pytrees such as nested dictionaries:
.. ipython::
pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})
which has two variables (as many as leafs) named ``('top', 'second')`` and ``top2``.
Dimensions and co-ordinates can be defined as usual:
.. ipython::
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dict_to_dataset(
datadict,
coords={"c": np.arange(10)},
dims={("top", "b"): ["c"]}
)
"""
if dims is None:
Expand Down

0 comments on commit 9330389

Please sign in to comment.