From d6c54bead41383653db6395a14cd6659141662b0 Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Wed, 3 Feb 2021 14:26:32 +0530 Subject: [PATCH 1/8] minor correction in sampling.py and starting.py --- pymc3/sampling.py | 2 ++ pymc3/tuning/starting.py | 1 + 2 files changed, 3 insertions(+) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index bc77113772f..733896e2ba3 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -427,8 +427,10 @@ def sample( check_start_vals(model.test_point, model) else: if isinstance(start, dict): + start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) else: + start = start[:] for chain_start_vals in start: update_start_vals(chain_start_vals, model.test_point, model) check_start_vals(start, model) diff --git a/pymc3/tuning/starting.py b/pymc3/tuning/starting.py index 6ace7dc3b5e..b6894bffcc0 100644 --- a/pymc3/tuning/starting.py +++ b/pymc3/tuning/starting.py @@ -100,6 +100,7 @@ def find_MAP( if start is None: start = model.test_point else: + start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) check_start_vals(start, model) From c5d55511cfc0e6b301be868f6aa488fcd217b1e1 Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Wed, 3 Feb 2021 23:57:19 +0530 Subject: [PATCH 2/8] test added and copy.deepcopy used --- pymc3/sampling.py | 4 ++-- pymc3/tests/test_sampling.py | 11 +++++++++++ pymc3/tuning/starting.py | 5 +++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 733896e2ba3..ea60a9c78cb 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -15,6 +15,7 @@ """Functions for MCMC sampling.""" import collections.abc as abc +import copy import logging import pickle import sys @@ -423,14 +424,13 @@ def sample( p 0.609 0.047 0.528 0.699 """ model = modelcontext(model) + start = copy.deepcopy(start) if start is None: check_start_vals(model.test_point, model) else: if isinstance(start, dict): - start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) else: - start = start[:] for chain_start_vals in start: update_start_vals(chain_start_vals, model.test_point, model) check_start_vals(start, model) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index c95ad230cb5..7ecfad05d8f 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -121,6 +121,17 @@ def test_iter_sample(self): for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." + def test_sample_does_not_modify_start(self): + start_dict = {"X0_mu": 25} + with self.model: + X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10) + trace = pm.sample( + tune=50, + draws=100, + start=start_dict, + ) + assert len(start_dict) == 1 + def test_parallel_start(self): with self.model: tr = pm.sample( diff --git a/pymc3/tuning/starting.py b/pymc3/tuning/starting.py index b6894bffcc0..2a800b2b4dd 100644 --- a/pymc3/tuning/starting.py +++ b/pymc3/tuning/starting.py @@ -17,6 +17,8 @@ @author: johnsalvatier """ +import copy + import numpy as np import theano.gradient as tg @@ -96,11 +98,10 @@ def find_MAP( vars = inputvars(vars) disc_vars = list(typefilter(vars, discrete_types)) allinmodel(vars, model) - + start = copy.deepcopy(start) if start is None: start = model.test_point else: - start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) check_start_vals(start, model) From c72090a8b5f2bf04083d8e092f102eb6b06fcfa1 Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Wed, 3 Feb 2021 14:26:32 +0530 Subject: [PATCH 3/8] minor correction in sampling.py and starting.py --- pymc3/sampling.py | 2 ++ pymc3/tuning/starting.py | 1 + 2 files changed, 3 insertions(+) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index ea60a9c78cb..6931e5b7b41 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -429,8 +429,10 @@ def sample( check_start_vals(model.test_point, model) else: if isinstance(start, dict): + start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) else: + start = start[:] for chain_start_vals in start: update_start_vals(chain_start_vals, model.test_point, model) check_start_vals(start, model) diff --git a/pymc3/tuning/starting.py b/pymc3/tuning/starting.py index 2a800b2b4dd..63638c2105f 100644 --- a/pymc3/tuning/starting.py +++ b/pymc3/tuning/starting.py @@ -102,6 +102,7 @@ def find_MAP( if start is None: start = model.test_point else: + start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) check_start_vals(start, model) From 3f84d41781c7330964c4dcad4eab2c1a6caf99d3 Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Thu, 4 Feb 2021 00:57:42 +0530 Subject: [PATCH 4/8] --- pymc3/sampling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 6931e5b7b41..9954abe0681 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -15,7 +15,6 @@ """Functions for MCMC sampling.""" import collections.abc as abc -import copy import logging import pickle import sys @@ -23,7 +22,7 @@ import warnings from collections import defaultdict -from copy import copy +from copy import copy, deepcopy from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast import arviz @@ -424,7 +423,7 @@ def sample( p 0.609 0.047 0.528 0.699 """ model = modelcontext(model) - start = copy.deepcopy(start) + start = deepcopy(start) if start is None: check_start_vals(model.test_point, model) else: From afa18fb939bf8954201e0daa2ea205a7b164235a Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Fri, 5 Feb 2021 01:41:27 +0530 Subject: [PATCH 5/8] further test added in sampling,py --- pymc3/tests/test_sampling.py | 28 ++++++++++++++++++++++++---- pymc3/tuning/starting.py | 1 - 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 7ecfad05d8f..473d612adeb 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -121,16 +121,36 @@ def test_iter_sample(self): for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." - def test_sample_does_not_modify_start(self): + def test_sample_does_not_modify_start_as_list_of_dicts(self): + # make sure pm.sample does not modify the 'start_list' passed as an argument + # see https://github.com/pymc-devs/pymc3/pull/4458 + start_list = [{"mu1": 10}, {"mu2": 15}] + with self.model: + mu1 = pm.Normal("mu1", mu=0, sd=5) + mu2 = pm.Normal("mu2", mu=0, sd=1) + trace = pm.sample( + step=pm.Metropolis(), + tune=5, + draws=10, + chains=2, + start=start_list, + ) + assert start_list == [{"mu1": 10}, {"mu2": 15}] + + def test_sample_does_not_modify_start_as_dict(self): + # make sure pm.sample does not modify the 'start_dict' passed as an argument. + # see https://github.com/pymc-devs/pymc3/pull/4458 start_dict = {"X0_mu": 25} with self.model: X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10) trace = pm.sample( - tune=50, - draws=100, + step=pm.Metropolis(), + tune=5, + draws=10, + chains=3, start=start_dict, ) - assert len(start_dict) == 1 + assert start_dict == {"X0_mu": 25} def test_parallel_start(self): with self.model: diff --git a/pymc3/tuning/starting.py b/pymc3/tuning/starting.py index 63638c2105f..2a800b2b4dd 100644 --- a/pymc3/tuning/starting.py +++ b/pymc3/tuning/starting.py @@ -102,7 +102,6 @@ def find_MAP( if start is None: start = model.test_point else: - start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) check_start_vals(start, model) From ae13cd3ebd94a943b738d1f905ab8a14e037f7c5 Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Fri, 5 Feb 2021 11:12:16 +0530 Subject: [PATCH 6/8] further improvement of test --- pymc3/sampling.py | 2 -- pymc3/tests/test_sampling.py | 52 +++++++++++++++--------------------- 2 files changed, 21 insertions(+), 33 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 9954abe0681..481d20ff034 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -428,10 +428,8 @@ def sample( check_start_vals(model.test_point, model) else: if isinstance(start, dict): - start = {k: v for k, v in start.items()} update_start_vals(start, model.test_point, model) else: - start = start[:] for chain_start_vals in start: update_start_vals(chain_start_vals, model.test_point, model) check_start_vals(start, model) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 473d612adeb..6ef1899bff5 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -121,37 +121,6 @@ def test_iter_sample(self): for i, trace in enumerate(samps): assert i == len(trace) - 1, "Trace does not have correct length." - def test_sample_does_not_modify_start_as_list_of_dicts(self): - # make sure pm.sample does not modify the 'start_list' passed as an argument - # see https://github.com/pymc-devs/pymc3/pull/4458 - start_list = [{"mu1": 10}, {"mu2": 15}] - with self.model: - mu1 = pm.Normal("mu1", mu=0, sd=5) - mu2 = pm.Normal("mu2", mu=0, sd=1) - trace = pm.sample( - step=pm.Metropolis(), - tune=5, - draws=10, - chains=2, - start=start_list, - ) - assert start_list == [{"mu1": 10}, {"mu2": 15}] - - def test_sample_does_not_modify_start_as_dict(self): - # make sure pm.sample does not modify the 'start_dict' passed as an argument. - # see https://github.com/pymc-devs/pymc3/pull/4458 - start_dict = {"X0_mu": 25} - with self.model: - X0_mu = pm.Lognormal("X0_mu", mu=np.log(0.25), sd=0.10) - trace = pm.sample( - step=pm.Metropolis(), - tune=5, - draws=10, - chains=3, - start=start_dict, - ) - assert start_dict == {"X0_mu": 25} - def test_parallel_start(self): with self.model: tr = pm.sample( @@ -316,6 +285,27 @@ def callback(trace, draw): assert len(trace) == trace_cancel_length +def test_sample_find_MAP_does_not_modify_start(): + # see https://github.com/pymc-devs/pymc3/pull/4458 + with pm.Model(): + pm.Lognormal("untransformed") + + # make sure find_Map does not modify the start dict + start = {"untransformed": 2} + pm.find_MAP(start=start) + assert start == {"untransformed": 2} + + # make sure sample does not modify the start dict + start = {"untransformed": 0.2} + pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=3) + assert start == {"untransformed": 0.2} + + # make sure sample does not modify the start when passes as dict + start = [{"untransformed": 2}, {"untransformed": 0.2}] + pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=2) + assert start == [{"untransformed": 2}, {"untransformed": 0.2}] + + def test_empty_model(): with pm.Model(): pm.Normal("a", observed=1) From ea647929ea9a72495e5717b0b4646aa441b1fd8c Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Fri, 5 Feb 2021 11:16:58 +0530 Subject: [PATCH 7/8] --- pymc3/tests/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 6ef1899bff5..f3f2872c442 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -300,7 +300,7 @@ def test_sample_find_MAP_does_not_modify_start(): pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=3) assert start == {"untransformed": 0.2} - # make sure sample does not modify the start when passes as dict + # make sure sample does not modify the start when passes as list of dict start = [{"untransformed": 2}, {"untransformed": 0.2}] pm.sample(draws=10, step=pm.Metropolis(), tune=5, start=start, chains=2) assert start == [{"untransformed": 2}, {"untransformed": 0.2}] From 3270b450c9ffeb972371a69f8a54eb156cdb53bf Mon Sep 17 00:00:00 2001 From: chandan5362 Date: Fri, 5 Feb 2021 15:42:08 +0530 Subject: [PATCH 8/8] RELEASe-NOTES.md updated --- RELEASE-NOTES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index c931ddf2281..1223b547044 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -12,6 +12,7 @@ - `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448)) - `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)). - `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)). +- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)). ## PyMC3 3.11.0 (21 January 2021)