Skip to content

Commit

Permalink
BART: add linear response, increase number of trees fitted per step (#…
Browse files Browse the repository at this point in the history
…5044)

* add linear response, increase number of trees fitted per step

* fix docstring
  • Loading branch information
aloctavodia authored Oct 8, 2021
1 parent 0bb2e9b commit 70f1975
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 38 deletions.
7 changes: 6 additions & 1 deletion pymc/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rng_fn(cls, rng=np.random.default_rng(), *args, **kwargs):
pred = np.zeros((flatten_size, X_new.shape[0]))
for ind, p in enumerate(pred):
for tree in all_trees[idx[ind]]:
p += np.array([tree.predict_out_of_sample(x) for x in X_new])
p += np.array([tree.predict_out_of_sample(x, cls.m) for x in X_new])
return pred.reshape((*size, -1))
else:
return np.full_like(cls.Y, cls.Y.mean())
Expand Down Expand Up @@ -92,6 +92,9 @@ class BART(NoDistribution):
k : float
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
and 3.
response : str
How the leaf_node values are computed. Available options are ``constant``, ``linear`` or
``mix`` (default).
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Expand All @@ -106,6 +109,7 @@ def __new__(
m=50,
alpha=0.25,
k=2,
response="mix",
split_prior=None,
**kwargs,
):
Expand All @@ -125,6 +129,7 @@ def __new__(
m=m,
alpha=alpha,
k=k,
response=response,
split_prior=split_prior,
),
)()
Expand Down
32 changes: 21 additions & 11 deletions pymc/distributions/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,31 @@ def predict_output(self):

return output.astype(aesara.config.floatX)

def predict_out_of_sample(self, x):
def predict_out_of_sample(self, X, m):
"""
Predict output of tree for an unobserved point x.
Parameters
----------
x : numpy array
X : numpy array
Unobserved point
m : int
Number of trees
Returns
-------
float
Value of the leaf value where the unobserved point lies.
"""
leaf_node = self._traverse_tree(x=x, node_index=0)
return leaf_node.value

def _traverse_tree(self, x, node_index=0):
leaf_node, split_variable = self._traverse_tree(X, node_index=0)
if leaf_node.linear_params is None:
return leaf_node.value
else:
x = X[split_variable].item()
y_x = leaf_node.linear_params[0] + leaf_node.linear_params[1] * x
return y_x / m

def _traverse_tree(self, x, node_index=0, split_variable=None):
"""
Traverse the tree starting from a particular node given an unobserved point.
Expand All @@ -125,13 +133,14 @@ def _traverse_tree(self, x, node_index=0):
"""
current_node = self.get_node(node_index)
if isinstance(current_node, SplitNode):
if x[current_node.idx_split_variable] <= current_node.split_value:
split_variable = current_node.idx_split_variable
if x[split_variable] <= current_node.split_value:
left_child = current_node.get_idx_left_child()
current_node = self._traverse_tree(x, left_child)
current_node, _ = self._traverse_tree(x, left_child, split_variable)
else:
right_child = current_node.get_idx_right_child()
current_node = self._traverse_tree(x, right_child)
return current_node
current_node, _ = self._traverse_tree(x, right_child, split_variable)
return current_node, split_variable

def grow_tree(self, index_leaf_node, new_split_node, new_left_node, new_right_node):
"""
Expand Down Expand Up @@ -202,7 +211,8 @@ def __init__(self, index, idx_split_variable, split_value):


class LeafNode(BaseNode):
def __init__(self, index, value, idx_data_points):
def __init__(self, index, value, idx_data_points, linear_params=None):
super().__init__(index)
self.value = value
self.idx_data_points = idx_data_points
self.linear_params = linear_params
117 changes: 92 additions & 25 deletions pymc/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ def sample_tree_sequential(
missing_data,
sum_trees_output,
mean,
linear_fit,
m,
normal,
mu_std,
response,
):
tree_grew = False
if self.expansion_nodes:
Expand All @@ -73,9 +75,11 @@ def sample_tree_sequential(
missing_data,
sum_trees_output,
mean,
linear_fit,
m,
normal,
mu_std,
response,
)
if tree_grew:
new_indexes = self.tree.idx_leaf_nodes[-2:]
Expand All @@ -97,11 +101,17 @@ class PGBART(ArrayStepShared):
Number of particles for the conditional SMC sampler. Defaults to 10
max_stages : int
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
chunk = int
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees.
batch : int
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and 20% after tuning.
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context).
Note
----
This sampler is inspired by the [Lakshminarayanan2015] Particle Gibbs sampler, but introduces
several changes. The changes will be properly documented soon.
References
----------
.. [Lakshminarayanan2015] Lakshminarayanan, B. and Roy, D.M. and Teh, Y. W., (2015),
Expand All @@ -114,7 +124,7 @@ class PGBART(ArrayStepShared):
generates_stats = True
stats_dtypes = [{"variable_inclusion": np.ndarray}]

def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", model=None):
def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", model=None):
_log.warning("BART is experimental. Use with caution.")
model = modelcontext(model)
initial_values = model.initial_point
Expand All @@ -125,6 +135,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
self.m = self.bart.m
self.alpha = self.bart.alpha
self.k = self.bart.k
self.response = self.bart.response
self.split_prior = self.bart.split_prior
if self.split_prior is None:
self.split_prior = np.ones(self.X.shape[1])
Expand All @@ -149,6 +160,8 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
idx_data_points=np.arange(self.num_observations, dtype="int32"),
)
self.mean = fast_mean()
self.linear_fit = fast_linear_fit()

self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.ssv = SampleSplittingVariable(self.split_prior)
Expand All @@ -157,10 +170,10 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
self.idx = 0
self.iter = 0
self.sum_trees = []
self.chunk = chunk
self.batch = batch

if self.chunk == "auto":
self.chunk = max(1, int(self.m * 0.1))
if self.batch == "auto":
self.batch = max(1, int(self.m * 0.1))
self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.len_indices = len(self.indices)
Expand Down Expand Up @@ -190,7 +203,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if self.idx == self.m:
self.idx = 0

for tree_id in range(self.idx, self.idx + self.chunk):
for tree_id in range(self.idx, self.idx + self.batch):
if tree_id >= self.m:
break
# Generate an initial set of SMC particles
Expand All @@ -213,9 +226,11 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.missing_data,
sum_trees_output,
self.mean,
self.linear_fit,
self.m,
self.normal,
self.mu_std,
self.response,
)
if tree_grew:
self.update_weight(p)
Expand Down Expand Up @@ -251,6 +266,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.split_prior[index] += 1
self.ssv = SampleSplittingVariable(self.split_prior)
else:
self.batch = max(1, int(self.m * 0.2))
self.iter += 1
self.sum_trees.append(new_tree)
if not self.iter % self.m:
Expand Down Expand Up @@ -389,16 +405,20 @@ def grow_tree(
missing_data,
sum_trees_output,
mean,
linear_fit,
m,
normal,
mu_std,
response,
):
current_node = tree.get_node(index_leaf_node)
idx_data_points = current_node.idx_data_points

index_selected_predictor = ssv.rvs()
selected_predictor = available_predictors[index_selected_predictor]
available_splitting_values = X[current_node.idx_data_points, selected_predictor]
available_splitting_values = X[idx_data_points, selected_predictor]
if missing_data:
idx_data_points = idx_data_points[~np.isnan(available_splitting_values)]
available_splitting_values = available_splitting_values[
~np.isnan(available_splitting_values)
]
Expand All @@ -407,58 +427,82 @@ def grow_tree(
return False, None

idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values))
selected_splitting_rule = available_splitting_values[idx_selected_splitting_values]
split_value = available_splitting_values[idx_selected_splitting_values]
new_split_node = SplitNode(
index=index_leaf_node,
idx_split_variable=selected_predictor,
split_value=selected_splitting_rule,
split_value=split_value,
)

left_node_idx_data_points, right_node_idx_data_points = get_new_idx_data_points(
new_split_node, current_node.idx_data_points, X
split_value, idx_data_points, selected_predictor, X
)

left_node_value = draw_leaf_value(
sum_trees_output[left_node_idx_data_points], mean, m, normal, mu_std
if response == "mix":
response = "linear" if np.random.random() >= 0.5 else "constant"

left_node_value, left_node_linear_params = draw_leaf_value(
sum_trees_output[left_node_idx_data_points],
X[left_node_idx_data_points, selected_predictor],
mean,
linear_fit,
m,
normal,
mu_std,
response,
)
right_node_value = draw_leaf_value(
sum_trees_output[right_node_idx_data_points], mean, m, normal, mu_std
right_node_value, right_node_linear_params = draw_leaf_value(
sum_trees_output[right_node_idx_data_points],
X[right_node_idx_data_points, selected_predictor],
mean,
linear_fit,
m,
normal,
mu_std,
response,
)

new_left_node = LeafNode(
index=current_node.get_idx_left_child(),
value=left_node_value,
idx_data_points=left_node_idx_data_points,
linear_params=left_node_linear_params,
)
new_right_node = LeafNode(
index=current_node.get_idx_right_child(),
value=right_node_value,
idx_data_points=right_node_idx_data_points,
linear_params=right_node_linear_params,
)
tree.grow_tree(index_leaf_node, new_split_node, new_left_node, new_right_node)

return True, index_selected_predictor


def get_new_idx_data_points(current_split_node, idx_data_points, X):
idx_split_variable = current_split_node.idx_split_variable
split_value = current_split_node.split_value
def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):

left_idx = X[idx_data_points, idx_split_variable] <= split_value
left_idx = X[idx_data_points, selected_predictor] <= split_value
left_node_idx_data_points = idx_data_points[left_idx]
right_node_idx_data_points = idx_data_points[~left_idx]

return left_node_idx_data_points, right_node_idx_data_points


def draw_leaf_value(sum_trees_output_idx, mean, m, normal, mu_std):
def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, response):
"""Draw Gaussian distributed leaf values"""
if sum_trees_output_idx.size == 0:
return 0
linear_params = None
if Y_mu_pred.size == 0:
return 0, linear_params
elif Y_mu_pred.size == 1:
mu_mean = Y_mu_pred.item() / m
else:
mu_mean = mean(sum_trees_output_idx) / m
draw = normal.random() * mu_std + mu_mean
return draw
if response == "constant":
mu_mean = mean(Y_mu_pred) / m
elif response == "linear":
Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred)
mu_mean = Y_fit / m
draw = normal.random() * mu_std + mu_mean
return draw, linear_params


def fast_mean():
Expand All @@ -479,6 +523,29 @@ def mean(a):
return mean


def fast_linear_fit():
"""If available use Numba to speed up the computation of the linear fit"""

def linear_fit(X, Y):

n = len(Y)
xbar = np.sum(X) / n
ybar = np.sum(Y) / n

b = (X @ Y - n * xbar * ybar) / (X @ X - n * xbar ** 2)
a = ybar - b * xbar

Y_fit = a + b * X
return Y_fit, (a, b)

try:
from numba import jit

return jit(linear_fit)
except ImportError:
return linear_fit


def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value).
Expand Down
1 change: 0 additions & 1 deletion pymc/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def test_bart_random():
rng = RandomState(12345)
pred_first = mu.owner.op.rng_fn(rng, X_new=X[:10])

assert_almost_equal(pred_first, pred_all[0, :10], decimal=4)
assert pred_all.shape == (2, 50)
assert pred_first.shape == (10,)

Expand Down

0 comments on commit 70f1975

Please sign in to comment.