diff --git a/pymc_bart/__init__.py b/pymc_bart/__init__.py index eee1881..440f7f2 100644 --- a/pymc_bart/__init__.py +++ b/pymc_bart/__init__.py @@ -19,7 +19,6 @@ from pymc_bart.utils import ( compute_variable_importance, plot_convergence, - plot_dependence, plot_ice, plot_pdp, plot_scatter_submodels, @@ -35,14 +34,13 @@ "SubsetSplitRule", "compute_variable_importance", "plot_convergence", - "plot_dependence", "plot_ice", "plot_pdp", "plot_scatter_submodels", "plot_variable_importance", "plot_variable_inclusion", ] -__version__ = "0.8.0" +__version__ = "0.8.1" pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART] diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index e10a511..d9738dd 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -137,22 +137,6 @@ def plot_convergence( return ax -def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument - """ - Partial dependence or individual conditional expectation plot. - """ - if kind == "pdp": - warnings.warn( - "This function has been deprecated. Use plot_pdp instead.", - FutureWarning, - ) - elif kind == "ice": - warnings.warn( - "This function has been deprecated. Use plot_ice instead.", - FutureWarning, - ) - - def plot_ice( bartrv: Variable, X: npt.NDArray[np.float64], @@ -307,6 +291,7 @@ def plot_pdp( var_discrete: Optional[list[int]] = None, func: Optional[Callable] = None, samples: int = 200, + ref_line: bool = True, random_seed: Optional[int] = None, sharey: bool = True, smooth: bool = True, @@ -347,6 +332,8 @@ def plot_pdp( Arbitrary function to apply to the predictions. Defaults to the identity function. samples : int Number of posterior samples used in the predictions. Defaults to 200 + ref_line : bool + If True a reference line is plotted at the mean of the partial dependence. Defaults to True. random_seed : Optional[int], by default None. Seed used to sample from the posterior. Defaults to None. sharey : bool @@ -402,6 +389,7 @@ def identity(x): count = 0 fake_X = _create_pdp_data(X, xs_interval, xs_values) + null_pd = [] for var in range(len(var_idx)): excluded = indices[:] excluded.remove(var) @@ -413,6 +401,7 @@ def identity(x): new_x = fake_X[:, var] for s_i in range(shape): p_di = func(p_d[:, :, s_i]) + null_pd.append(p_di.mean()) if var in var_discrete: _, idx_uni = np.unique(new_x, return_index=True) y_means = p_di.mean(0)[idx_uni] @@ -442,6 +431,11 @@ def identity(x): count += 1 + if ref_line: + ref_val = sum(null_pd) / len(null_pd) + for ax_ in np.ravel(axes): + ax_.axhline(ref_val, color="0.7", linestyle="--") + fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15) return axes @@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 indices = least_important_vars[::-1] - labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)]) + labels = np.array( + ["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])] + ) vi_results = { "indices": np.asarray(indices), - "labels": labels[indices], + "labels": labels, "r2_mean": r2_mean, "r2_hdi": r2_hdi, "preds": preds, diff --git a/pyproject.toml b/pyproject.toml index bc94137..f8f3e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ line-length = 100 select = ["E", "F", "I", "PL", "UP", "W"] ignore = [ "PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons. + "PLR0913", #Too many arguments in function definition + ] [tool.ruff.lint.pylint]