-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Improper Distribution #612
Changes from 4 commits
909a82c
338042e
a51b132
3f7660b
689b740
9ae051d
e412521
dedc6c4
bbd668d
2789071
0abe8a0
25413be
5f43c3c
c4ba460
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,8 @@ | |
|
||
def init_to_median(site=None, num_samples=15): | ||
""" | ||
Initialize to the prior median. | ||
Initialize to the prior median. For priors with no `.sample` method implemented, | ||
we defer to the :func:`init_to_uniform` strategy. | ||
|
||
:param int num_samples: number of prior points to calculate median. | ||
""" | ||
|
@@ -22,13 +23,17 @@ def init_to_median(site=None, num_samples=15): | |
if site['type'] == 'sample' and not site['is_observed']: | ||
rng_key = site['kwargs'].get('rng_key') | ||
sample_shape = site['kwargs'].get('sample_shape') | ||
samples = site['fn'].sample(rng_key, sample_shape=(num_samples,) + sample_shape) | ||
return np.median(samples, axis=0) | ||
try: | ||
samples = site['fn'].sample(rng_key, sample_shape=(num_samples,) + sample_shape) | ||
return np.median(samples, axis=0) | ||
except NotImplementedError: | ||
return init_to_uniform(site) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For substitute_fn, I think it only works if we substitue(seed(model), ...) because seed will give rng_key for the site we want to apply substitute_fn. Inside substitute_fn, we dont use sample primitive, so there is no random.split here IIUC. |
||
|
||
|
||
def init_to_prior(site=None): | ||
""" | ||
Initialize to a prior sample. | ||
Initialize to a prior sample. For priors with no `.sample` method implemented, | ||
we defer to the :func:`init_to_uniform` strategy. | ||
""" | ||
return init_to_median(site, num_samples=1) | ||
|
||
|
@@ -47,13 +52,23 @@ def init_to_uniform(site=None, radius=2): | |
rng_key = site['kwargs'].get('rng_key') | ||
sample_shape = site['kwargs'].get('sample_shape') | ||
rng_key, subkey = random.split(rng_key) | ||
transform = biject_to(fn.support) | ||
# this is used to interpret the changes of event_shape in | ||
# domain and codomain spaces | ||
prototype_value = fn.sample(subkey, sample_shape=()) | ||
transform = biject_to(fn.support) | ||
unconstrained_event_shape = np.shape(transform.inv(prototype_value)) | ||
try: | ||
prototype_value = fn.sample(subkey, sample_shape=()) | ||
unconstrained_shape = np.shape(transform.inv(prototype_value)) | ||
except NotImplementedError: | ||
# XXX: this works for ImproperUniform prior, | ||
# we can't use this logic for general priors | ||
# because some distributions such as TransformedDistribution might | ||
# have wrong event_shape. | ||
prototype_value = np.full(fn.event_shape, np.nan) | ||
unconstrained_event_shape = np.shape(transform.inv(prototype_value)) | ||
unconstrained_shape = fn.batch_shape + unconstrained_event_shape | ||
|
||
unconstrained_samples = dist.Uniform(-radius, radius).sample( | ||
rng_key, sample_shape=sample_shape + unconstrained_event_shape) | ||
rng_key, sample_shape=sample_shape + unconstrained_shape) | ||
return transform(unconstrained_samples) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -253,3 +253,17 @@ def model(data): | |
for name, p in init_params[0].items(): | ||
# XXX: the result is equal if we disable fast-math-mode | ||
assert_allclose(p[i], init_params_i[0][name], atol=1e-6) | ||
|
||
|
||
def test_improper_expand(): | ||
|
||
def model(): | ||
population = np.array([1000., 2000., 3000.]) | ||
with numpyro.plate("region", 3): | ||
numpyro.sample("incidence", | ||
dist.ImproperUniform(support=constraints.interval(0, population), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also add a d = dist.ImproperUniform(support=constraints.interval(0, population)
incidence = numpyro.sample("incidence",
d,
batch_shape=(3,),
event_shape=event_shape))
assert d.log_prob(incidence).shape == (3,) |
||
batch_shape=(3,), | ||
event_shape=(3,))) | ||
|
||
model_info = initialize_model(random.PRNGKey(0), model) | ||
assert model_info.param_info.z['incidence'].shape == (3, 3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What I wanted to verify for this example is that with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me test it. I don't know what happens when users provide an invalid event_shape. edit: Oh, now I see what you and Fritz meant before. Here, support is batched... interesting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, the result is still There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, @neerajprad! I just added that test case and fixed the issue at
prototype_value = np.full(site['fn'].event_shape, np.nan)
unconstrained_event_shape = np.shape(transform.inv(prototype_value))
unconstrained_shape = site['fn'].batch_shape + unconstrained_event_shape
prototype_value = np.full(site['fn'].shape(), np.nan)
unconstrained_shape = np.shape(transform.inv(prototype_value)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great, that makes sense. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We expose this doc so I want to match PyTorch behavior here.