Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART: improve accuracy and other minor fixes #5177

Merged
merged 3 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
- `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)).
- New features for BART:
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
- Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044).
- Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091).
- Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)).
- `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098)
- ...

Expand Down
76 changes: 42 additions & 34 deletions pymc/bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ class PGBART(ArrayStepShared):
Number of particles for the conditional SMC sampler. Defaults to 10
max_stages : int
Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
batch : int
batch : int or tuple
Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
during tuning and 20% after tuning.
during tuning and 20% after tuning. If a tuple is passed the first element is the batch size
during tuning and the second the batch size after tuning.
model: PyMC Model
Optional model for sampling step. Defaults to None (taken from context).

Expand Down Expand Up @@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.alpha = self.bart.alpha
self.k = self.bart.k
self.response = self.bart.response
self.split_prior = self.bart.split_prior
if self.split_prior is None:
self.split_prior = np.ones(self.X.shape[1])
self.alpha_vec = self.bart.split_prior
if self.alpha_vec is None:
self.alpha_vec = np.ones(self.X.shape[1])

self.init_mean = self.Y.mean()
# if data is binary
Expand All @@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.mu_std = 6 / (self.k * self.m ** 0.5)
# maybe we need to check for count data
else:
self.mu_std = self.Y.std() / (self.k * self.m ** 0.5)
self.mu_std = (2 * self.Y.std()) / (self.k * self.m ** 0.5)

self.num_observations = self.X.shape[0]
self.num_variates = self.X.shape[1]
Expand All @@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo

self.normal = NormalSampler()
self.prior_prob_leaf_node = compute_prior_probability(self.alpha)
self.ssv = SampleSplittingVariable(self.split_prior)
self.ssv = SampleSplittingVariable(self.alpha_vec)

self.tune = True
self.idx = 0
self.batch = batch

if self.batch == "auto":
self.batch = max(1, int(self.m * 0.1))
if batch == "auto":
self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2)))
else:
if isinstance(batch, (tuple, list)):
self.batch = batch
else:
self.batch = (batch, batch)

self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.len_indices = len(self.indices)
Expand All @@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
self.all_particles = []
for i in range(self.m):
self.a_tree.tree_id = i
self.a_tree.leaf_node_value = (
self.init_mean / self.m + self.normal.random() * self.mu_std,
)
p = ParticleTree(
self.a_tree,
self.init_log_weight,
Expand All @@ -201,20 +209,16 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
sum_trees_output = q.data
variable_inclusion = np.zeros(self.num_variates, dtype="int")

if self.idx == self.m:
self.idx = 0

for tree_id in range(self.idx, self.idx + self.batch):
if tree_id >= self.m:
break
tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune])
for tree_id in tree_ids:
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree_id)
# Compute the sum of trees without the tree we are attempting to replace
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()

# The old tree is not growing so we update the weights only once.
self.update_weight(particles[0])
self.update_weight(particles[0], new=True)
for t in range(self.max_stages):
# Sample each particle (try to grow each tree), except for the first one.
for p in particles[1:]:
Expand All @@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if tree_grew:
self.update_weight(p)
# Normalize weights
W_t, normalized_weights = self.normalize(particles)
W_t, normalized_weights = self.normalize(particles[1:])

# Resample all but first particle
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
re_n_w = normalized_weights
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
particles[1:] = particles[new_indices]

# Set the new weights
for p in particles:
for p in particles[1:]:
p.log_weight = W_t

# Check if particles can keep growing, otherwise stop iterating
Expand All @@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if all(non_available_nodes_for_expansion):
break

for p in particles[1:]:
p.log_weight = p.old_likelihood_logp

_, normalized_weights = self.normalize(particles)
# Get the new tree and update
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
self.all_trees[self.idx] = new_tree
self.all_trees[tree_id] = new_tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree_id] = new_particle
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()

if self.tune:
self.ssv = SampleSplittingVariable(self.alpha_vec)
for index in new_particle.used_variates:
self.split_prior[index] += 1
self.ssv = SampleSplittingVariable(self.split_prior)
self.alpha_vec[index] += 1
else:
self.batch = max(1, int(self.m * 0.2))
for index in new_particle.used_variates:
variable_inclusion[index] += 1
self.idx += 1

stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)}
sum_trees_output = RaveledVars(sum_trees_output, point_map_info)
Expand Down Expand Up @@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:

return np.array(particles)

def update_weight(self, particle: List[ParticleTree]) -> None:
def update_weight(self, particle: List[ParticleTree], new=False) -> None:
"""
Update the weight of a particle

Expand All @@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
new_likelihood = self.likelihood_logp(
self.sum_trees_output_noi + particle.tree.predict_output()
)
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood
if new:
particle.log_weight = new_likelihood
else:
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood


class SampleSplittingVariable:
def __init__(self, alpha_prior):
def __init__(self, alpha_vec):
"""
Sample splitting variables proportional to `alpha_prior`.
Sample splitting variables proportional to `alpha_vec`.

This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior`
parameter and then using those weights to sample from the available spliting variables.
This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
This enforce sparsity.
"""
self.enu = list(enumerate(np.cumsum(alpha_prior / alpha_prior.sum())))
self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum())))

def rvs(self):
r = np.random.random()
Expand Down