From 064457e34d3041bc3886b66a2707b94f5554aac4 Mon Sep 17 00:00:00 2001 From: Osvaldo A Martin Date: Sun, 29 Dec 2024 08:11:29 -0300 Subject: [PATCH] Adds get_variable_inclusion function (#214) * add get_variable_inclusion function * add elements to API reference --- docs/api_reference.rst | 2 +- pymc_bart/__init__.py | 2 ++ pymc_bart/utils.py | 68 +++++++++++++++++++++++++++++++----------- 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 93afde1..b6fb8a5 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -13,4 +13,4 @@ methods in the current release of PyMC-BART. ============================= .. automodule:: pymc_bart - :members: BART, PGBART, plot_pdp, plot_ice, plot_variable_importance, plot_convergence, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule + :members: BART, PGBART, compute_variable_importance, get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, plot_scatter_submodels, plot_variable_importance, plot_variable_inclusion, ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index 361be83..f4a1f7a 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -18,6 +18,7 @@ from pymc_bart.split_rules import ContinuousSplitRule, OneHotSplitRule, SubsetSplitRule from pymc_bart.utils import ( compute_variable_importance, + get_variable_inclusion, plot_convergence, plot_ice, plot_pdp, @@ -33,6 +34,7 @@ "OneHotSplitRule", "SubsetSplitRule", "compute_variable_importance", + "get_variable_inclusion", "plot_convergence", "plot_ice", "plot_pdp", diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index 58d14b8..df8f76f 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -693,6 +693,50 @@ def _smooth_mean( return x_data, y_data +def get_variable_inclusion(idata, X, labels=None, to_kulprit=False): + """ + Get the normalized variable inclusion from BART model. + + Parameters + ---------- + idata : InferenceData + InferenceData containing a collection of BART_trees in sample_stats group + X : npt.NDArray + The covariate matrix. + labels : Optional[list[str]] + List of the names of the covariates. If X is a DataFrame the names of the covariables will + be taken from it and this argument will be ignored. + to_kulprit : bool + If True, the function will return a list of list with the variables names. + This list can be passed as a path to Kulprit's project method. Defaults to False. + Returns + ------- + VI_norm : npt.NDArray + Normalized variable inclusion. + labels : list[str] + List of the names of the covariates. + """ + VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values + VI_norm = VIs / VIs.sum() + idxs = np.argsort(VI_norm) + + indices = idxs[::-1] + n_vars = len(indices) + + if hasattr(X, "columns") and hasattr(X, "to_numpy"): + labels = X.columns + + if labels is None: + labels = np.arange(n_vars).astype(str) + + label_list = labels.to_list() + + if to_kulprit: + return [label_list[:idx] for idx in range(n_vars)] + else: + return VI_norm[indices], label_list + + def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=None, ax=None): """ Plot normalized variable inclusion from BART model. @@ -720,26 +764,15 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non Returns ------- - idxs: indexes of the covariates from higher to lower relative importance axes: matplotlib axes """ if plot_kwargs is None: plot_kwargs = {} - VIs = idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values - VIs = VIs / VIs.sum() - idxs = np.argsort(VIs) - - indices = idxs[::-1] - n_vars = len(indices) - - if hasattr(X, "columns") and hasattr(X, "to_numpy"): - labels = X.columns + VI_norm, labels = get_variable_inclusion(idata, X, labels) + n_vars = len(labels) - if labels is None: - labels = np.arange(n_vars).astype(str) - - new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + new_labels = ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)] ticks = np.arange(n_vars, dtype=int) @@ -749,19 +782,18 @@ def plot_variable_inclusion(idata, X, labels=None, figsize=None, plot_kwargs=Non if ax is None: _, ax = plt.subplots(1, 1, figsize=figsize) + ax.axhline(1 / n_vars, color="0.5", linestyle="--") ax.plot( - VIs[indices], + VI_norm, color=plot_kwargs.get("color", "k"), marker=plot_kwargs.get("marker", "o"), ls=plot_kwargs.get("ls", "-"), ) ax.set_xticks(ticks, new_labels, rotation=plot_kwargs.get("rotation", 0)) - - ax.axhline(1 / n_vars, color="0.5", linestyle="--") ax.set_ylim(0, 1) - return idxs, ax + return ax def compute_variable_importance( # noqa: PLR0915 PLR0912