Skip to content

Commit

Permalink
Adds get_variable_inclusion function (#214)
Browse files Browse the repository at this point in the history
* add get_variable_inclusion function

* add elements to API reference
  • Loading branch information
aloctavodia authored Dec 29, 2024
1 parent 3bad2c6 commit 064457e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART.
=============================

.. automodule:: pymc_bart
:members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
:members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
2 changes: 2 additions & 0 deletions pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule
from pymc_bart.utils import (
compute_variable_importance,
get_variable_inclusion,
plot_convergence,
plot_ice,
plot_pdp,
Expand All @@ -33,6 +34,7 @@
"OneHotSplitRule",
"SubsetSplitRule",
"compute_variable_importance",
"get_variable_inclusion",
"plot_convergence",
"plot_ice",
"plot_pdp",
Expand Down
68 changes: 50 additions & 18 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,50 @@ def _smooth_mean(
return x_data, y_data


def get_variable_inclusion(idata, X, labels=None, to_kulprit=False):
"""
Get the normalized variable inclusion from BART model.
Parameters
----------
idata : InferenceData
InferenceData containing a collection of BART_trees in sample_stats group
X : npt.NDArray
The covariate matrix.
labels : Optional[list[str]]
List of the names of the covariates. If X is a DataFrame the names of the covariables will
be taken from it and this argument will be ignored.
to_kulprit : bool
If True, the function will return a list of list with the variables names.
This list can be passed as a path to Kulprit's project method. Defaults to False.
Returns
-------
VI_norm : npt.NDArray
Normalized variable inclusion.
labels : list[str]
List of the names of the covariates.
"""
VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
VI_norm = VIs / VIs.sum()
idxs = np.argsort(VI_norm)

indices = idxs[::-1]
n_vars = len(indices)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns

if labels is None:
labels = np.arange(n_vars).astype(str)

label_list = labels.to_list()

if to_kulprit:
return [label_list[:idx] for idx in range(n_vars)]
else:
return VI_norm[indices], label_list


def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None):
"""
Plot normalized variable inclusion from BART model.
Expand Down Expand Up @@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
Returns
-------
idxs: indexes of the covariates from higher to lower relative importance
axes: matplotlib axes
"""
if plot_kwargs is None:
plot_kwargs = {}

VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
VIs = VIs / VIs.sum()
idxs = np.argsort(VIs)

indices = idxs[::-1]
n_vars = len(indices)

if hasattr(X, "columns") and hasattr(X, "to_numpy"):
labels = X.columns
VI_norm, labels = get_variable_inclusion(idata, X, labels)
n_vars = len(labels)

if labels is None:
labels = np.arange(n_vars).astype(str)

new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]

ticks = np.arange(n_vars, dtype=int)

Expand All @@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non
if ax is None:
_, ax = plt.subplots(1, 1, figsize=figsize)

ax.axhline(1 / n_vars, color="0.5", linestyle="--")
ax.plot(
VIs[indices],
VI_norm,
color=plot_kwargs.get("color", "k"),
marker=plot_kwargs.get("marker", "o"),
ls=plot_kwargs.get("ls", "-"),
)

ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0))

ax.axhline(1 / n_vars, color="0.5", linestyle="--")
ax.set_ylim(0, 1)

return idxs, ax
return ax


def compute_variable_importance( # noqa: PLR0915 PLR0912
Expand Down

0 comments on commit 064457e

Please sign in to comment.