diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 6d58c0adca1..ccde13925bc 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -86,8 +86,9 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01 - `pm.DensityDist` can now accept an optional `logcdf` keyword argument to pass in a function to compute the cummulative density function of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)). - `pm.DensityDist` can now accept an optional `get_moment` keyword argument to pass in a function to compute the moment of the distribution (see [5026](https://github.com/pymc-devs/pymc/pull/5026)). - New features for BART: - - Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). - - Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). + - Added linear response, increased number of trees fitted per step [5044](https://github.com/pymc-devs/pymc3/pull/5044). + - Added partial dependence plots and individual conditional expectation plots [5091](https://github.com/pymc-devs/pymc3/pull/5091). + - Modify how particle weights are computed. This improves accuracy of the modeled function (see [5177](https://github.com/pymc-devs/pymc3/pull/5177)). - `pm.Data` now passes additional kwargs to `aesara.shared`. [#5098](https://github.com/pymc-devs/pymc/pull/5098) - ... diff --git a/pymc/bart/pgbart.py b/pymc/bart/pgbart.py index 733641038d1..3bc7ac0a25b 100644 --- a/pymc/bart/pgbart.py +++ b/pymc/bart/pgbart.py @@ -101,9 +101,10 @@ class PGBART(ArrayStepShared): Number of particles for the conditional SMC sampler. Defaults to 10 max_stages : int Maximum number of iterations of the conditional SMC sampler. Defaults to 100. - batch : int + batch : int or tuple Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees - during tuning and 20% after tuning. + during tuning and 20% after tuning. If a tuple is passed the first element is the batch size + during tuning and the second the batch size after tuning. model: PyMC Model Optional model for sampling step. Defaults to None (taken from context). @@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.alpha = self.bart.alpha self.k = self.bart.k self.response = self.bart.response - self.split_prior = self.bart.split_prior - if self.split_prior is None: - self.split_prior = np.ones(self.X.shape[1]) + self.alpha_vec = self.bart.split_prior + if self.alpha_vec is None: + self.alpha_vec = np.ones(self.X.shape[1]) self.init_mean = self.Y.mean() # if data is binary @@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.mu_std = 6 / (self.k * self.m ** 0.5) # maybe we need to check for count data else: - self.mu_std = self.Y.std() / (self.k * self.m ** 0.5) + self.mu_std = (2 * self.Y.std()) / (self.k * self.m ** 0.5) self.num_observations = self.X.shape[0] self.num_variates = self.X.shape[1] @@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.normal = NormalSampler() self.prior_prob_leaf_node = compute_prior_probability(self.alpha) - self.ssv = SampleSplittingVariable(self.split_prior) + self.ssv = SampleSplittingVariable(self.alpha_vec) self.tune = True - self.idx = 0 - self.batch = batch - if self.batch == "auto": - self.batch = max(1, int(self.m * 0.1)) + if batch == "auto": + self.batch = (max(1, int(self.m * 0.1)), max(1, int(self.m * 0.2))) + else: + if isinstance(batch, (tuple, list)): + self.batch = batch + else: + self.batch = (batch, batch) + self.log_num_particles = np.log(num_particles) self.indices = list(range(1, num_particles)) self.len_indices = len(self.indices) @@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo self.all_particles = [] for i in range(self.m): self.a_tree.tree_id = i + self.a_tree.leaf_node_value = ( + self.init_mean / self.m + self.normal.random() * self.mu_std, + ) p = ParticleTree( self.a_tree, self.init_log_weight, @@ -201,12 +209,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: sum_trees_output = q.data variable_inclusion = np.zeros(self.num_variates, dtype="int") - if self.idx == self.m: - self.idx = 0 - - for tree_id in range(self.idx, self.idx + self.batch): - if tree_id >= self.m: - break + tree_ids = np.random.randint(0, self.m, size=self.batch[~self.tune]) + for tree_id in tree_ids: # Generate an initial set of SMC particles # at the end of the algorithm we return one of these particles as the new tree particles = self.init_particles(tree_id) @@ -214,7 +218,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output() # The old tree is not growing so we update the weights only once. - self.update_weight(particles[0]) + self.update_weight(particles[0], new=True) for t in range(self.max_stages): # Sample each particle (try to grow each tree), except for the first one. for p in particles[1:]: @@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: if tree_grew: self.update_weight(p) # Normalize weights - W_t, normalized_weights = self.normalize(particles) + W_t, normalized_weights = self.normalize(particles[1:]) # Resample all but first particle - re_n_w = normalized_weights[1:] / normalized_weights[1:].sum() + re_n_w = normalized_weights new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w) particles[1:] = particles[new_indices] # Set the new weights - for p in particles: + for p in particles[1:]: p.log_weight = W_t # Check if particles can keep growing, otherwise stop iterating @@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: if all(non_available_nodes_for_expansion): break + for p in particles[1:]: + p.log_weight = p.old_likelihood_logp + + _, normalized_weights = self.normalize(particles) # Get the new tree and update new_particle = np.random.choice(particles, p=normalized_weights) new_tree = new_particle.tree - self.all_trees[self.idx] = new_tree + self.all_trees[tree_id] = new_tree new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles self.all_particles[tree_id] = new_particle sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output() if self.tune: + self.ssv = SampleSplittingVariable(self.alpha_vec) for index in new_particle.used_variates: - self.split_prior[index] += 1 - self.ssv = SampleSplittingVariable(self.split_prior) + self.alpha_vec[index] += 1 else: - self.batch = max(1, int(self.m * 0.2)) for index in new_particle.used_variates: variable_inclusion[index] += 1 - self.idx += 1 stats = {"variable_inclusion": variable_inclusion, "bart_trees": copy(self.all_trees)} sum_trees_output = RaveledVars(sum_trees_output, point_map_info) @@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray: return np.array(particles) - def update_weight(self, particle: List[ParticleTree]) -> None: + def update_weight(self, particle: List[ParticleTree], new=False) -> None: """ Update the weight of a particle @@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None: new_likelihood = self.likelihood_logp( self.sum_trees_output_noi + particle.tree.predict_output() ) - particle.log_weight += new_likelihood - particle.old_likelihood_logp - particle.old_likelihood_logp = new_likelihood + if new: + particle.log_weight = new_likelihood + else: + particle.log_weight += new_likelihood - particle.old_likelihood_logp + particle.old_likelihood_logp = new_likelihood class SampleSplittingVariable: - def __init__(self, alpha_prior): + def __init__(self, alpha_vec): """ - Sample splitting variables proportional to `alpha_prior`. + Sample splitting variables proportional to `alpha_vec`. - This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior` - parameter and then using those weights to sample from the available spliting variables. + This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model. This enforce sparsity. """ - self.enu = list(enumerate(np.cumsum(alpha_prior / alpha_prior.sum()))) + self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) def rvs(self): r = np.random.random()