-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
minor correction in sampling.py and starting.py #4458
Changes from 5 commits
d6c54be
c5d5551
c72090a
3f84d41
afa18fb
ae13cd3
ea64792
3270b45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
import warnings | ||
|
||
from collections import defaultdict | ||
from copy import copy | ||
from copy import copy, deepcopy | ||
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast | ||
|
||
import arviz | ||
|
@@ -423,12 +423,15 @@ def sample( | |
p 0.609 0.047 0.528 0.699 | ||
""" | ||
model = modelcontext(model) | ||
start = deepcopy(start) | ||
if start is None: | ||
check_start_vals(model.test_point, model) | ||
else: | ||
if isinstance(start, dict): | ||
start = {k: v for k, v in start.items()} | ||
update_start_vals(start, model.test_point, model) | ||
else: | ||
start = start[:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will still change the dictionary inplace. a = [dict(a=1, b=2), dict(a=1, b=2)]
b = a[:]
for b_ in b:
b_['c'] = 3
a
[{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'c': 3}] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can try There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I think @chandan5362 original suggestion to use deepcopy was good There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the function in question only adds new keys, but does not change keys already in place, a shallow copy (and nested shallow copy) should be fine. But I have no strong objections to deepcopy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, we could use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line can also be removed now. |
||
for chain_start_vals in start: | ||
update_start_vals(chain_start_vals, model.test_point, model) | ||
check_start_vals(start, model) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,6 +121,37 @@ def test_iter_sample(self): | |
for i, trace in enumerate(samps): | ||
assert i == len(trace) - 1, "Trace does not have correct length." | ||
|
||
def test_sample_does_not_modify_start_as_list_of_dicts(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this test fail in master? It looks like no transforms would be added in this test model and therefore the dictionary wouldn't be changed anyway, but I could be wrong. I imagined it worked something like your test below but with a dictionary for each chain: start_dict = [{"X0_mu": 25}, {"X0_mu": 25}]
with pm.model() as m:
X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10)
trace = pm.sample(
step=pm.Metropolis(),
tune=5,
draws=10,
chains=2,
start=start_dict,
)
assert start_dict == ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, you made one parameter be missing on purpose in each chain... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
unfortunately, it updates the dictionary with |
||
# make sure pm.sample does not modify the 'start_list' passed as an argument | ||
# see https://github.com/pymc-devs/pymc3/pull/4458 | ||
start_list = [{"mu1": 10}, {"mu2": 15}] | ||
with self.model: | ||
mu1 = pm.Normal("mu1", mu=0, sd=5) | ||
mu2 = pm.Normal("mu2", mu=0, sd=1) | ||
trace = pm.sample( | ||
step=pm.Metropolis(), | ||
tune=5, | ||
draws=10, | ||
chains=2, | ||
start=start_list, | ||
) | ||
assert start_list == [{"mu1": 10}, {"mu2": 15}] | ||
|
||
def test_sample_does_not_modify_start_as_dict(self): | ||
# make sure pm.sample does not modify the 'start_dict' passed as an argument. | ||
# see https://github.com/pymc-devs/pymc3/pull/4458 | ||
start_dict = {"X0_mu": 25} | ||
with self.model: | ||
X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10) | ||
trace = pm.sample( | ||
step=pm.Metropolis(), | ||
tune=5, | ||
draws=10, | ||
chains=3, | ||
start=start_dict, | ||
) | ||
assert start_dict == {"X0_mu": 25} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not overcomplicate the tests. Also we should test both
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did remove these lines but It came back from nowhere (I might have fetched before committing). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better if we take this test outside from the |
||
def test_parallel_start(self): | ||
with self.model: | ||
tr = pm.sample( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply
start.copy()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what I was suggesting initially,. Anyway, I will replace that with
deepcopy
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can be removed now.