Skip to content

Commit

Permalink
Robust generalized bart inference (#4709)
Browse files Browse the repository at this point in the history
* robust generalized bart inference

* update docstring

* update test

* clarify role of link function

* revert few changes

* raise a ValueError if inv_link string is not valid

* update release notes
  • Loading branch information
aloctavodia authored May 25, 2021
1 parent fc31e00 commit 8cb87fe
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 74 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
+ Fix bug in the computation of the log pseudolikelihood values (SMC-ABC). (see [#4672](https://github.com/pymc-devs/pymc3/pull/4672)).

### New Features
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675)).
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675) and [#4709](https://github.com/pymc-devs/pymc3/pull/4709)).

## PyMC3 3.11.2 (14 March 2021)

Expand Down
71 changes: 28 additions & 43 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from pandas import DataFrame, Series
from scipy.special import expit

from pymc3.distributions.distribution import NoDistribution
from pymc3.distributions.tree import LeafNode, SplitNode, Tree
Expand All @@ -30,7 +31,6 @@ def __init__(
m=200,
alpha=0.25,
split_prior=None,
scale=None,
inv_link=None,
jitter=False,
*args,
Expand Down Expand Up @@ -63,22 +63,32 @@ def __init__(
)
self.m = m
self.alpha = alpha
self.y_std = Y.std()

if scale is None:
self.leaf_scale = NormalSampler(sigma=None)
elif isinstance(scale, (int, float)):
self.leaf_scale = NormalSampler(sigma=Y.std() / self.m ** scale)

if inv_link is None:
self.inv_link = lambda x: x
self.inv_link = self.link = lambda x: x
elif isinstance(inv_link, str):
# The link function is just a rough approximation in order to allow the PGBART sampler
# to propose reasonable values for the leaf nodes.
if inv_link == "logistic":
self.inv_link = expit
self.link = lambda x: (x - 0.5) * 10
elif inv_link == "exp":
self.inv_link = np.exp
self.link = np.log
self.Y[self.Y == 0] += 0.0001
else:
raise ValueError("Accepted strings are 'logistic' or 'exp'")
else:
self.inv_link = inv_link
self.inv_link, self.link = inv_link

self.init_mean = self.link(self.Y.mean())
self.Y_un = self.link(self.Y)

self.num_observations = X.shape[0]
self.num_variates = X.shape[1]
self.available_predictors = list(range(self.num_variates))
self.ssv = SampleSplittingVariable(split_prior, self.num_variates)
self.initial_value_leaf_nodes = self.init_mean / self.m
self.trees = self.init_list_of_trees()
self.all_trees = []
self.mean = fast_mean()
Expand All @@ -96,7 +106,7 @@ def preprocess_XY(self, X, Y):
return X, Y, missing_data

def init_list_of_trees(self):
initial_value_leaf_nodes = self.Y.mean() / self.m
initial_value_leaf_nodes = self.initial_value_leaf_nodes
initial_idx_data_points_leaf_nodes = np.array(range(self.num_observations), dtype="int32")
list_of_trees = []
for i in range(self.m):
Expand All @@ -110,7 +120,7 @@ def init_list_of_trees(self):
# bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
# The sum_trees_output will contain the sum of the predicted output for all trees.
# When R_j is needed we subtract the current predicted output for tree T_j.
self.sum_trees_output = np.full_like(self.Y, self.Y.mean())
self.sum_trees_output = np.full_like(self.Y, self.init_mean)

return list_of_trees

Expand Down Expand Up @@ -181,14 +191,13 @@ def get_new_idx_data_points(self, current_split_node, idx_data_points):

def get_residuals(self):
"""Compute the residuals."""
R_j = self.Y - self.inv_link(self.sum_trees_output)

R_j = self.Y_un - self.sum_trees_output
return R_j

def draw_leaf_value(self, idx_data_points):
"""Draw the residual mean."""
R_j = self.get_residuals()[idx_data_points]
draw = self.mean(R_j) + self.leaf_scale.random()
draw = self.mean(R_j)
return draw

def predict(self, X_new):
Expand Down Expand Up @@ -278,24 +287,6 @@ def rvs(self):
return i


class NormalSampler:
def __init__(self, sigma):
self.size = 5000
self.cache = []
self.sigma = sigma

def random(self):
if self.sigma is None:
return 0
else:
if not self.cache:
self.update()
return self.cache.pop()

def update(self):
self.cache = np.random.normal(loc=0.0, scale=self.sigma, size=self.size).tolist()


class BART(BaseBART):
"""
BART distribution.
Expand All @@ -317,23 +308,17 @@ class BART(BaseBART):
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Defaults to None, all variable have the same a prior probability
scale : float
Controls the variance of the proposed leaf value. The leaf values are computed as a
Gaussian with mean equal to the conditional residual mean and variance proportional to
the variance of the response variable, and inversely proportional to the number of trees
and the scale parameter. Defaults to None, i.e the variance is 0.
inv_link : numpy function
Inverse link function defaults to None, i.e. the identity function.
inv_link : str or tuple of functions
Inverse link function defaults to None, i.e. the identity function. Accepted strings are
``logistic`` or ``exp``.
jitter : bool
Whether to jitter the X values or not. Defaults to False. When values of X are repeated,
jittering X has the effect of increasing the number of effective spliting variables,
otherwise it does not have any effect.
"""

def __init__(
self, X, Y, m=200, alpha=0.25, split_prior=None, scale=None, inv_link=None, jitter=False
):
super().__init__(X, Y, m, alpha, split_prior, scale, inv_link)
def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, inv_link=None, jitter=False):
super().__init__(X, Y, m, alpha, split_prior, inv_link)

def _str_repr(self, name=None, dist=None, formatting="plain"):
if dist is None:
Expand Down
58 changes: 31 additions & 27 deletions pymc3/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,29 @@ def __init__(self, vars=None, num_particles=10, max_stages=5000, chunk="auto", m
self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.max_stages = max_stages
self.old_trees_particles_list = []
for i in range(self.bart.m):
p = ParticleTree(self.bart.trees[i], self.bart.prior_prob_leaf_node)
self.old_trees_particles_list.append(p)

shared = make_shared_replacements(vars, model)
self.likelihood_logp = logp([model.datalogpt], vars, shared)
self.init_leaf_nodes = self.bart.initial_value_leaf_nodes
self.init_likelihood = self.likelihood_logp(self.bart.inv_link(self.bart.sum_trees_output))
self.init_log_weight = self.init_likelihood - self.log_num_particles
self.old_trees_particles_list = []
for i in range(self.bart.m):
p = ParticleTree(
self.bart.trees[i],
self.bart.prior_prob_leaf_node,
self.init_log_weight,
self.init_likelihood,
)
self.old_trees_particles_list.append(p)
super().__init__(vars, shared)

def astep(self, _):
bart = self.bart

inv_link = bart.inv_link
num_observations = bart.num_observations
variable_inclusion = np.zeros(bart.num_variates, dtype="int")

# For the tunning phase we restrict max_stages to a low number, otherwise it is almost sure
# we will reach max_stages given that our first set of m trees is not good at all.
# Can set max_stages as a function of the number of variables/dimensions? XXX
if self.tune:
max_stages = 5
else:
max_stages = self.max_stages

if self.idx == bart.m:
self.idx = 0

Expand All @@ -110,25 +110,28 @@ def astep(self, _):
bart.sum_trees_output -= old_prediction
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree.tree_id, num_observations, inv_link)
particles = self.init_particles(tree.tree_id)

for t in range(1, max_stages):
for t in range(1, self.max_stages):
# Get old particle at stage t
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
# sample each particle (try to grow each tree)
for c in range(1, self.num_particles):
particles[c].sample_tree_sequential(bart)
# Update weights. Since the prior is used as the proposal,the weights
# are updated additively as the ratio of the new and old log_likelihoods
for p_idx, p in enumerate(particles):
new_likelihood = self.likelihood_logp(inv_link(p.tree.predict_output()))
for p in particles:
new_likelihood = self.likelihood_logp(
inv_link(bart.sum_trees_output + p.tree.predict_output())
)
p.log_weight += new_likelihood - p.old_likelihood_logp
p.old_likelihood_logp = new_likelihood

# Normalize weights
W, normalized_weights = self.normalize(particles)
# Resample all but first particle
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()

new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
particles[1:] = particles[new_indices]

Expand All @@ -149,8 +152,7 @@ def astep(self, _):
new_tree = np.random.choice(particles, p=normalized_weights)
self.old_trees_particles_list[tree.tree_id] = new_tree
bart.trees[idx] = new_tree.tree
new_prediction = new_tree.tree.predict_output()
bart.sum_trees_output += new_prediction
bart.sum_trees_output += new_tree.tree.predict_output()

if not self.tune:
self.iter += 1
Expand Down Expand Up @@ -194,26 +196,28 @@ def get_old_tree_particle(self, tree_id, t):
old_tree_particle.set_particle_to_step(t)
return old_tree_particle

def init_particles(self, tree_id, num_observations, inv_link):
def init_particles(self, tree_id):
"""
Initialize particles
"""
# The first particle is from the tree we are trying to replace
prev_tree = self.get_old_tree_particle(tree_id, 0)
likelihood = self.likelihood_logp(inv_link(prev_tree.tree.predict_output()))
likelihood = self.likelihood_logp(self.bart.inv_link(prev_tree.tree.predict_output()))
prev_tree.old_likelihood_logp = likelihood
prev_tree.log_weight = likelihood - self.log_num_particles
particles = [prev_tree]

# The rest of the particles are identically initialized
initial_idx_data_points_leaf_nodes = np.arange(num_observations, dtype="int32")
initial_idx_data_points_leaf_nodes = np.arange(self.bart.num_observations, dtype="int32")
new_tree = Tree.init_tree(
tree_id=tree_id,
leaf_node_value=0,
leaf_node_value=self.init_leaf_nodes,
idx_data_points=initial_idx_data_points_leaf_nodes,
)
for i in range(1, self.num_particles):
particles.append(ParticleTree(new_tree, self.bart.prior_prob_leaf_node, 0, 0))

prior_prob = self.bart.prior_prob_leaf_node
for _ in range(1, self.num_particles):
particles.append(
ParticleTree(new_tree, prior_prob, self.init_log_weight, self.init_likelihood)
)

return np.array(particles)

Expand Down
4 changes: 1 addition & 3 deletions pymc3/tests/test_bart.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np

from scipy.special import expit

import pymc3 as pm


Expand Down Expand Up @@ -106,7 +104,7 @@ def test_model():

Y = np.repeat([0, 1], 50)
with pm.Model() as model:
mu = pm.BART("mu", X, Y, m=50, inv_link=expit, scale=0.25)
mu = pm.BART("mu", X, Y, m=50, inv_link="logistic")
y = pm.Bernoulli("y", mu, observed=Y)
trace = pm.sample(1000, random_seed=212480)

Expand Down

0 comments on commit 8cb87fe

Please sign in to comment.