Skip to content

Commit

Permalink
regularize binomial bart (#4720)
Browse files Browse the repository at this point in the history
* regularize binomial bart

* update release notes
  • Loading branch information
aloctavodia authored May 27, 2021
1 parent 8cb87fe commit 9e4c7f9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
+ Fix bug in the computation of the log pseudolikelihood values (SMC-ABC). (see [#4672](https://github.com/pymc-devs/pymc3/pull/4672)).

### New Features
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675) and [#4709](https://github.com/pymc-devs/pymc3/pull/4709)).
+ Generalized BART, bounded distributions like Binomial and Poisson can now be used as likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675), [#4709](https://github.com/pymc-devs/pymc3/pull/4709) and
[#4720](https://github.com/pymc-devs/pymc3/pull/4720)).

## PyMC3 3.11.2 (14 March 2021)

Expand Down
15 changes: 8 additions & 7 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self,
X,
Y,
m=200,
m=50,
alpha=0.25,
split_prior=None,
inv_link=None,
Expand Down Expand Up @@ -67,11 +67,12 @@ def __init__(
if inv_link is None:
self.inv_link = self.link = lambda x: x
elif isinstance(inv_link, str):
# The link function is just a rough approximation in order to allow the PGBART sampler
# to propose reasonable values for the leaf nodes.
if inv_link == "logistic":
self.inv_link = expit
self.link = lambda x: (x - 0.5) * 10
# The link function is just a rough approximation in order to allow the PGBART
# sampler to propose reasonable values for the leaf nodes. The regularizing term
# 2 * self.m ** 0.5 is inspired by Chipman's DOI: 10.1214/09-AOAS285
self.link = lambda x: (x - 0.5) * 2 * self.m ** 0.5
elif inv_link == "exp":
self.inv_link = np.exp
self.link = np.log
Expand Down Expand Up @@ -302,8 +303,8 @@ class BART(BaseBART):
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
Control the prior probability over the depth of the trees. Even when it can takes values in
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Expand All @@ -317,7 +318,7 @@ class BART(BaseBART):
otherwise it does not have any effect.
"""

def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, inv_link=None, jitter=False):
def __init__(self, X, Y, m=50, alpha=0.25, split_prior=None, inv_link=None, jitter=False):
super().__init__(X, Y, m, alpha, split_prior, inv_link)

def _str_repr(self, name=None, dist=None, formatting="plain"):
Expand Down

0 comments on commit 9e4c7f9

Please sign in to comment.