Skip to content

Commit

Permalink
minor correction in sampling.py and starting.py (#4458)
Browse files Browse the repository at this point in the history
Make deepcopy of start dicts in pm.sample and `pm.find_MAP` to prevent inplace modification of user variables

closes #4456
  • Loading branch information
chandan5362 authored Feb 5, 2021
1 parent e467bb9 commit b6660f9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448))
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).

## PyMC3 3.11.0 (21 January 2021)

Expand Down
3 changes: 2 additions & 1 deletion pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings

from collections import defaultdict
from copy import copy
from copy import copy, deepcopy
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import arviz
Expand Down Expand Up @@ -423,6 +423,7 @@ def sample(
p 0.609 0.047 0.528 0.699
"""
model = modelcontext(model)
start = deepcopy(start)
if start is None:
check_start_vals(model.test_point, model)
else:
Expand Down
21 changes: 21 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,27 @@ def callback(trace, draw):
assert len(trace) == trace_cancel_length


def test_sample_find_MAP_does_not_modify_start():
# see https://github.com/pymc-devs/pymc3/pull/4458
with pm.Model():
pm.Lognormal("untransformed")

# make sure find_Map does not modify the start dict
start = {"untransformed": 2}
pm.find_MAP(start=start)
assert start == {"untransformed": 2}

# make sure sample does not modify the start dict
start = {"untransformed": 0.2}
pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=3)
assert start == {"untransformed": 0.2}

# make sure sample does not modify the start when passes as list of dict
start = [{"untransformed": 2}, {"untransformed": 0.2}]
pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=2)
assert start == [{"untransformed": 2}, {"untransformed": 0.2}]


def test_empty_model():
with pm.Model():
pm.Normal("a", observed=1)
Expand Down
4 changes: 3 additions & 1 deletion pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
@author: johnsalvatier
"""
import copy

import numpy as np
import theano.gradient as tg

Expand Down Expand Up @@ -96,7 +98,7 @@ def find_MAP(
vars = inputvars(vars)
disc_vars = list(typefilter(vars, discrete_types))
allinmodel(vars, model)

start = copy.deepcopy(start)
if start is None:
start = model.test_point
else:
Expand Down

0 comments on commit b6660f9

Please sign in to comment.