Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

regularize binomial bart #4720

Merged
merged 2 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
+ Fix bug in the computation of the log pseudolikelihood values (SMC-ABC). (see [#4672](https://github.com/pymc-devs/pymc3/pull/4672)).

### New Features
+ BART with non-gaussian likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675) and [#4709](https://github.com/pymc-devs/pymc3/pull/4709)).
+ Generalized BART, bounded distributions like Binomial and Poisson can now be used as likelihoods (see [#4675](https://github.com/pymc-devs/pymc3/pull/4675), [#4709](https://github.com/pymc-devs/pymc3/pull/4709) and
[#4720](https://github.com/pymc-devs/pymc3/pull/4720)).

## PyMC3 3.11.2 (14 March 2021)

Expand Down
15 changes: 8 additions & 7 deletions pymc3/distributions/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
self,
X,
Y,
m=200,
m=50,
alpha=0.25,
split_prior=None,
inv_link=None,
Expand Down Expand Up @@ -67,11 +67,12 @@ def __init__(
if inv_link is None:
self.inv_link = self.link = lambda x: x
elif isinstance(inv_link, str):
# The link function is just a rough approximation in order to allow the PGBART sampler
# to propose reasonable values for the leaf nodes.
if inv_link == "logistic":
self.inv_link = expit
self.link = lambda x: (x - 0.5) * 10
# The link function is just a rough approximation in order to allow the PGBART
# sampler to propose reasonable values for the leaf nodes. The regularizing term
# 2 * self.m ** 0.5 is inspired by Chipman's DOI: 10.1214/09-AOAS285
self.link = lambda x: (x - 0.5) * 2 * self.m ** 0.5
elif inv_link == "exp":
self.inv_link = np.exp
self.link = np.log
Expand Down Expand Up @@ -302,8 +303,8 @@ class BART(BaseBART):
m : int
Number of trees
alpha : float
Control the prior probability over the depth of the trees. Must be in the interval (0, 1),
altought it is recomenned to be in the interval (0, 0.5].
Control the prior probability over the depth of the trees. Even when it can takes values in
the interval (0, 1), it is recommended to be in the interval (0, 0.5].
split_prior : array-like
Each element of split_prior should be in the [0, 1] interval and the elements should sum
to 1. Otherwise they will be normalized.
Expand All @@ -317,7 +318,7 @@ class BART(BaseBART):
otherwise it does not have any effect.
"""

def __init__(self, X, Y, m=200, alpha=0.25, split_prior=None, inv_link=None, jitter=False):
def __init__(self, X, Y, m=50, alpha=0.25, split_prior=None, inv_link=None, jitter=False):
super().__init__(X, Y, m, alpha, split_prior, inv_link)

def _str_repr(self, name=None, dist=None, formatting="plain"):
Expand Down