Skip to content

Commit

Permalink
BART: improve accuracy and other minor fixes (#5177)
Browse files Browse the repository at this point in the history
* improve accuracy and other minor fixes

* update release notes

* fix typo
  • Loading branch information
aloctavodia authored Nov 17, 2021
1 parent a11eaa2 commit c0c5a80
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 36 deletions.
5 changes: 3 additions & 2 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
- New features for BART:
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
- Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)).
- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098)
- ...

Expand Down
76 changes: 42 additions & 34 deletions pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ class PGBART(ArrayStepShared):
Number of particles for the conditional SMC sampler. Defaults to 10
max_stages : int
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
batch : int
batch : int or tuple
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and 20% after tuning.
during tuning and 20% after tuning. If a tuple is passed the first element is the batch size
during tuning and the second the batch size after tuning.
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context).
Expand Down Expand Up @@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.alpha = self.bart.alpha
self.k = self.bart.k
self.response = self.bart.response
self.split_prior = self.bart.split_prior
if self.split_prior is None:
self.split_prior = np.ones(self.X.shape[1])
self.alpha_vec = self.bart.split_prior
if self.alpha_vec is None:
self.alpha_vec = np.ones(self.X.shape[1])

self.init_mean = self.Y.mean()
# if data is binary
Expand All @@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.mu_std = 6 / (self.k * self.m ** 0.5)
# maybe we need to check for count data
else:
self.mu_std = self.Y.std() / (self.k * self.m ** 0.5)
self.mu_std = (2 * self.Y.std()) / (self.k * self.m ** 0.5)

self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
Expand All @@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo

self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.ssv = SampleSplittingVariable(self.split_prior)
self.ssv = SampleSplittingVariable(self.alpha_vec)

self.tune = True
self.idx = 0
self.batch = batch

if self.batch == "auto":
self.batch = max(1, int(self.m * 0.1))
if batch == "auto":
self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2)))
else:
if isinstance(batch, (tuple, list)):
self.batch = batch
else:
self.batch = (batch, batch)

self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.len_indices = len(self.indices)
Expand All @@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.all_particles = []
for i in range(self.m):
self.a_tree.tree_id = i
self.a_tree.leaf_node_value = (
self.init_mean / self.m + self.normal.random() * self.mu_std,
)
p = ParticleTree(
self.a_tree,
self.init_log_weight,
Expand All @@ -201,20 +209,16 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
sum_trees_output = q.data
variable_inclusion = np.zeros(self.num_variates, dtype="int")

if self.idx == self.m:
self.idx = 0

for tree_id in range(self.idx, self.idx + self.batch):
if tree_id >= self.m:
break
tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune])
for tree_id in tree_ids:
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree_id)
# Compute the sum of trees without the tree we are attempting to replace
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()

# The old tree is not growing so we update the weights only once.
self.update_weight(particles[0])
self.update_weight(particles[0], new=True)
for t in range(self.max_stages):
# Sample each particle (try to grow each tree), except for the first one.
for p in particles[1:]:
Expand All @@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if tree_grew:
self.update_weight(p)
# Normalize weights
W_t, normalized_weights = self.normalize(particles)
W_t, normalized_weights = self.normalize(particles[1:])

# Resample all but first particle
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
re_n_w = normalized_weights
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
particles[1:] = particles[new_indices]

# Set the new weights
for p in particles:
for p in particles[1:]:
p.log_weight = W_t

# Check if particles can keep growing, otherwise stop iterating
Expand All @@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if all(non_available_nodes_for_expansion):
break

for p in particles[1:]:
p.log_weight = p.old_likelihood_logp

_, normalized_weights = self.normalize(particles)
# Get the new tree and update
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
self.all_trees[self.idx] = new_tree
self.all_trees[tree_id] = new_tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree_id] = new_particle
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()

if self.tune:
self.ssv = SampleSplittingVariable(self.alpha_vec)
for index in new_particle.used_variates:
self.split_prior[index] += 1
self.ssv = SampleSplittingVariable(self.split_prior)
self.alpha_vec[index] += 1
else:
self.batch = max(1, int(self.m * 0.2))
for index in new_particle.used_variates:
variable_inclusion[index] += 1
self.idx += 1

stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
Expand Down Expand Up @@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:

return np.array(particles)

def update_weight(self, particle: List[ParticleTree]) -> None:
def update_weight(self, particle: List[ParticleTree], new=False) -> None:
"""
Update the weight of a particle
Expand All @@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
new_likelihood = self.likelihood_logp(
self.sum_trees_output_noi + particle.tree.predict_output()
)
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood
if new:
particle.log_weight = new_likelihood
else:
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood


class SampleSplittingVariable:
def __init__(self, alpha_prior):
def __init__(self, alpha_vec):
"""
Sample splitting variables proportional to `alpha_prior`.
Sample splitting variables proportional to `alpha_vec`.
This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior`
parameter and then using those weights to sample from the available spliting variables.
This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
This enforce sparsity.
"""
self.enu = list(enumerate(np.cumsum(alpha_prior / alpha_prior.sum())))
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))

def rvs(self):
r = np.random.random()
Expand Down

0 comments on commit c0c5a80

Please sign in to comment.