Skip to content

Commit

Permalink
remove linear and mix response
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Nov 30, 2021
1 parent 0167c88 commit f023e32
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 89 deletions.
5 changes: 0 additions & 5 deletions pymc/bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ class BART(NoDistribution):
k : float
Scale parameter for the values of the leaf nodes. Defaults to 2. Recomended to be between 1
and 3.
response : str
How the leaf_node values are computed. Available options are ``constant`` (default),
``linear`` or ``mix``.
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum to
1. Otherwise they will be normalized.
Expand All @@ -84,7 +81,6 @@ def __new__(
m=50,
alpha=0.25,
k=2,
response="constant",
split_prior=None,
**kwargs,
):
Expand All @@ -103,7 +99,6 @@ def __new__(
m=m,
alpha=alpha,
k=k,
response=response,
split_prior=split_prior,
),
)()
Expand Down
71 changes: 9 additions & 62 deletions pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
self.m = self.bart.m
self.alpha = self.bart.alpha
self.k = self.bart.k
self.response = self.bart.response
self.alpha_vec = self.bart.split_prior
if self.alpha_vec is None:
self.alpha_vec = np.ones(self.X.shape[1])
Expand All @@ -90,10 +89,8 @@ def __init__(self, vars=None, num_particles=40, max_stages=100, batch="auto", mo
self.a_tree = Tree.init_tree(
leaf_node_value=self.init_mean / self.m,
idx_data_points=np.arange(self.num_observations, dtype="int32"),
m=self.m,
)
self.mean = fast_mean()
self.linear_fit = fast_linear_fit()

self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
Expand Down Expand Up @@ -140,11 +137,9 @@ def astep(self, _):
self.sum_trees,
self.X,
self.mean,
self.linear_fit,
self.m,
self.normal,
self.mu_std,
self.response,
)

# The old tree and the one with new leafs do not grow so we update the weights only once
Expand All @@ -162,11 +157,9 @@ def astep(self, _):
self.missing_data,
self.sum_trees,
self.mean,
self.linear_fit,
self.m,
self.normal,
self.mu_std,
self.response,
)
if tree_grew:
self.update_weight(p)
Expand Down Expand Up @@ -286,11 +279,9 @@ def sample_tree(
missing_data,
sum_trees,
mean,
linear_fit,
m,
normal,
mu_std,
response,
):
tree_grew = False
if self.expansion_nodes:
Expand All @@ -308,11 +299,9 @@ def sample_tree(
missing_data,
sum_trees,
mean,
linear_fit,
m,
normal,
mu_std,
response,
)
if index_selected_predictor is not None:
new_indexes = self.tree.idx_leaf_nodes[-2:]
Expand All @@ -322,9 +311,9 @@ def sample_tree(

return tree_grew

def sample_leafs(self, sum_trees, X, mean, linear_fit, m, normal, mu_std, response):
def sample_leafs(self, sum_trees, X, mean, m, normal, mu_std):

sample_leaf_values(self.tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response)
sample_leaf_values(self.tree, sum_trees, X, mean, m, normal, mu_std)


class SampleSplittingVariable:
Expand Down Expand Up @@ -379,11 +368,9 @@ def grow_tree(
missing_data,
sum_trees,
mean,
linear_fit,
m,
normal,
mu_std,
response,
):
current_node = tree.get_node(index_leaf_node)
idx_data_points = current_node.idx_data_points
Expand All @@ -409,28 +396,22 @@ def grow_tree(
current_node.get_idx_right_child(),
)

if response == "mix":
response = "linear" if np.random.random() >= 0.5 else "constant"

new_nodes = []
for idx in range(2):
idx_data_point = new_idx_data_points[idx]
node_value, node_linear_params = draw_leaf_value(
node_value = draw_leaf_value(
sum_trees[idx_data_point],
X[idx_data_point, selected_predictor],
mean,
linear_fit,
m,
normal,
mu_std,
response,
)

new_node = LeafNode(
index=current_node_children[idx],
value=node_value,
idx_data_points=idx_data_point,
linear_params=node_linear_params,
)
new_nodes.append(new_node)

Expand All @@ -449,26 +430,23 @@ def grow_tree(
return index_selected_predictor


def sample_leaf_values(tree, sum_trees, X, mean, linear_fit, m, normal, mu_std, response):
def sample_leaf_values(tree, sum_trees, X, mean, m, normal, mu_std):

for idx in tree.idx_leaf_nodes:
if idx > 0:
leaf = tree[idx]
idx_data_points = leaf.idx_data_points
parent_node = tree[leaf.get_idx_parent_node()]
selected_predictor = parent_node.idx_split_variable
node_value, node_linear_params = draw_leaf_value(
node_value = draw_leaf_value(
sum_trees[idx_data_points],
X[idx_data_points, selected_predictor],
mean,
linear_fit,
m,
normal,
mu_std,
response,
)
leaf.value = node_value
leaf.linear_params = node_linear_params


def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X):
Expand All @@ -480,24 +458,19 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
return left_node_idx_data_points, right_node_idx_data_points


def draw_leaf_value(Y_mu_pred, X_mu, mean, linear_fit, m, normal, mu_std, response):
def draw_leaf_value(Y_mu_pred, X_mu, mean, m, normal, mu_std):
"""Draw Gaussian distributed leaf values"""
linear_params = None
if Y_mu_pred.size == 0:
return 0, linear_params
return 0
else:
norm = normal.random() * mu_std
if Y_mu_pred.size == 1:
mu_mean = Y_mu_pred.item() / m
elif response == "constant":
else:
mu_mean = mean(Y_mu_pred) / m
elif response == "linear":
Y_fit, linear_params = linear_fit(X_mu, Y_mu_pred)
mu_mean = Y_fit / m
linear_params[2] = norm

draw = norm + mu_mean
return draw, linear_params
return draw


def fast_mean():
Expand All @@ -518,32 +491,6 @@ def mean(a):
return mean


def fast_linear_fit():
"""If available use Numba to speed up the computation of the linear fit"""

def linear_fit(X, Y):

n = len(Y)
xbar = np.sum(X) / n
ybar = np.sum(Y) / n

den = X @ X - n * xbar ** 2
if den > 1e-10:
b = (X @ Y - n * xbar * ybar) / den
else:
b = 0
a = ybar - b * xbar
Y_fit = a + b * X
return Y_fit, [a, b, 0]

try:
from numba import jit

return jit(linear_fit)
except ImportError:
return linear_fit


def discrete_uniform_sampler(upper_value):
"""Draw from the uniform distribution with bounds [0, upper_value).
Expand Down
33 changes: 12 additions & 21 deletions pymc/bart/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ class Tree:
num_observations : int, optional
"""

def __init__(self, num_observations=0, m=0):
def __init__(self, num_observations=0):
self.tree_structure = {}
self.idx_leaf_nodes = []
self.num_observations = num_observations
self.m = m

def __getitem__(self, index):
return self.get_node(index)
Expand Down Expand Up @@ -97,16 +96,10 @@ def predict_out_of_sample(self, X):
float
Value of the leaf value where the unobserved point lies.
"""
leaf_node, split_variable = self._traverse_tree(X, node_index=0)
linear_params = leaf_node.linear_params
if linear_params is None:
return leaf_node.value
else:
x = X[split_variable].item()
y_x = (linear_params[0] + linear_params[1] * x) / self.m
return y_x + linear_params[2]

def _traverse_tree(self, x, node_index=0, split_variable=None):
leaf_node = self._traverse_tree(X, node_index=0)
return leaf_node.value

def _traverse_tree(self, x, node_index=0):
"""
Traverse the tree starting from a particular node given an unobserved point.
Expand All @@ -121,17 +114,16 @@ def _traverse_tree(self, x, node_index=0, split_variable=None):
"""
current_node = self.get_node(node_index)
if isinstance(current_node, SplitNode):
split_variable = current_node.idx_split_variable
if x[split_variable] <= current_node.split_value:
if x[current_node.idx_split_variable] <= current_node.split_value:
left_child = current_node.get_idx_left_child()
current_node, split_variable = self._traverse_tree(x, left_child, split_variable)
current_node = self._traverse_tree(x, left_child)
else:
right_child = current_node.get_idx_right_child()
current_node, split_variable = self._traverse_tree(x, right_child, split_variable)
return current_node, split_variable
current_node = self._traverse_tree(x, right_child)
return current_node

@staticmethod
def init_tree(leaf_node_value, idx_data_points, m):
def init_tree(leaf_node_value, idx_data_points):
"""
Parameters
Expand All @@ -145,7 +137,7 @@ def init_tree(leaf_node_value, idx_data_points, m):
-------
"""
new_tree = Tree(len(idx_data_points), m)
new_tree = Tree(len(idx_data_points))
new_tree[0] = LeafNode(index=0, value=leaf_node_value, idx_data_points=idx_data_points)
return new_tree

Expand Down Expand Up @@ -174,8 +166,7 @@ def __init__(self, index, idx_split_variable, split_value):


class LeafNode(BaseNode):
def __init__(self, index, value, idx_data_points, linear_params=None):
def __init__(self, index, value, idx_data_points):
super().__init__(index)
self.value = value
self.idx_data_points = idx_data_points
self.linear_params = linear_params
2 changes: 1 addition & 1 deletion pymc/tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class TestUtils:
Y = np.random.normal(0, 1, size=50)

with pm.Model() as model:
mu = pm.BART("mu", X, Y, m=10, response="mix")
mu = pm.BART("mu", X, Y, m=10)
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu, sigma, observed=Y)
idata = pm.sample(random_seed=3415)
Expand Down

0 comments on commit f023e32

Please sign in to comment.