Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bayesian Additive Regression Trees (BARTs) #4183

Merged
merged 26 commits into from
Nov 14, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
273187b
update from master
aloctavodia Jan 25, 2020
9d4f73f
black
aloctavodia Jan 25, 2020
beaf184
minor fix
aloctavodia Jan 27, 2020
f11a57f
clean code
aloctavodia Jan 30, 2020
43fed87
blackify
aloctavodia Oct 9, 2020
f847f66
fix error residuals
aloctavodia Oct 9, 2020
6700a74
use a low number of max_stages for the first iteration, remove not ne…
aloctavodia Oct 16, 2020
ac96b1a
use Rockova prior, refactor prior leaf prob computaion
aloctavodia Oct 16, 2020
7d54bfa
clean code add docstring
aloctavodia Oct 19, 2020
b566d50
reduce code
aloctavodia Oct 20, 2020
0ff5833
speed-up by fitting a subset of trees per step
aloctavodia Oct 21, 2020
3419e70
choose max
aloctavodia Oct 21, 2020
51165d4
improve docstrings
aloctavodia Oct 21, 2020
03758a4
refactor and clean code
aloctavodia Oct 22, 2020
c3c3929
clean docstrings
aloctavodia Oct 22, 2020
6a58daa
add tests and minor fixes.
aloctavodia Oct 23, 2020
9050469
remove space.
aloctavodia Oct 23, 2020
5fdd999
add variable importance report
aloctavodia Oct 29, 2020
acc5290
use ValueError
aloctavodia Nov 4, 2020
7ac976b
wip return mean and std variable importance
aloctavodia Nov 6, 2020
78b6f79
update variable importance report
aloctavodia Nov 9, 2020
2dda3b0
update release notes, remove vi hdi report
aloctavodia Nov 14, 2020
2050958
Merge branch 'master' into BART
aloctavodia Nov 14, 2020
bb69a76
test variable importance
aloctavodia Nov 14, 2020
a473028
Merge branch 'BART' of https://github.com/aloctavodia/pymc3 into BART
aloctavodia Nov 14, 2020
5a7b552
fix test
aloctavodia Nov 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@
from .timeseries import MvGaussianRandomWalk
from .timeseries import MvStudentTRandomWalk

from .bart import BART

from .bound import Bound


__all__ = [
"Uniform",
"Flat",
Expand Down Expand Up @@ -175,4 +178,5 @@
"Moyal",
"Simulator",
"fast_sample_posterior_predictive",
"BART",
]
257 changes: 257 additions & 0 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from .distribution import NoDistribution
from .tree import Tree, SplitNode, LeafNode

__all__ = ["BART"]


class BaseBART(NoDistribution):
def __init__(self, X, Y, m=200, alpha=0.25, *args, **kwargs):
self.X = X
self.Y = Y
super().__init__(shape=X.shape[0], dtype="float64", testval=0, *args, **kwargs)

if self.X.ndim != 2:
raise BARTParamsError("The design matrix X must have two dimensions")

if self.Y.ndim != 1:
raise BARTParamsError("The response matrix Y must have one dimension")
if self.X.shape[0] != self.Y.shape[0]:
raise BARTParamsError(
"The design matrix X and the response matrix Y must have the same number of elements"
)
if not isinstance(m, int):
raise BARTParamsError("The number of trees m type must be int")
if m < 1:
raise BARTParamsError("The number of trees m must be greater than zero")

if alpha <= 0 or 1 <= alpha:
raise BARTParamsError(
"The value for the alpha parameter for the tree structure "
"must be in the interval (0, 1)"
)

self.num_observations = X.shape[0]
self.num_variates = X.shape[1]
self.m = m
self.alpha = alpha
self.trees = self.init_list_of_trees()
self.mean = fast_mean()
self.prior_prob_leaf_node = compute_prior_probability(alpha)

def init_list_of_trees(self):
initial_value_leaf_nodes = self.Y.mean() / self.m
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
list_of_trees = []
for i in range(self.m):
new_tree = Tree.init_tree(
tree_id=i,
leaf_node_value=initial_value_leaf_nodes,
idx_data_points=initial_idx_data_points_leaf_nodes,
)
list_of_trees.append(new_tree)
# Diff trick to speed computation of residuals. From Section 3.1 of Kapelner, A and Bleich, J.
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
# The sum_trees_output will contain the sum of the predicted output for all trees.
# When R_j is needed we subtract the current predicted output for tree T_j.
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())

return list_of_trees

def __iter__(self):
return iter(self.trees)

def __repr_latex(self):
raise NotImplementedError

def get_available_predictors(self, idx_data_points_split_node):
possible_splitting_variables = []
for j in range(self.num_variates):
x_j = self.X[idx_data_points_split_node, j]
x_j = x_j[~np.isnan(x_j)]
for i in range(1, len(x_j)):
if x_j[i - 1] != x_j[i]:
possible_splitting_variables.append(j)
break
return possible_splitting_variables

def get_available_splitting_rules(self, idx_data_points_split_node, idx_split_variable):
x_j = self.X[idx_data_points_split_node, idx_split_variable]
x_j = x_j[~np.isnan(x_j)]
values, indices = np.unique(x_j, return_index=True)
# The last value is not consider since if we choose it as the value of
# the splitting rule assignment, it would leave the right subtree empty.
return values[:-1], indices[:-1]

def grow_tree(self, tree, index_leaf_node):
# This can be unsuccessful when there are not available predictors
successful_grow_tree = False
current_node = tree.get_node(index_leaf_node)

available_predictors = self.get_available_predictors(current_node.idx_data_points)

if not available_predictors:
return successful_grow_tree

index_selected_predictor = discrete_uniform_sampler(len(available_predictors))
selected_predictor = available_predictors[index_selected_predictor]

available_splitting_rules, _ = self.get_available_splitting_rules(
current_node.idx_data_points, selected_predictor
)
index_selected_splitting_rule = discrete_uniform_sampler(len(available_splitting_rules))
selected_splitting_rule = available_splitting_rules[index_selected_splitting_rule]

new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=selected_splitting_rule,
)

left_node_idx_data_points, right_node_idx_data_points = self.get_new_idx_data_points(
new_split_node, current_node.idx_data_points
)

left_node_value = self.draw_leaf_value(left_node_idx_data_points)
right_node_value = self.draw_leaf_value(right_node_idx_data_points)

new_left_node = LeafNode(
index=current_node.get_idx_left_child(),
value=left_node_value,
idx_data_points=left_node_idx_data_points,
)
new_right_node = LeafNode(
index=current_node.get_idx_right_child(),
value=right_node_value,
idx_data_points=right_node_idx_data_points,
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)
successful_grow_tree = True

return successful_grow_tree

def get_new_idx_data_points(self, current_split_node, idx_data_points):
idx_split_variable = current_split_node.idx_split_variable
split_value = current_split_node.split_value

left_idx = np.nonzero(self.X[idx_data_points, idx_split_variable] <= split_value)
left_node_idx_data_points = idx_data_points[left_idx]
right_idx = np.nonzero(~(self.X[idx_data_points, idx_split_variable] <= split_value))
right_node_idx_data_points = idx_data_points[right_idx]

return left_node_idx_data_points, right_node_idx_data_points

def get_residuals(self):
"""Compute the residuals."""
R_j = self.Y - self.sum_trees_output
return R_j

def get_residuals_loo(self, tree):
"""Compute the residuals without leaving the passed tree out."""
R_j = self.Y - (self.sum_trees_output - tree.predict_output(self.num_observations))
return R_j

def draw_leaf_value(self, idx_data_points):
""" Draw the residual mean."""
R_j = self.get_residuals()[idx_data_points]
draw = self.mean(R_j)
return draw


def compute_prior_probability(alpha):
"""
Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
Taken from equation 19 in [Rockova2018].

Parameters
----------
alpha : float

Returns
-------
list with probabilities for leaf nodes

References
----------
.. [Rockova2018] Veronika Rockova, Enakshi Saha (2018). On the theory of BART.
arXiv, `link <https://arxiv.org/abs/1810.00787>`__
"""
prior_leaf_prob = [0]
depth = 1
while prior_leaf_prob[-1] < 1:
prior_leaf_prob.append(1 - alpha ** depth)
depth += 1
return prior_leaf_prob


def fast_mean():
"""If available use Numba to speed up the computation of the mean."""
try:
from numba import jit
except ImportError:
return np.mean

@jit
def mean(a):
count = a.shape[0]
suma = 0
for i in range(count):
suma += a[i]
return suma / count

return mean


def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value)."""
return int(np.random.random() * upper_value)


class BART(BaseBART):
"""
BART distribution.

Distribution representing a sum over trees

Parameters
----------
X :
The design matrix.
Y :
The response vector.
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
"""

def __init__(self, X, Y, m=200, alpha=0.25):
super().__init__(X, Y, m, alpha)

def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
dist = self
X = (type(self.X),)
Y = (type(self.Y),)
alpha = self.alpha
m = self.m

if formatting == "latex":
return f"$\\text{{{name}}} \\sim \\text{{BART}}(\\text{{alpha = }}\\text{{{alpha}}}, \\text{{m = }}\\text{{{m}}})$"
else:
return f"{name} ~ BART(alpha = {alpha}, m = {m})"
Loading