Skip to content

Commit

Permalink
Reimplement Dataset and InferenceData using DimensionalData (#191)
Browse files Browse the repository at this point in the history
* Add DimensionalData as dependency

* Create lookup for group order

* Implement Dataset using DimensionalData

* Remove methods defined in DimensionalData

* Deprecate indexing with strings

* Fix show method

* Simplify dataset_to_dict implementation

* Implement converters to/from xarray

* Add accessor for wrapped data

* Update attribute get/setters

* Fix deprecation of indexing

* Deprecate using a string for group name

* Remove xarray designations

* Document non-allocating behavior

* Implement InferenceData as wrapping a dict of Datasets

* Update group accessors

* Update property accessors

* Remove redundant constructors

* Update function implementations

* Add hasgroup

* Deprecate addition of InferenceData

* Add converters to/from xarray

* Remove reference to xarray

* Implement merge and merge!

* Replace concat! with merge!

* Don't use deprecated method

* Update diagnostics tests

* Increment version number

* Increment ArviZ version

* Remove deprecation

* Don't use xarray html reprs

* Add back plaintext show method

* Make sure strings are mapped to symbols

* Let strings be strings

* Convert dims to range if possible

* Use namedtuple backing for InferenceData

* Make Dataset just wrap an AbstractDimStack

* Don't automatically convert xarray to Dataset

* Don't assume InferenceData is mutable

* Correctly convert mixed type arrays

* Move xarray code to its own file

* Collect Python interop at the end of file

* Collect dimensionaldata code

* Fix conversion from PyObject

* Work around filter not being implemented for NamedTuple

* Avoid using entire DimensionalData namespace

* Export Dataset

* Revert "Export Dataset"

This reverts commit a3aa171.

* Update dataset construction tests

* Fix rebuild overload

* Remove getindex

* Reorganize convert methods

* Add and fix more tests

* Add and repair tests

* Use DimensionalData.index

* Add DimensionalData to test suite

* Test Dataset to/from xarray conversion

* Update convert_to_dataset tests

* Use tuple method of index

* Generalize processing of python arrays

* Use issetequal

* Update convert_to_constant_dataset tests

* Update remaining dataset tests

* Make dataset tests a single testset

* Repair dataset_to_dict

* Run formatter

* Allow accessing variables with properties

* Rename varname to var_name

* Implement indexing and iteration interfaces

* Get all supported groups

* Make iteration type inferrable

* Update many InferenceData tests

* Run formatter

* Update many tests

* Update conversion tests

* Remove addition of InferenceDatas

* Test conversion to Python

* Add helper function to check idatas are equal

* Add dependencies

* Define default sample dims

* Define namedtuple_to_dataset

* Implement convert_to_inference_data

* Implement package_version

* Add simpler forms of rekey

* Use OrderedDict and add/get metadata to/from arrays

* Fix InferenceData docstring

* Update concat test

* Update helpers

* Add schema check

* Fix broken tests

* Run formatter

* Test namedtuple_to_dataset

* Actually use convert_to_dataset

* Update from_namedtuple

* Update docstrings

* Add broken test for inferrability

* Make sure dims and coords default to NamedTuple

* Update MCMCChains converter

* Update tests

* Remove dead code

* Make NamedTuple

* Update docstrings

* Update test dependencies

* Mark as dev release

* Update mpl_examples.md

* Update quickstart.jl

* Always pass an array

* Don't send a NamedTuple to Python

* Run formatter

* Update storage documentation

* Avoid patterns that poorly convert to Python types

* Revert back to String

* Reduce indentation in REPL display of InferenceData

* Convert special arrays to raw arrays for plot_kde

* Test package_version

* Remove outdated reference to convert_to_constant_dataset

* Fix typo

* Fix typo

* Don't export undefined function

* Remove concat! from API docs

* Make tests work for Julia 1.6

* Avoid setting dims and coords unless defined

* Remove unused function

* Merge PyObject conversion functions

* Don't drop directly to Python

* Add missing tests

* Test default_var_name

* Generalize convert function

* Add missing tests

* Test using type for dimension

* Test prior_predictive as Chains

* Add missing tests

* Repair rekey for Dict
  • Loading branch information
sethaxen authored Jul 9, 2022
1 parent 4cb4841 commit 1f64237
Show file tree
Hide file tree
Showing 28 changed files with 1,469 additions and 1,121 deletions.
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
name = "ArviZ"
uuid = "131c737c-5715-5e2e-ad31-c244f01c1dc7"
authors = ["Seth Axen <seth.axen@gmail.com>"]
version = "0.5.22"
version = "0.6.0-DEV"

[deps]
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Expand All @@ -18,7 +21,9 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[compat]
Conda = "1.0"
DataFrames = "0.20, 0.21, 0.22, 1.0"
DimensionalData = "0.20"
LogExpFunctions = "0.2.0, 0.3"
OrderedCollections = "1"
PSIS = "0.2, 0.3, 0.4, 0.5"
PyCall = "1.91.2"
PyPlot = "2.8.2"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
ArviZ = "0.5"
ArviZ = "0.6"
CmdStan = "6.0"
Distributions = "0.25"
Documenter = "0.27"
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
| [`from_mcmcchains`](@ref) | Convert `MCMCChains` data into an `InferenceData`. |
| [`from_samplechains`](@ref) | Convert `SampleChains` data into an `InferenceData`. |
| [`concat`](@ref) | Concatenate `InferenceData` objects. |
| [`concat!`](@ref) | Concatenate `InferenceData` objects in-place. |

## [Utils](@id utils-api)

Expand Down
5 changes: 3 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ These functions are not exported to reduce namespace clutter.
For `InferenceData` inputs, [`summarystats`](@ref) replaces `arviz.summary` to avoid confusion with `Base.summary`.
For arbitrary inputs and the full functionality of `arviz.summary`, use [`ArviZ.summary`](@ref), which is not exported.

ArviZ.jl transparently interconverts between `arviz.InferenceData` and our own [`InferenceData`](@ref), used for dispatch.
`InferenceData` has identical usage to its Python counterpart.
While Python ArviZ is built on xarray, and `InferenceData` groups are `xarray.Dataset`s, ArviZ.jl is built on [DimensionalData.jl](https://rafaqz.github.io/DimensionalData.jl/stable/), and `InferenceData` groups are `DimensionalData.AbstractDimStack`s that have identical usage to [`DimensionalData.DimStack`](https://rafaqz.github.io/DimensionalData.jl/stable/api/#DimensionalData.DimStack).
When ArviZ.jl uses functionality implemented in Python, it transparently interconverts `InferenceData` to/from `arviz.InferenceData`.
Going to Python, this is non-allocating.

Functions that in ArviZ return Pandas types here return [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl) types.

Expand Down
86 changes: 43 additions & 43 deletions docs/src/mpl_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("centered_eight")
plot_autocorr(data; var_names=("tau", "mu"))
plot_autocorr(data; var_names=[:tau, :mu])
gcf()
```
Expand Down Expand Up @@ -48,7 +48,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("regression1d")
plot_bpv(data; kind="t_stat", t_stat="0.5")
plot_bpv(data; kind=:t_stat, t_stat="0.5")
gcf()
```
Expand Down Expand Up @@ -91,7 +91,7 @@ non_centered_data = load_arviz_data("non_centered_eight")
plot_density(
[centered_data, non_centered_data];
data_labels=["Centered", "Non Centered"],
var_names=["theta"],
var_names=[:theta],
shade=0.1,
)
Expand Down Expand Up @@ -169,7 +169,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
idata = load_arviz_data("radon")
plot_ess(idata; var_names=["b"], kind="evolution")
plot_ess(idata; var_names=[:b], kind=:evolution)
gcf()
```
Expand All @@ -186,7 +186,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
idata = load_arviz_data("non_centered_eight")
plot_ess(idata; var_names=["mu"], kind="local", marker="_", ms=20, mew=2, rug=true)
plot_ess(idata; var_names=[:mu], kind=:local, marker="_", ms=20, mew=2, rug=true)
gcf()
```
Expand All @@ -203,7 +203,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
idata = load_arviz_data("radon")
plot_ess(idata; var_names=["sigma"], kind="quantile", color="C4")
plot_ess(idata; var_names=[:sigma], kind=:quantile, color="C4")
gcf()
```
Expand All @@ -224,7 +224,7 @@ non_centered_data = load_arviz_data("non_centered_eight")
plot_forest(
[centered_data, non_centered_data];
model_names=["Centered", "Non Centered"],
var_names=["mu"],
var_names=[:mu],
)
title("Estimated theta for eight schools model")
Expand All @@ -245,12 +245,12 @@ ArviZ.use_style("arviz-darkgrid")
rugby_data = load_arviz_data("rugby")
plot_forest(
rugby_data;
kind="ridgeplot",
var_names=["defs"],
kind=:ridgeplot,
var_names=[:defs],
linewidth=4,
combined=true,
ridgeplot_overlap=1.5,
colors="blue",
colors=:blue,
figsize=(9, 4),
)
title("Relative defensive strength\nof Six Nation rugby teams")
Expand All @@ -276,7 +276,7 @@ x_data = randn(100)
y_data = 2 .+ x_data .* 0.5
y_data_rep = 0.5 .* randn(200, 100) .+ transpose(y_data)
plot(x_data, y_data; color="C6")
plot_hdi(x_data, y_data_rep; color="k", plot_kwargs=Dict("ls" => "--"))
plot_hdi(x_data, y_data_rep; color=:k, plot_kwargs=Dict(:ls => "--"))
gcf()
```
Expand All @@ -295,9 +295,9 @@ ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("non_centered_eight")
plot_pair(
data;
var_names=["theta"],
coords=Dict("school" => ["Choate", "Phillips Andover"]),
kind="hexbin",
var_names=[:theta],
coords=Dict(:school => ["Choate", "Phillips Andover"]),
kind=:hexbin,
marginals=true,
figsize=(10, 10),
)
Expand All @@ -319,16 +319,16 @@ ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("centered_eight")
## Combine different posterior draws from different chains
obs = data.posterior_predictive["obs"].values
obs = data.posterior_predictive.obs
size_obs = size(obs)
y_hat = reshape(obs, prod(size_obs[1:2]), size_obs[3:end]...)
plot_kde(
y_hat;
label="Estimated Effect\n of SAT Prep",
rug=true,
plot_kwargs=Dict("linewidth" => 2, "color" => "black"),
rug_kwargs=Dict("color" => "black"),
plot_kwargs=Dict(:linewidth => 2, :color => :black),
rug_kwargs=Dict(:color => :black),
)
gcf()
Expand Down Expand Up @@ -405,7 +405,7 @@ ArviZ.use_style("arviz-darkgrid")
idata = load_arviz_data("radon")
plot_loo_pit(idata; y="y", ecdf=true, color="maroon")
plot_loo_pit(idata; y=:y, ecdf=true, color=:maroon)
gcf()
```
Expand All @@ -422,7 +422,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
idata = load_arviz_data("non_centered_eight")
plot_loo_pit(; idata, y="obs", color="indigo")
plot_loo_pit(; idata, y=:obs, color=:indigo)
gcf()
```
Expand All @@ -439,7 +439,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("centered_eight")
plot_mcse(data; var_names=["tau", "mu"], rug=true, extra_methods=true)
plot_mcse(data; var_names=[:tau, :mu], rug=true, extra_methods=true)
gcf()
```
Expand All @@ -456,7 +456,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("radon")
plot_mcse(data; var_names=["sigma_a"], color="C4", errorbar=true)
plot_mcse(data; var_names=[:sigma_a], color="C4", errorbar=true)
gcf()
```
Expand All @@ -473,9 +473,9 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
centered = load_arviz_data("centered_eight")
coords = Dict("school" => ["Choate", "Deerfield"])
coords = Dict(:school => ["Choate", "Deerfield"])
plot_pair(
centered; var_names=["theta", "mu", "tau"], coords, divergences=true, textsize=22
centered; var_names=[:theta, :mu, :tau], coords, divergences=true, textsize=22
)
gcf()
Expand All @@ -493,11 +493,11 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
centered = load_arviz_data("centered_eight")
coords = Dict("school" => ["Choate", "Deerfield"])
coords = Dict(:school => ["Choate", "Deerfield"])
plot_pair(
centered;
var_names=["theta", "mu", "tau"],
kind="hexbin",
var_names=[:theta, :mu, :tau],
kind=:hexbin,
coords,
colorbar=true,
divergences=true,
Expand All @@ -518,11 +518,11 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
centered = load_arviz_data("centered_eight")
coords = Dict("school" => ["Choate", "Deerfield"])
coords = Dict(:school => ["Choate", "Deerfield"])
plot_pair(
centered;
var_names=["theta", "mu", "tau"],
kind="kde",
var_names=[:theta, :mu, :tau],
kind=:kde,
coords,
divergences=true,
textsize=22,
Expand All @@ -543,15 +543,15 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
centered = load_arviz_data("centered_eight")
coords = Dict("school" => ["Choate", "Deerfield"])
coords = Dict(:school => ["Choate", "Deerfield"])
plot_pair(
centered;
var_names=["mu", "theta"],
kind=["scatter", "kde"],
kde_kwargs=Dict("fill_last" => false),
var_names=[:mu, :theta],
kind=[:scatter, :kde],
kde_kwargs=Dict(:fill_last => false),
marginals=true,
coords,
point_estimate="median",
point_estimate=:median,
figsize=(10, 8),
)
Expand All @@ -570,7 +570,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("centered_eight")
ax = plot_parallel(data; var_names=["theta", "tau", "mu"])
ax = plot_parallel(data; var_names=[:theta, :tau, :mu])
ax.set_xticklabels(ax.get_xticklabels(); rotation=70)
draw()
Expand All @@ -589,8 +589,8 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("centered_eight")
coords = Dict("school" => ["Choate"])
plot_posterior(data; var_names=["mu", "theta"], coords, rope=(-1, 1))
coords = Dict(:school => ["Choate"])
plot_posterior(data; var_names=[:mu, :theta], coords, rope=(-1, 1))
gcf()
```
Expand All @@ -607,7 +607,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("non_centered_eight")
plot_ppc(data; data_pairs=Dict("obs" => "obs"), alpha=0.03, figsize=(12, 6), textsize=14)
plot_ppc(data; data_pairs=Dict(:obs => :obs), alpha=0.03, figsize=(12, 6), textsize=14)
gcf()
```
Expand All @@ -624,7 +624,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("non_centered_eight")
plot_ppc(data; alpha=0.3, kind="cumulative", figsize=(12, 6), textsize=14)
plot_ppc(data; alpha=0.3, kind=:cumulative, figsize=(12, 6), textsize=14)
gcf()
```
Expand All @@ -641,7 +641,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("centered_eight")
plot_rank(data; var_names=("tau", "mu"))
plot_rank(data; var_names=[:tau, :mu])
gcf()
```
Expand All @@ -658,7 +658,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("classification10d")
plot_separation(data; y="outcome", y_hat="outcome", figsize=(8, 1))
plot_separation(data; y=:outcome, y_hat=:outcome, figsize=(8, 1))
gcf()
```
Expand All @@ -675,7 +675,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("non_centered_eight")
plot_trace(data; var_names=("tau", "mu"))
plot_trace(data; var_names=[:tau, :mu])
gcf()
```
Expand All @@ -692,7 +692,7 @@ using ArviZ
ArviZ.use_style("arviz-darkgrid")
data = load_arviz_data("non_centered_eight")
plot_violin(data; var_names=["mu", "tau"])
plot_violin(data; var_names=[:mu, :tau])
gcf()
```
Expand Down
Loading

0 comments on commit 1f64237

Please sign in to comment.