Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back BART to V4 and make it more general #4914

Merged
merged 26 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b0025b0
frowardporting from unreleased v3 plus generalization
aloctavodia Aug 5, 2021
4224969
aesarize
aloctavodia Aug 6, 2021
750ca75
improve docstrings
aloctavodia Aug 7, 2021
f9a0937
small fix docstring and variable names
aloctavodia Aug 9, 2021
05ce929
fix format variable importance
aloctavodia Aug 9, 2021
412845d
fix broadcasting issue and other minor fixes
aloctavodia Aug 10, 2021
139821c
add test and fix pylint
aloctavodia Aug 10, 2021
f0037c1
fix float32
aloctavodia Aug 11, 2021
afff2be
sample splitting variables non-uniformly
aloctavodia Aug 11, 2021
f4b7b13
remove xfail
aloctavodia Aug 12, 2021
45381ca
add back xfail on windows
aloctavodia Aug 12, 2021
a07f050
add back xfail on windows and for float32
aloctavodia Aug 12, 2021
35bc056
fix test
aloctavodia Aug 12, 2021
8f498fc
clean rnd
aloctavodia Aug 13, 2021
0bfaeba
add size argument and check for NoDistribution
aloctavodia Aug 17, 2021
c7da202
stop updating split_prior after tuning
aloctavodia Aug 17, 2021
4985e55
clean code and small speed-up
aloctavodia Aug 18, 2021
51b2c4d
clean code and small speed-up
aloctavodia Aug 18, 2021
14d2128
revert xfail
aloctavodia Aug 19, 2021
d1982dc
add tests
aloctavodia Aug 19, 2021
c1c3a0b
fix number of chains
aloctavodia Aug 19, 2021
9ea259a
revert test
aloctavodia Aug 19, 2021
6cebc10
clean code, refactor and small speed-up
aloctavodia Aug 23, 2021
efa8dc6
test random
aloctavodia Sep 3, 2021
64496cd
test random
aloctavodia Sep 3, 2021
d67c9a3
add missing data test
aloctavodia Sep 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 96 additions & 240 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,271 +14,127 @@

import numpy as np

from pandas import DataFrame, Series
from aesara.tensor.random.op import RandomVariable, default_shape_from_params

from pymc3.distributions.distribution import NoDistribution
from pymc3.distributions.tree import LeafNode, SplitNode, Tree

__all__ = ["BART"]


class BaseBART(NoDistribution):
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, *args, **kwargs):

self.X, self.Y, self.missing_data = self.preprocess_XY(X, Y)

super().__init__(shape=X.shape[0], dtype="float64", initval=0, *args, **kwargs)

if self.X.ndim != 2:
raise ValueError("The design matrix X must have two dimensions")

if self.Y.ndim != 1:
raise ValueError("The response matrix Y must have one dimension")
if self.X.shape[0] != self.Y.shape[0]:
raise ValueError(
"The design matrix X and the response matrix Y must have the same number of elements"
)
if not isinstance(m, int):
raise ValueError("The number of trees m type must be int")
if m < 1:
raise ValueError("The number of trees m must be greater than zero")

if alpha <= 0 or 1 <= alpha:
raise ValueError(
"The value for the alpha parameter for the tree structure "
"must be in the interval (0, 1)"
)

self.num_observations = X.shape[0]
self.num_variates = X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
self.m = m
self.alpha = alpha
self.trees = self.init_list_of_trees()
self.all_trees = []
self.mean = fast_mean()
self.prior_prob_leaf_node = compute_prior_probability(alpha)

def preprocess_XY(self, X, Y):
if isinstance(Y, (Series, DataFrame)):
Y = Y.to_numpy()
if isinstance(X, (Series, DataFrame)):
X = X.to_numpy()
missing_data = np.any(np.isnan(X))
X = np.random.normal(X, np.std(X, 0) / 100)
return X, Y, missing_data

def init_list_of_trees(self):
initial_value_leaf_nodes = self.Y.mean() / self.m
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
list_of_trees = []
for i in range(self.m):
new_tree = Tree.init_tree(
tree_id=i,
leaf_node_value=initial_value_leaf_nodes,
idx_data_points=initial_idx_data_points_leaf_nodes,
)
list_of_trees.append(new_tree)
# Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J.
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
# The sum_trees_output will contain the sum of the predicted output for all trees.
# When R_j is needed we subtract the current predicted output for tree T_j.
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())

return list_of_trees

def __iter__(self):
return iter(self.trees)

def __repr_latex(self):
raise NotImplementedError

def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
x_j = self.X[idx_data_points_split_node, idx_split_variable]
if self.missing_data:
x_j = x_j[~np.isnan(x_j)]
values = np.unique(x_j)
# The last value is never available as it would leave the right subtree empty.
return values[:-1]

def grow_tree(self, tree, index_leaf_node):
current_node = tree.get_node(index_leaf_node)

index_selected_predictor = self.ssv.rvs()
selected_predictor = self.available_predictors[index_selected_predictor]
available_splitting_rules = self.get_available_splitting_rules(
current_node.idx_data_points, selected_predictor
)
# This can be unsuccessful when there are not available splitting rules
if available_splitting_rules.size == 0:
return False, None

index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]
new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=selected_splitting_rule,
)

left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points(
new_split_node, current_node.idx_data_points
)

left_node_value = self.draw_leaf_value(left_node_idx_data_points)
right_node_value = self.draw_leaf_value(right_node_idx_data_points)

new_left_node = LeafNode(
index=current_node.get_idx_left_child(),
value=left_node_value,
idx_data_points=left_node_idx_data_points,
)
new_right_node = LeafNode(
index=current_node.get_idx_right_child(),
value=right_node_value,
idx_data_points=right_node_idx_data_points,
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)

return True, index_selected_predictor

def get_new_idx_data_points(self, current_split_node, idx_data_points):
idx_split_variable = current_split_node.idx_split_variable
split_value = current_split_node.split_value

left_idx = self.X[idx_data_points, idx_split_variable] <= split_value
left_node_idx_data_points = idx_data_points[left_idx]
right_node_idx_data_points = idx_data_points[~left_idx]

return left_node_idx_data_points, right_node_idx_data_points

def get_residuals(self):
"""Compute the residuals."""
R_j = self.Y - self.sum_trees_output
return R_j

def get_residuals_loo(self, tree):
"""Compute the residuals without leaving the passed tree out."""
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
return R_j

def draw_leaf_value(self, idx_data_points):
"""Draw the residual mean."""
R_j = self.get_residuals()[idx_data_points]
draw = self.mean(R_j)
return draw

def predict(self, X_new):
"""Compute out of sample predictions evaluated at X_new"""
trees = self.all_trees
num_observations = X_new.shape[0]
pred = np.zeros((len(trees), num_observations))
np.random.randint(len(trees))
for draw, trees_to_sum in enumerate(trees):
new_Y = np.zeros(num_observations)
for tree in trees_to_sum:
new_Y += [tree.predict_out_of_sample(x) for x in X_new]
pred[draw] = new_Y
return pred


def compute_prior_probability(alpha):
class BARTRV(RandomVariable):
"""
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 19 in [Rockova2018].

Parameters
----------
alpha : float

Returns
-------
list with probabilities for leaf nodes

References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
Base class for BART
"""
prior_leaf_prob = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha ** depth)
depth += 1
return prior_leaf_prob


def fast_mean():
"""If available use Numba to speed up the computation of the mean."""
try:
from numba import jit
except ImportError:
return np.mean

@jit
def mean(a):
count = a.shape[0]
suma = 0
for i in range(count):
suma += a[i]
return suma / count

return mean

name = "BART"
ndim_supp = 1
ndims_params = [2, 1, 0, 0, 0, 1]
dtype = "floatX"
_print_name = ("BART", "\\operatorname{BART}")
all_trees = None

def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)

@classmethod
def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
size = kwargs.pop("size", None)
X_new = kwargs.pop("X_new", None)
aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
all_trees = cls.all_trees
if all_trees:

if size is None:
size = ()
elif isinstance(size, int):
size = [size]

flatten_size = 1
for s in size:
flatten_size *= s

idx = rng.randint(len(all_trees), size=flatten_size)

if X_new is None:
pred = np.zeros((flatten_size, all_trees[0][0].num_observations))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += tree.predict_output()
else:
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())

def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value)."""
return int(np.random.random() * upper_value)


class SampleSplittingVariable:
def __init__(self, prior, num_variates):
self.prior = prior
self.num_variates = num_variates

if self.prior is not None:
self.prior = np.asarray(self.prior)
self.prior = self.prior / self.prior.sum()
if self.prior.size != self.num_variates:
raise ValueError(
f"The size of split_prior ({self.prior.size}) should be the "
f"same as the number of covariates ({self.num_variates})"
)
self.enu = list(enumerate(np.cumsum(self.prior)))

def rvs(self):
if self.prior is None:
return int(np.random.random() * self.num_variates)
else:
r = np.random.random()
for i, v in self.enu:
if r <= v:
return i
bart = BARTRV()


class BART(BaseBART):
class BART(NoDistribution):
"""
BART distribution.
Bayesian Additive Regression Tree distribution.

Distribution representing a sum over trees

Parameters
----------
X : array-like
The design matrix.
The covariate matrix.
Y : array-like
The response vector.
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
Control the prior probability over the depth of the trees. Even when it can takes values in
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
k : float
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
and 3.
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Defaults to None, all variable have the same a prior probability
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Defaults to None, i.e. all covariates have the same prior probability to be selected.
"""

def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None):
super().__init__(X, Y, m, alpha, split_prior)
def __new__(
cls,
name,
X,
Y,
m=50,
alpha=0.25,
k=2,
split_prior=None,
**kwargs,
):

cls.all_trees = []

bart_op = type(
f"BART_{name}",
(BARTRV,),
dict(
name="BART",
all_trees=cls.all_trees,
inplace=False,
initval=Y.mean(),
X=X,
Y=Y,
m=m,
alpha=alpha,
k=k,
split_prior=split_prior,
),
)()

NoDistribution.register(BARTRV)

cls.rv_op = bart_op
params = [X, Y, m, alpha, k]
return super().__new__(cls, name, *params, **kwargs)

@classmethod
def dist(cls, *params, **kwargs):
return super().dist(params, **kwargs)
Loading