diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f6ebbe8513d..2a7b8f5796a 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -5,6 +5,8 @@ ### Maintenance - Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318) - Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)). +- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)). + ## PyMC3 3.10.0 (7 December 2020) diff --git a/pymc3/model.py b/pymc3/model.py index 882c32cd6f6..015c089f5e4 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -36,14 +36,7 @@ from pymc3.exceptions import ImputationWarning from pymc3.math import flatten_list from pymc3.memoize import WithMemoization, memoize -from pymc3.theanof import ( - floatX, - generator, - gradient, - hessian, - inputvars, - set_theano_conf, -) +from pymc3.theanof import floatX, generator, gradient, hessian, inputvars from pymc3.util import get_transformed_name, get_var_name from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter @@ -288,15 +281,17 @@ def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument def __enter__(self): self.__class__.context_class.get_contexts().append(self) # self._theano_config is set in Model.__new__ + self._config_context = None if hasattr(self, "_theano_config"): - self._old_theano_config = set_theano_conf(self._theano_config) + self._config_context = theano.change_flags(**self._theano_config) + self._config_context.__enter__() return self def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument self.__class__.context_class.get_contexts().pop() # self._theano_config is set in Model.__new__ - if hasattr(self, "_old_theano_config"): - set_theano_conf(self._old_theano_config) + if self._config_context: + self._config_context.__exit__(typ, value, traceback) dct[__enter__.__name__] = __enter__ dct[__exit__.__name__] = __exit__ diff --git a/pymc3/tests/test_theanof.py b/pymc3/tests/test_theanof.py index 8d51554adbc..d54aed680d8 100644 --- a/pymc3/tests/test_theanof.py +++ b/pymc3/tests/test_theanof.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections - from itertools import product import numpy as np @@ -21,7 +19,7 @@ import theano import theano.tensor as tt -from pymc3.theanof import _conversion_map, set_theano_conf, take_along_axis +from pymc3.theanof import _conversion_map, take_along_axis from pymc3.vartypes import int_types FLOATX = str(theano.config.floatX) @@ -72,27 +70,6 @@ def np_take_along_axis(arr, indices, axis): return arr[_make_along_axis_idx(arr.shape, indices, _axis)] -class TestSetTheanoConfig: - def test_invalid_key(self): - with pytest.raises(ValueError) as e: - set_theano_conf({"bad_key": True}) - e.match("Unknown") - - def test_restore_when_bad_key(self): - with theano.configparser.change_flags(compute_test_value="off"): - with pytest.raises(ValueError): - conf = collections.OrderedDict([("compute_test_value", "raise"), ("bad_key", True)]) - set_theano_conf(conf) - assert theano.config.compute_test_value == "off" - - def test_restore(self): - with theano.configparser.change_flags(compute_test_value="off"): - conf = set_theano_conf({"compute_test_value": "raise"}) - assert conf == {"compute_test_value": "off"} - conf = set_theano_conf(conf) - assert conf == {"compute_test_value": "raise"} - - class TestTakeAlongAxis: def setup_class(self): self.inputs_buffer = dict() diff --git a/pymc3/theanof.py b/pymc3/theanof.py index 9227d2f5ad4..817488b4fd2 100644 --- a/pymc3/theanof.py +++ b/pymc3/theanof.py @@ -15,9 +15,8 @@ import numpy as np import theano -from theano import scalar +from theano import change_flags, scalar from theano import tensor as tt -from theano.configparser import change_flags from theano.gof import Op from theano.gof.graph import inputs from theano.sandbox.rng_mrg import MRG_RandomStreams @@ -442,34 +441,6 @@ def floatX_array(x): return floatX(np.array(x)) -def set_theano_conf(values): - """Change the theano configuration and return old values. - - This is similar to `theano.configparser.change_flags`, but it - returns the original values in a pickleable form. - """ - variables = {} - unknown = set(values.keys()) - for variable in theano.configparser._config_var_list: - if variable.fullname in values: - variables[variable.fullname] = variable - unknown.remove(variable.fullname) - if len(unknown) > 0: - raise ValueError("Unknown theano config settings: %s" % unknown) - - old = {} - for name, variable in variables.items(): - old_value = variable.__get__(True, None) - try: - variable.__set__(None, values[name]) - except Exception: - for key, old_value in old.items(): - variables[key].__set__(None, old_value) - raise - old[name] = old_value - return old - - def ix_(*args): """ Theano np.ix_ analog