Skip to content

Commit

Permalink
finish labeling changes in summary
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed May 26, 2020
1 parent bfeee90 commit 6ac1faa
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 64 deletions.
4 changes: 1 addition & 3 deletions arviz/plots/backends/bokeh/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from . import backend_kwarg_defaults
from .. import show_layout
from ...plot_utils import (
_create_axes_grid,
)
from ...plot_utils import _create_axes_grid
from ....sel_utils import make_label


Expand Down
4 changes: 1 addition & 3 deletions arviz/plots/backends/bokeh/rankplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from . import backend_kwarg_defaults
from .. import show_layout
from ...plot_utils import (
_create_axes_grid,
)
from ...plot_utils import _create_axes_grid
from ....sel_utils import make_label
from ....stats.stats_utils import histogram

Expand Down
4 changes: 1 addition & 3 deletions arviz/plots/backends/matplotlib/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from . import backend_show
from ...kdeplot import plot_kde
from ...plot_utils import (
_create_axes_grid,
)
from ...plot_utils import _create_axes_grid
from ....sel_utils import make_label
from ....numeric_utils import _fast_kde, histogram, get_bins

Expand Down
4 changes: 1 addition & 3 deletions arviz/plots/backends/matplotlib/rankplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import scipy.stats

from . import backend_show
from ...plot_utils import (
_create_axes_grid,
)
from ...plot_utils import _create_axes_grid
from ....sel_utils import make_label
from ....stats.stats_utils import histogram

Expand Down
11 changes: 7 additions & 4 deletions arviz/sel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import xarray as xr


def selection_to_string(selection):
"""Convert dictionary of coordinates to a string for labels.
Expand Down Expand Up @@ -41,12 +42,14 @@ def make_label(var_name, selection, position="below"):
if selection:
sel = selection_to_string(selection)
if position == "below":
sep = "\n"
base = "{}\n{}"
elif position == "beside":
sep = " "
base = "{}[{}]"
else:
sep = sel = ""
return "{}{}{}".format(var_name, sep, sel)
sel = ""
base = "{}{}"
return base.format(var_name, sel)


def purge_duplicates(list_in):
"""Remove duplicates from list while preserving order.
Expand Down
26 changes: 15 additions & 11 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def summary(
stat_funcs=None,
extend=True,
hdi_prob=None,
order="C",
order=None,
index_origin=None,
skipna=False,
coords: Optional[CoordSpec] = None,
Expand Down Expand Up @@ -1031,11 +1031,6 @@ def summary(
hdi_prob: float, optional
hdi interval to compute. Defaults to 0.94. This is only meaningful when ``stat_funcs`` is
None.
order: {"C", "F"}
If fmt is "wide", use either C or F unpacking order. Defaults to C.
index_origin: int
If fmt is "wide, select n-based indexing for multivariate parameters.
Defaults to rcParam data.index.origin, which is 0.
skipna: bool
If true ignores nan values when computing the summary statistics, it does not affect the
behaviour of the functions passed to ``stat_funcs``. Defaults to false.
Expand All @@ -1045,6 +1040,11 @@ def summary(
Dimensions specification for the variables to be used if the ``fmt`` is ``'xarray'``.
credible_interval: float, optional
deprecated: Please see hdi_prob
order
deprecated: order is now ignored.
index_origin
deprecated: index_origin is now ignored, modify the coordinate values to change the
value used in summary.
Returns
-------
Expand Down Expand Up @@ -1112,7 +1112,12 @@ def summary(
extra_args["coords"] = coords
if dims is not None:
extra_args["dims"] = dims
if index_origin is None:
if index_origin is not None:
warnings.warn(
"index_origin has been deprecated. summary now shows coordinate values, "
"to change the label shown, modify the coordinate values before calling sumary",
DeprecationWarning,
)
index_origin = rcParams["data.index_origin"]
if hdi_prob is None:
hdi_prob = rcParams["stats.hdi_prob"]
Expand All @@ -1127,10 +1132,9 @@ def summary(
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
raise TypeError("Invalid format: '{}'. Formatting options are: {}".format(fmt, fmt_group))

unpack_order_group = ("C", "F")
if not isinstance(order, str) or (order.upper() not in unpack_order_group):
raise TypeError(
"Invalid order: '{}'. Unpacking options are: {}".format(order, unpack_order_group)
if order is not None:
warnings.warn(
"order has been deprecated. summary now shows coordinate values.", DeprecationWarning
)

alpha = 1 - hdi_prob
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/base_tests/test_rcparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_stats_information_criterion(models):
assert "loo" in df_comp.columns


def test_http_type_request(models, monkeypatch):
def test_http_type_request(monkeypatch):
def _urlretrive(url, _):
raise Exception("URL Retrieved: {}".format(url))

Expand Down
60 changes: 24 additions & 36 deletions arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,40 +229,24 @@ def test_summary_fmt(centered_eight, fmt):
assert summary(centered_eight, fmt=fmt) is not None


@pytest.mark.parametrize("order", ["C", "F"])
def test_summary_unpack_order(order):
data = from_dict({"a": np.random.randn(4, 100, 4, 5, 3)})
az_summary = summary(data, order=order, fmt="wide")
def test_summary_labels():
coords1 = list("abcd")
coords2 = np.arange(1, 6)
data = from_dict(
{"a": np.random.randn(4, 100, 4, 5)},
coords={"dim1": coords1, "dim2": coords2},
dims={"a": ["dim1", "dim2"]},
)
az_summary = summary(data, fmt="wide")
assert az_summary is not None
if order != "F":
first_index = 4
second_index = 5
third_index = 3
else:
first_index = 3
second_index = 5
third_index = 4
column_order = []
for idx1 in range(first_index):
for idx2 in range(second_index):
for idx3 in range(third_index):
if order != "F":
column_order.append("a[{},{},{}]".format(idx1, idx2, idx3))
else:
column_order.append("a[{},{},{}]".format(idx3, idx2, idx1))
for coord1 in coords1:
for coord2 in coords2:
column_order.append("a[{}, {}]".format(coord1, coord2))
for col1, col2 in zip(list(az_summary.index), column_order):
assert col1 == col2


@pytest.mark.parametrize("origin", [0, 1, 2, 3])
def test_summary_index_origin(origin):
data = from_dict({"a": np.random.randn(2, 50, 10)})
az_summary = summary(data, index_origin=origin, fmt="wide")
assert az_summary is not None
for i, col in enumerate(list(az_summary.index)):
assert col == "a[{}]".format(i + origin)


@pytest.mark.parametrize(
"stat_funcs", [[np.var], {"var": np.var, "var2": lambda x: np.var(x) ** 2}]
)
Expand All @@ -274,12 +258,12 @@ def test_summary_stat_func(centered_eight, stat_funcs):

def test_summary_nan(centered_eight):
centered_eight = deepcopy(centered_eight)
centered_eight.posterior.theta[:, :, 0] = np.nan
centered_eight.posterior["theta"].loc[{"school": "Deerfield"}] = np.nan
summary_xarray = summary(centered_eight)
assert summary_xarray is not None
assert summary_xarray.loc["theta[0]"].isnull().all()
assert summary_xarray.loc["theta[Deerfield]"].isnull().all()
assert (
summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[0]"]]
summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[Deerfield]"]]
.notnull()
.all()
.all()
Expand All @@ -288,7 +272,7 @@ def test_summary_nan(centered_eight):

def test_summary_skip_nan(centered_eight):
centered_eight = deepcopy(centered_eight)
centered_eight.posterior.theta[:, :10, 1] = np.nan
centered_eight.posterior["theta"].loc[{"draw": slice(10), "school": "Deerfield"}] = np.nan
summary_xarray = summary(centered_eight)
theta_1 = summary_xarray.loc["theta[Deerfield]"].isnull()
assert summary_xarray is not None
Expand All @@ -302,10 +286,14 @@ def test_summary_bad_fmt(centered_eight, fmt):
summary(centered_eight, fmt=fmt)


@pytest.mark.parametrize("order", [1, "bad_order"])
def test_summary_bad_unpack_order(centered_eight, order):
with pytest.raises(TypeError):
summary(centered_eight, order=order)
def test_summary_order_deprecation(centered_eight):
with pytest.warns(DeprecationWarning, match="order"):
summary(centered_eight, order="C")


def test_summary_index_origin_deprecation(centered_eight):
with pytest.warns(DeprecationWarning, match="index_origin"):
summary(centered_eight, index_origin=1)


@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
Expand Down

0 comments on commit 6ac1faa

Please sign in to comment.