Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove theanof.set_theano_conf and instead use the config context #4329

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Maintenance
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).


## PyMC3 3.10.0 (7 December 2020)

Expand Down
17 changes: 6 additions & 11 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@
from pymc3.exceptions import ImputationWarning
from pymc3.math import flatten_list
from pymc3.memoize import WithMemoization, memoize
from pymc3.theanof import (
floatX,
generator,
gradient,
hessian,
inputvars,
set_theano_conf,
)
from pymc3.theanof import floatX, generator, gradient, hessian, inputvars
from pymc3.util import get_transformed_name, get_var_name
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter

Expand Down Expand Up @@ -288,15 +281,17 @@ def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument
def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
# self._theano_config is set in Model.__new__
self._config_context = None
if hasattr(self, "_theano_config"):
self._old_theano_config = set_theano_conf(self._theano_config)
self._config_context = theano.change_flags(**self._theano_config)
self._config_context.__enter__()
return self

def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument
self.__class__.context_class.get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, "_old_theano_config"):
set_theano_conf(self._old_theano_config)
if self._config_context:
self._config_context.__exit__(typ, value, traceback)

dct[__enter__.__name__] = __enter__
dct[__exit__.__name__] = __exit__
Expand Down
25 changes: 1 addition & 24 deletions pymc3/tests/test_theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections

from itertools import product

import numpy as np
import pytest
import theano
import theano.tensor as tt

from pymc3.theanof import _conversion_map, set_theano_conf, take_along_axis
from pymc3.theanof import _conversion_map, take_along_axis
from pymc3.vartypes import int_types

FLOATX = str(theano.config.floatX)
Expand Down Expand Up @@ -72,27 +70,6 @@ def np_take_along_axis(arr, indices, axis):
return arr[_make_along_axis_idx(arr.shape, indices, _axis)]


class TestSetTheanoConfig:
def test_invalid_key(self):
with pytest.raises(ValueError) as e:
set_theano_conf({"bad_key": True})
e.match("Unknown")

def test_restore_when_bad_key(self):
with theano.configparser.change_flags(compute_test_value="off"):
with pytest.raises(ValueError):
conf = collections.OrderedDict([("compute_test_value", "raise"), ("bad_key", True)])
set_theano_conf(conf)
assert theano.config.compute_test_value == "off"

def test_restore(self):
with theano.configparser.change_flags(compute_test_value="off"):
conf = set_theano_conf({"compute_test_value": "raise"})
assert conf == {"compute_test_value": "off"}
conf = set_theano_conf(conf)
assert conf == {"compute_test_value": "raise"}


class TestTakeAlongAxis:
def setup_class(self):
self.inputs_buffer = dict()
Expand Down
31 changes: 1 addition & 30 deletions pymc3/theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import numpy as np
import theano

from theano import scalar
from theano import change_flags, scalar
from theano import tensor as tt
from theano.configparser import change_flags
from theano.gof import Op
from theano.gof.graph import inputs
from theano.sandbox.rng_mrg import MRG_RandomStreams
Expand Down Expand Up @@ -442,34 +441,6 @@ def floatX_array(x):
return floatX(np.array(x))


def set_theano_conf(values):
"""Change the theano configuration and return old values.

This is similar to `theano.configparser.change_flags`, but it
returns the original values in a pickleable form.
"""
variables = {}
unknown = set(values.keys())
for variable in theano.configparser._config_var_list:
if variable.fullname in values:
variables[variable.fullname] = variable
unknown.remove(variable.fullname)
if len(unknown) > 0:
raise ValueError("Unknown theano config settings: %s" % unknown)

old = {}
for name, variable in variables.items():
old_value = variable.__get__(True, None)
try:
variable.__set__(None, values[name])
except Exception:
for key, old_value in old.items():
variables[key].__set__(None, old_value)
raise
old[name] = old_value
return old


def ix_(*args):
"""
Theano np.ix_ analog
Expand Down