Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the Tables interface for Dataset #141

Closed
sethaxen opened this issue Sep 1, 2021 · 4 comments
Closed

Implement the Tables interface for Dataset #141

sethaxen opened this issue Sep 1, 2021 · 4 comments

Comments

@sethaxen
Copy link
Member

sethaxen commented Sep 1, 2021

There are pros and cons to implementing the Tables interface for Dataset.

Pros:

  • Primarily: users can instantly use the Tables-compatible @df macro in StatsPlots or data wrapper in AlgebraOfGraphics to construct plots.
  • Users can easily convert a Dataset to any Tables sink, such as a DataFrames.DataFrame, and if we implement the interface to make Dataset a Tables sink, then users can also construct a Dataset from anything else that implements the Tables interface.

Cons:

  • Unlike MCMCChains.Chains, which also implements the Tables interface, data stored in a Dataset is typically inherently multidimensional. To flatten it into a table structure, we would need to discard that multidimensional information. From this perspective, it would be better for plotting purposes to have something like Tables that is friendly to multidimensional data and can use the coords and dims we store.
  • This will cause Pluto.jl to override our custom HTML representation for Dataset, since it checks for the Tables interface early in the pipeline.
@ahartikainen
Copy link

Can this do similar unpacking we do in Python side to dataframe?

@sethaxen
Copy link
Member Author

sethaxen commented Sep 1, 2021

It would be similar, yeah. It would look something like

using ArviZ, DataFrames
idata = load_arviz_data("radon")
df = DataFrame(idata.posterior)

DataFrame's default constructor falls back to the Tables interface, so all we would need to do here is implement that, and then DataFrame could flatten our Datasets automatically. I'll need to look at more examples to understand how xarray.Dataset's to_dataframe method works when dimensions don't line up. it looks to me like it makes lots of copies to line everything up:

julia> idata = load_arviz_data("glycan_torsion_angles");

julia> idata.posterior
Dataset (xarray.Dataset)
Dimensions:      (chain: 1, draw: 2000, tors_dim_0: 4, cs_mu_dim_0: 4)
Coordinates:
  * chain        (chain) int64 0
  * draw         (draw) int64 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999
  * tors_dim_0   (tors_dim_0) int64 0 1 2 3
  * cs_mu_dim_0  (cs_mu_dim_0) int64 0 1 2 3
Data variables:
    alpha        (chain, draw) float64 ...
    beta         (chain, draw) float64 ...
    tors         (chain, draw, tors_dim_0) float64 ...
    E            (chain, draw) float64 ...
    cs_mu        (chain, draw, cs_mu_dim_0) float64 ...
    sd           (chain, draw) float64 ...
    nu           (chain, draw) float64 ...
Attributes:
    created_at:                 2020-06-03T12:48:04.865326
    inference_library:          pymc3
    inference_library_version:  3.6rc1

julia> idata.posterior.to_dataframe()
PyObject                                        alpha      beta      tors           E       cs_mu        sd         nu
chain draw tors_dim_0 cs_mu_dim_0                                                                            
0     0    0          0           -20.889500  1.218748 -1.155806  151.343733  106.420083  7.373658   5.238337
                      1           -20.889500  1.218748 -1.155806  151.343733   74.920956  7.373658   5.238337
                      2           -20.889500  1.218748 -1.155806  151.343733  105.464299  7.373658   5.238337
                      3           -20.889500  1.218748 -1.155806  151.343733   80.719928  7.373658   5.238337
           1          0           -20.889500  1.218748  2.395331  151.343733  106.420083  7.373658   5.238337
...                                      ...       ...       ...         ...         ...       ...        ...
      1999 2          3            33.172971  0.626514 -1.387980  154.906266   82.240872  4.118010  21.424700
           3          0            33.172971  0.626514 -1.949625  154.906266   99.113831  4.118010  21.424700
                      1            33.172971  0.626514 -1.949625  154.906266   83.297361  4.118010  21.424700
                      2            33.172971  0.626514 -1.949625  154.906266   97.532364  4.118010  21.424700
                      3            33.172971  0.626514 -1.949625  154.906266   82.240872  4.118010  21.424700

[32000 rows x 7 columns]

whereas InferenceData.to_dataframe doesn't:

julia> idata.to_dataframe()
PyObject       chain  draw  (posterior, alpha)  ...  (sample_stats, log_likelihood[2], 1)  (sample_stats, log_likelihood[3], 2)  (sample_stats, log_likelihood[4], 3)
0         0     0          -20.889500  ...                             -2.988165                             -3.157347                             -3.069121
1         0     1          -28.163740  ...                             -3.684156                             -2.864150                             -2.868946
2         0     2          -22.198927  ...                             -4.346044                             -3.470297                             -3.528483
3         0     3          -22.741758  ...                             -4.127992                             -2.929625                             -2.917777
4         0     4          -17.740277  ...                             -3.036136                             -3.279647                             -3.211577
...     ...   ...                 ...  ...                                   ...                                   ...                                   ...
1995      0  1995           15.058016  ...                             -3.301039                             -2.897864                             -3.423978
1996      0  1996           27.261025  ...                             -4.253844                             -2.772573                             -2.825215
1997      0  1997          -19.063158  ...                             -2.976555                             -2.954317                             -2.954626
1998      0  1998           14.574519  ...                             -2.963257                             -3.106524                             -2.983051
1999      0  1999           33.172971  ...                             -3.725655                             -2.753599                             -2.991688

[2000 rows x 19 columns]

How useful have you found these in practice?

@ahartikainen
Copy link

I'm wrong guy to ask that :D

We should ask this from people who use the dataframe.

@sethaxen
Copy link
Member Author

sethaxen commented Aug 8, 2022

Closed by #191

@sethaxen sethaxen closed this as completed Aug 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants