Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor correction in sampling.py and starting.py #4458

Merged
merged 8 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings

from collections import defaultdict
from copy import copy
from copy import copy, deepcopy
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import arviz
Expand Down Expand Up @@ -423,12 +423,15 @@ def sample(
p 0.609 0.047 0.528 0.699
"""
model = modelcontext(model)
start = deepcopy(start)
if start is None:
check_start_vals(model.test_point, model)
else:
if isinstance(start, dict):
start = {k: v for k, v in start.items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply start.copy()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was suggesting initially,. Anyway, I will replace that with deepcopy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed now.

update_start_vals(start, model.test_point, model)
else:
start = start[:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will still change the dictionary inplace.

a = [dict(a=1, b=2), dict(a=1, b=2)]
b = a[:]
for b_ in b:
    b_['c'] = 3
    
a
[{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'c': 3}]

Copy link
Member

@ricardoV94 ricardoV94 Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try start = [s.copy() for s in start]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I think @chandan5362 original suggestion to use deepcopy was good

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the function in question only adds new keys, but does not change keys already in place, a shallow copy (and nested shallow copy) should be fine. But I have no strong objections to deepcopy

Copy link
Contributor Author

@chandan5362 chandan5362 Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we could use start.copy() but to stay on the safer side, we should probably use deepcopy though it does not make any sense here. Also we won't have to use start.copy inside list comprehension if we use deepcopy.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, deepcopy(None) == None, so it can go even before any is not None or isinstance checks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can also be removed now.

for chain_start_vals in start:
update_start_vals(chain_start_vals, model.test_point, model)
check_start_vals(start, model)
Expand Down
31 changes: 31 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,37 @@ def test_iter_sample(self):
for i, trace in enumerate(samps):
assert i == len(trace) - 1, "Trace does not have correct length."

def test_sample_does_not_modify_start_as_list_of_dicts(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test fail in master? It looks like no transforms would be added in this test model and therefore the dictionary wouldn't be changed anyway, but I could be wrong. I imagined it worked something like your test below but with a dictionary for each chain:

start_dict = [{"X0_mu": 25}, {"X0_mu": 25}]
with pm.model() as m:
   X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10)
   trace = pm.sample(
                step=pm.Metropolis(),
                tune=5,
                draws=10,
                chains=2,
                start=start_dict,
            )
  assert start_dict == ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, you made one parameter be missing on purpose in each chain...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this test fail in master? It looks like no transforms would be added in this test model and therefore the dictionary wouldn't be changed anyway.

unfortunately, it updates the dictionary with model.test_point even though no transformed variable is there to be added.

# make sure pm.sample does not modify the 'start_list' passed as an argument
# see https://github.com/pymc-devs/pymc3/pull/4458
start_list = [{"mu1": 10}, {"mu2": 15}]
with self.model:
mu1 = pm.Normal("mu1", mu=0, sd=5)
mu2 = pm.Normal("mu2", mu=0, sd=1)
trace = pm.sample(
step=pm.Metropolis(),
tune=5,
draws=10,
chains=2,
start=start_list,
)
assert start_list == [{"mu1": 10}, {"mu2": 15}]

def test_sample_does_not_modify_start_as_dict(self):
# make sure pm.sample does not modify the 'start_dict' passed as an argument.
# see https://github.com/pymc-devs/pymc3/pull/4458
start_dict = {"X0_mu": 25}
with self.model:
X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10)
trace = pm.sample(
step=pm.Metropolis(),
tune=5,
draws=10,
chains=3,
start=start_dict,
)
assert start_dict == {"X0_mu": 25}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not overcomplicate the tests. Also we should test both pm.sample and pm.find_MAP.

  • no need to reuse the complicated self.model model
  • variable names and parameters don't matter
  • distribution should be transformed by default (Uniform, Lognormal, ...)
  • everything in the same test case so the compilation is done just once
with pm.Model():
    pm.Lognormal("untransformed")

    # test that find_MAP doesn't change the start dict
    start = {"untransformed": 2}
    pm.find_MAP(start=start, niter=5)
    assert start == {"untransformed": 2}
    
    # check that sample doesn't change it either
    start = {"untransformed": 0.5}
    ...

    # and also not if start is different for each chain
    start = [{"untransformed": 2}, {"untransformed": 0.5}]
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did remove these lines but It came back from nowhere (I might have fetched before committing).
Anyway, this time I will make sure that these lines does not come back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if we take this test outside from the TestSample class.

def test_parallel_start(self):
with self.model:
tr = pm.sample(
Expand Down
4 changes: 3 additions & 1 deletion pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

@author: johnsalvatier
"""
import copy

import numpy as np
import theano.gradient as tg

Expand Down Expand Up @@ -96,7 +98,7 @@ def find_MAP(
vars = inputvars(vars)
disc_vars = list(typefilter(vars, discrete_types))
allinmodel(vars, model)

start = copy.deepcopy(start)
if start is None:
start = model.test_point
else:
Expand Down