Skip to content

Commit

Permalink
Fix bug with labels in variable importance, add reference line, remov…
Browse files Browse the repository at this point in the history
…e deprecation warning (#207)

* fix bug labels variable importance, add reference line

* revert change
  • Loading branch information
aloctavodia authored Dec 20, 2024
1 parent 77116d1 commit 1ec251b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
4 changes: 1 addition & 3 deletions pymc_bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pymc_bart.utils import (
compute_variable_importance,
plot_convergence,
plot_dependence,
plot_ice,
plot_pdp,
plot_scatter_submodels,
Expand All @@ -35,14 +34,13 @@
"SubsetSplitRule",
"compute_variable_importance",
"plot_convergence",
"plot_dependence",
"plot_ice",
"plot_pdp",
"plot_scatter_submodels",
"plot_variable_importance",
"plot_variable_inclusion",
]
__version__ = "0.8.0"
__version__ = "0.8.1"


pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]
32 changes: 14 additions & 18 deletions pymc_bart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,6 @@ def plot_convergence(
return ax


def plot_dependence(*args, kind="pdp", **kwargs): # pylint: disable=unused-argument
"""
Partial dependence or individual conditional expectation plot.
"""
if kind == "pdp":
warnings.warn(
"This function has been deprecated. Use plot_pdp instead.",
FutureWarning,
)
elif kind == "ice":
warnings.warn(
"This function has been deprecated. Use plot_ice instead.",
FutureWarning,
)


def plot_ice(
bartrv: Variable,
X: npt.NDArray[np.float64],
Expand Down Expand Up @@ -307,6 +291,7 @@ def plot_pdp(
var_discrete: Optional[list[int]] = None,
func: Optional[Callable] = None,
samples: int = 200,
ref_line: bool = True,
random_seed: Optional[int] = None,
sharey: bool = True,
smooth: bool = True,
Expand Down Expand Up @@ -347,6 +332,8 @@ def plot_pdp(
Arbitrary function to apply to the predictions. Defaults to the identity function.
samples : int
Number of posterior samples used in the predictions. Defaults to 200
ref_line : bool
If True a reference line is plotted at the mean of the partial dependence. Defaults to True.
random_seed : Optional[int], by default None.
Seed used to sample from the posterior. Defaults to None.
sharey : bool
Expand Down Expand Up @@ -402,6 +389,7 @@ def identity(x):

count = 0
fake_X = _create_pdp_data(X, xs_interval, xs_values)
null_pd = []
for var in range(len(var_idx)):
excluded = indices[:]
excluded.remove(var)
Expand All @@ -413,6 +401,7 @@ def identity(x):
new_x = fake_X[:, var]
for s_i in range(shape):
p_di = func(p_d[:, :, s_i])
null_pd.append(p_di.mean())
if var in var_discrete:
_, idx_uni = np.unique(new_x, return_index=True)
y_means = p_di.mean(0)[idx_uni]
Expand Down Expand Up @@ -442,6 +431,11 @@ def identity(x):

count += 1

if ref_line:
ref_val = sum(null_pd) / len(null_pd)
for ax_ in np.ravel(axes):
ax_.axhline(ref_val, color="0.7", linestyle="--")

fig.text(-0.05, 0.5, y_label, va="center", rotation="vertical", fontsize=15)

return axes
Expand Down Expand Up @@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912

indices = least_important_vars[::-1]

labels = np.array(["+ " + ele if index != 0 else ele for index, ele in enumerate(labels)])
labels = np.array(
["+ " + ele if index != 0 else ele for index, ele in enumerate(labels[indices])]
)

vi_results = {
"indices": np.asarray(indices),
"labels": labels[indices],
"labels": labels,
"r2_mean": r2_mean,
"r2_hdi": r2_hdi,
"preds": preds,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ line-length = 100
select = ["E", "F", "I", "PL", "UP", "W"]
ignore = [
"PLR2004", # Checks for the use of unnamed numerical constants ("magic") values in comparisons.
"PLR0913", #Too many arguments in function definition

]

[tool.ruff.lint.pylint]
Expand Down

0 comments on commit 1ec251b

Please sign in to comment.