Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add filter feature for var_names #1154

Merged
merged 21 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.x.x Unreleased

### New features
* Stats and plotting functions that provide `var_names` arg can now filter parameters based on partial naming (`filter="like"`) or regular expressions (`filter="regex"`) (see [#1154](https://github.com/arviz-devs/arviz/pull/1154)).
* Add `true_values` argument for `plot_pair`. It allows for a scatter plot showing the true values of the variables #1140
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds #1079
* Add out-of-sample groups (`predictions` and `predictions_constant_data`) and `constant_data` group to pyro translation #1090
Expand All @@ -24,7 +25,7 @@
* Updated benchmarks and moved to asv_benchmarks/benchmarks (#1142)
* Moved `_fast_kde`, `_fast_kde_2d`, `get_bins` and `_sturges_formula` to `numeric_utils` and `get_coords` to `utils` (#1142)
* Rank plot: rename `axes` argument to `ax` (#1144)
* Added a warning specifying log scale is now the default in compare/loo/waic functions ([#1150](https://github.com/arviz-devs/arviz/pull/1150))
* Added a warning specifying log scale is now the default in compare/loo/waic functions ([#1150](https://github.com/arviz-devs/arviz/pull/1150)).
* Fixed bug in `plot_posterior` with rcParam "plot.matplotlib.show" = True (#1151)

### Deprecation
Expand Down
2 changes: 1 addition & 1 deletion arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def dict_to_dataset(data, *, attrs=None, library=None, coords=None, dims=None):

Examples
--------
dict_to_dataset({'x': np.random.randn(4, 100), 'y', np.random.rand(4, 100)})
dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})

"""
if dims is None:
Expand Down
31 changes: 19 additions & 12 deletions arviz/plots/autocorrplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def plot_autocorr(
data,
var_names=None,
filter_vars=None,
max_lag=None,
combined=False,
figsize=None,
Expand All @@ -30,18 +31,24 @@ def plot_autocorr(

Parameters
----------
data : obj
data: obj
Any object that can be converted to an az.InferenceData object
Refer to documentation of az.convert_to_dataset for details
var_names : list of variable names, optional
Variables to be plotted, if None all variable are plotted.
Vector-value stochastics are handled automatically.
max_lag : int, optional
var_names: list of variable names, optional
Variables to be plotted, if None all variable are plotted. Prefix the
variables by `~` when you want to exclude them from the plot. Vector-value
stochastics are handled automatically.
filter_vars: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
`pandas.filter`.
max_lag: int, optional
Maximum lag to calculate autocorrelation. Defaults to 100 or num draws, whichever is smaller
combined : bool
combined: bool
Flag for combining multiple chains into a single chain. If False (default), chains will be
plotted separately.
figsize : tuple
figsize: tuple
Figure size. If None it will be defined automatically.
Note this is not used if ax is supplied.
textsize: float
Expand All @@ -57,12 +64,12 @@ def plot_autocorr(
backend_kwargs: dict, optional
These are kwargs specific to the backend being used. For additional documentation
check the plotting method of the backend.
show : bool, optional
show: bool, optional
Call backend show function.

Returns
-------
axes : matplotlib axes or bokeh figures
axes: matplotlib axes or bokeh figures

Examples
--------
Expand All @@ -83,12 +90,12 @@ def plot_autocorr(
>>> az.plot_autocorr(data, var_names=['mu', 'tau'] )


Combine chains collapsing by variable
Combine chains by variable and select variables by excluding some with partial naming

.. plot::
:context: close-figs

>>> az.plot_autocorr(data, var_names=['mu', 'tau'], combined=True)
>>> az.plot_autocorr(data, var_names=['~thet'], filter_vars="like", combined=True)


Specify maximum lag (x axis bound)
Expand All @@ -99,7 +106,7 @@ def plot_autocorr(
>>> az.plot_autocorr(data, var_names=['mu', 'tau'], max_lag=200, combined=True)
"""
data = convert_to_dataset(data, group="posterior")
var_names = _var_names(var_names, data)
var_names = _var_names(var_names, data, filter_vars)

# Default max lag to 100 or max length of chain
if max_lag is None:
Expand Down
49 changes: 28 additions & 21 deletions arviz/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
def plot_ess(
idata,
var_names=None,
filter_vars=None,
kind="local",
relative=False,
coords=None,
Expand All @@ -43,60 +44,66 @@ def plot_ess(

Parameters
----------
idata : obj
idata: obj
Any object that can be converted to an az.InferenceData object
Refer to documentation of az.convert_to_dataset for details
var_names : list of variable names, optional
Variables to be plotted.
kind : str, optional
var_names: list of variable names, optional
Variables to be plotted. Prefix the variables by `~` when you want to exclude
them from the plot.
filter_vars: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
`pandas.filter`.
kind: str, optional
Options: ``local``, ``quantile`` or ``evolution``, specify the kind of plot.
relative : bool
relative: bool
Show relative ess in plot ``ress = ess / N``.
coords : dict, optional
coords: dict, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
figsize : tuple, optional
figsize: tuple, optional
Figure size. If None it will be defined automatically.
textsize: float, optional
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
on figsize.
rug : bool
rug: bool
Plot rug plot of values diverging or that reached the max tree depth.
rug_kind : bool
rug_kind: bool
Variable in sample stats to use as rug mask. Must be a boolean variable.
n_points : int
n_points: int
Number of points for which to plot their quantile/local ess or number of subsets
in the evolution plot.
extra_methods : bool, optional
extra_methods: bool, optional
Plot mean and sd ESS as horizontal lines. Not taken into account in evolution kind
min_ess : int
min_ess: int
Minimum number of ESS desired.
ax: numpy array-like of matplotlib axes or bokeh figures, optional
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
its own array of plot areas (and return it).
extra_kwargs : dict, optional
extra_kwargs: dict, optional
If evolution plot, extra_kwargs is used to plot ess tail and differentiate it
from ess bulk. Otherwise, passed to extra methods lines.
text_kwargs : dict, optional
text_kwargs: dict, optional
Only taken into account when ``extra_methods=True``. kwargs passed to ax.annotate
for extra methods lines labels. It accepts the additional
key ``x`` to set ``xy=(text_kwargs["x"], mcse)``
hline_kwargs : dict, optional
hline_kwargs: dict, optional
kwargs passed to ax.axhline for the horizontal minimum ESS line.
rug_kwargs : dict
rug_kwargs: dict
kwargs passed to rug plot.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
These are kwargs specific to the backend being used. For additional documentation
check the plotting method of the backend.
show : bool, optional
show: bool, optional
Call backend show function.
**kwargs
Passed as-is to plt.hist() or plt.plot() function depending on the value of `kind`.

Returns
-------
axes : matplotlib axes or bokeh figures
axes: matplotlib axes or bokeh figures

References
----------
Expand All @@ -119,13 +126,13 @@ def plot_ess(
... idata, kind="local", var_names=["mu", "theta"], coords=coords
... )

Plot quantile ESS.
Plot quantile ESS and exclude variables with partial naming

.. plot::
:context: close-figs

>>> az.plot_ess(
... idata, kind="quantile", var_names=["mu", "theta"], coords=coords
... idata, kind="quantile", var_names=['~thet'], filter_vars="like", coords=coords
... )

Plot ESS evolution as the number of samples increase. When the model is converging properly,
Expand Down Expand Up @@ -172,7 +179,7 @@ def plot_ess(
extra_methods = False if kind == "evolution" else extra_methods

data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
var_names = _var_names(var_names, data)
var_names = _var_names(var_names, data, filter_vars)
n_draws = data.dims["draw"]
n_samples = n_draws * data.dims["chain"]

Expand Down
54 changes: 31 additions & 23 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def plot_forest(
kind="forestplot",
model_names=None,
var_names=None,
filter_vars=None,
transform=None,
coords=None,
combined=False,
Expand Down Expand Up @@ -40,61 +41,67 @@ def plot_forest(

Parameters
----------
data : obj or list[obj]
data: obj or list[obj]
Any object that can be converted to an az.InferenceData object
Refer to documentation of az.convert_to_dataset for details
kind : str
kind: str
Choose kind of plot for main axis. Supports "forestplot" or "ridgeplot"
model_names : list[str], optional
model_names: list[str], optional
List with names for the models in the list of data. Useful when
plotting more that one dataset
var_names: list[str], optional
List of variables to plot (defaults to None, which results in all
variables plotted)
transform : callable
variables plotted) Prefix the variables by `~` when you want to exclude them
from the plot.
filter_vars: {None, "like", "regex"}, optional, default=None
If `None` (default), interpret var_names as the real variables names. If "like",
interpret var_names as substrings of the real variables names. If "regex",
interpret var_names as regular expressions on the real variables names. A la
`pandas.filter`.
transform: callable
Function to transform data (defaults to None i.e.the identity function)
coords : dict, optional
coords: dict, optional
Coordinates of var_names to be plotted. Passed to `Dataset.sel`
combined : bool
combined: bool
Flag for combining multiple chains into a single chain. If False (default),
chains will be plotted separately.
credible_interval : float, optional
credible_interval: float, optional
Credible interval to plot. Defaults to 0.94.
rope: tuple or dictionary of tuples
Lower and upper values of the Region Of Practical Equivalence. If a list with one
interval only is provided, the ROPE will be displayed across the y-axis. If more than one
interval is provided the length of the list should match the number of variables.
quartiles : bool, optional
quartiles: bool, optional
Flag for plotting the interquartile range, in addition to the credible_interval intervals.
Defaults to True
r_hat : bool, optional
r_hat: bool, optional
Flag for plotting Split R-hat statistics. Requires 2 or more chains. Defaults to False
ess : bool, optional
ess: bool, optional
Flag for plotting the effective sample size. Defaults to False
colors : list or string, optional
colors: list or string, optional
list with valid matplotlib colors, one color per model. Alternative a string can be passed.
If the string is `cycle`, it will automatically chose a color per model from the
matplotlibs cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used
for all models. Defauls to 'cycle'.
textsize: float
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
on figsize.
linewidth : int
linewidth: int
Line width throughout. If None it will be autoscaled based on figsize.
markersize : int
markersize: int
Markersize throughout. If None it will be autoscaled based on figsize.
ridgeplot_alpha : float
ridgeplot_alpha: float
Transparency for ridgeplot fill. If 0, border is colored by model, otherwise
a black outline is used.
ridgeplot_overlap : float
ridgeplot_overlap: float
Overlap height for ridgeplots.
ridgeplot_kind : string
ridgeplot_kind: string
By default ("auto") continuous variables are plotted using KDEs and discrete ones using
histograms. To override this use "hist" to plot histograms and "density" for KDEs
ridgeplot_quantiles : list
ridgeplot_quantiles: list
Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
Defaults to None.
figsize : tuple
figsize: tuple
Figure size. If None it will be defined automatically.
ax: axes, optional
Matplotlib axes or bokeh figures.
Expand All @@ -105,12 +112,12 @@ def plot_forest(
backend_kwargs: bool, optional
These are kwargs specific to the backend being used. For additional documentation
check the plotting method of the backend.
show : bool, optional
show: bool, optional
Call backend show function.

Returns
-------
gridspec : matplotlib GridSpec or bokeh figures
gridspec: matplotlib GridSpec or bokeh figures

Examples
--------
Expand All @@ -123,7 +130,8 @@ def plot_forest(
>>> non_centered_data = az.load_arviz_data('non_centered_eight')
>>> axes = az.plot_forest(non_centered_data,
>>> kind='forestplot',
>>> var_names=['theta'],
>>> var_names=["^the"],
>>> filter_vars="regex",
>>> combined=True,
>>> ridgeplot_overlap=3,
>>> figsize=(9, 7))
Expand Down Expand Up @@ -156,7 +164,7 @@ def plot_forest(
datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords
)

var_names = _var_names(var_names, datasets)
var_names = _var_names(var_names, datasets, filter_vars)

ncols, width_ratios = 1, [3]

Expand Down
Loading