Skip to content

Commit

Permalink
BART: add partial dependence plots and individual conditional expecta…
Browse files Browse the repository at this point in the history
…tion plots (#5091)

* add utils for prediction and interpretability

* remove file

* add tests

* test mixed response

* fix tests

* update release notes

* remove unused import
  • Loading branch information
aloctavodia authored Nov 3, 2021
1 parent 80bf823 commit 52a126d
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 70 deletions.
4 changes: 4 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
- `pm.DensityDist` no longer accepts the `logp` as its first position argument. It is now an optional keyword argument. If you pass a callable as the first positional argument, a `TypeError` will be raised (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- `pm.DensityDist` now accepts distribution parameters as positional arguments. Passing them as a dictionary in the `observed` keyword argument is no longer supported and will raise an error (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- The signature of the `logp` and `random` functions that can be passed into a `pm.DensityDist` has been changed (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- Generalize BART. A BART variable can be combined with other random variables. The `inv_link` argument has been removed (see [4914](https://github.com/pymc-devs/pymc3/pull/4914)).
- Move BART to its own module (see [5058](https://github.com/pymc-devs/pymc3/pull/5058)).
- ...

### New Features
Expand All @@ -32,6 +34,8 @@
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc/pull/5004)
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc3/pull/5026)).
- BART: add linear response, increase number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
- BART: add partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
- ...

### Maintenance
Expand Down
1 change: 1 addition & 0 deletions pymc/bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@

from pymc.bart.bart import BART
from pymc.bart.pgbart import PGBART
from pymc.bart.utils import plot_dependence, predict

__all__ = ["BART", "PGBART"]
32 changes: 1 addition & 31 deletions pymc/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,35 +41,7 @@ def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):

@classmethod
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
size = kwargs.pop("size", None)
X_new = kwargs.pop("X_new", None)
all_trees = cls.all_trees
if all_trees:

if size is None:
size = ()
elif isinstance(size, int):
size = [size]

flatten_size = 1
for s in size:
flatten_size *= s

idx = rng.randint(len(all_trees), size=flatten_size)

if X_new is None:
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += tree.predict_output()
else:
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x, cls.m) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())
return np.full_like(cls.Y, cls.Y.mean())


bart = BARTRV()
Expand Down Expand Up @@ -117,15 +89,13 @@ def __new__(
**kwargs,
):

cls.all_trees = []
X, Y = preprocess_XY(X, Y)

bart_op = type(
f"BART_{name}",
(BARTRV,),
dict(
name="BART",
all_trees=cls.all_trees,
inplace=False,
initval=Y.mean(),
X=X,
Expand Down
28 changes: 11 additions & 17 deletions pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging

from copy import copy
from typing import Any, Dict, List, Tuple

import aesara
Expand Down Expand Up @@ -121,7 +122,7 @@ class PGBART(ArrayStepShared):
name = "bartsampler"
default_blocked = False
generates_stats = True
stats_dtypes = [{"variable_inclusion": np.ndarray}]
stats_dtypes = [{"variable_inclusion": np.ndarray, "bart_trees": np.ndarray}]

def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
_log.warning("BART is experimental. Use with caution.")
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
tree_id=0,
leaf_node_value=self.init_mean / self.m,
idx_data_points=np.arange(self.num_observations, dtype="int32"),
m=self.m,
)
self.mean = fast_mean()
self.linear_fit = fast_linear_fit()
Expand All @@ -169,8 +171,6 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo

self.tune = True
self.idx = 0
self.iter = 0
self.sum_trees = []
self.batch = batch

if self.batch == "auto":
Expand All @@ -193,12 +193,12 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.init_likelihood,
)
self.all_particles.append(p)
self.all_trees = np.array([p.tree for p in self.all_particles])
super().__init__(vars, shared)

def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
point_map_info = q.point_map_info
sum_trees_output = q.data

variable_inclusion = np.zeros(self.num_variates, dtype="int")

if self.idx == self.m:
Expand All @@ -212,7 +212,6 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
particles = self.init_particles(tree_id)
# Compute the sum of trees without the tree we are attempting to replace
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
self.idx += 1

# The old tree is not growing so we update the weights only once.
self.update_weight(particles[0])
Expand Down Expand Up @@ -258,6 +257,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
# Get the new tree and update
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
self.all_trees[self.idx] = new_tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree_id] = new_particle
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
Expand All @@ -268,17 +268,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.ssv = SampleSplittingVariable(self.split_prior)
else:
self.batch = max(1, int(self.m * 0.2))
self.iter += 1
self.sum_trees.append(new_tree)
if not self.iter % self.m:
# XXX update the all_trees variable in BARTRV to be used in the rng_fn method
# this fails for chains > 1 as the variable is not shared between proccesses
self.bart.all_trees.append(self.sum_trees)
self.sum_trees = []
for index in new_particle.used_variates:
variable_inclusion[index] += 1
self.idx += 1

stats = {"variable_inclusion": variable_inclusion}
stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
return sum_trees_output, [stats]

Expand Down Expand Up @@ -526,11 +520,11 @@ def linear_fit(X, Y):
xbar = np.sum(X) / n
ybar = np.sum(Y) / n

if np.all(X == xbar):
b = 0
den = X @ X - n * xbar ** 2
if den > 1e-10:
b = (X @ Y - n * xbar * ybar) / den
else:
b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)

b = 0
a = ybar - b * xbar
Y_fit = a + b * X
return Y_fit, [a, b, 0]
Expand Down
18 changes: 10 additions & 8 deletions pymc/bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,23 @@ class Tree:
Identifier used to get the previous tree in the ParticleGibbs algorithm used in BART.
num_observations : int
Number of observations used to fit BART.
m : int
Number of trees
Parameters
----------
tree_id : int, optional
num_observations : int, optional
"""

def __init__(self, tree_id=0, num_observations=0):
def __init__(self, tree_id=0, num_observations=0, m=0):
self.tree_structure = {}
self.num_nodes = 0
self.idx_leaf_nodes = []
self.idx_prunable_split_nodes = []
self.tree_id = tree_id
self.num_observations = num_observations
self.m = m

def __getitem__(self, index):
return self.get_node(index)
Expand Down Expand Up @@ -94,16 +96,14 @@ def predict_output(self):

return output.astype(aesara.config.floatX)

def predict_out_of_sample(self, X, m):
def predict_out_of_sample(self, X):
"""
Predict output of tree for an unobserved point x.
Parameters
----------
X : numpy array
Unobserved point
m : int
Number of trees
Returns
-------
Expand All @@ -116,7 +116,7 @@ def predict_out_of_sample(self, X, m):
return leaf_node.value
else:
x = X[split_variable].item()
y_x = (linear_params[0] + linear_params[1] * x) / m
y_x = (linear_params[0] + linear_params[1] * x) / self.m
return y_x + linear_params[2]

def _traverse_tree(self, x, node_index=0, split_variable=None):
Expand Down Expand Up @@ -170,20 +170,22 @@ def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_no
self.idx_prunable_split_nodes.remove(parent_index)

@staticmethod
def init_tree(tree_id, leaf_node_value, idx_data_points):
def init_tree(tree_id, leaf_node_value, idx_data_points, m):
"""
Parameters
----------
tree_id
leaf_node_value
idx_data_points
m : int
number of trees in BART
Returns
-------
"""
new_tree = Tree(tree_id, len(idx_data_points))
new_tree = Tree(tree_id, len(idx_data_points), m)
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
return new_tree

Expand Down
Loading

0 comments on commit 52a126d

Please sign in to comment.