Skip to content

Commit

Permalink
Deprecate pytensor_config
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 24, 2023
1 parent bbfd69e commit 0b85d02
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 14 deletions.
8 changes: 7 additions & 1 deletion pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,13 @@ def __new__(cls, *args, **kwargs):
instance._parent = kwargs.get("model")
else:
instance._parent = cls.get_context(error_if_none=False)
instance._pytensor_config = kwargs.get("pytensor_config", {})
pytensor_config = kwargs.get("pytensor_config", {})
if pytensor_config:
warnings.warn(
"pytensor_config is deprecated. Use pytensor.config or pytensor.config.change_flags context manager instead.",
FutureWarning,
)
instance._pytensor_config = pytensor_config
return instance

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,9 @@ def test_compile_fn():

def test_model_pytensor_config():
assert pytensor.config.mode != "JAX"
with pm.Model(pytensor_config=dict(mode="JAX")) as model:
with pytest.warns(FutureWarning, match="pytensor_config is deprecated"):
m = pm.Model(pytensor_config=dict(mode="JAX"))
with m:
assert pytensor.config.mode == "JAX"
assert pytensor.config.mode != "JAX"

Expand Down
8 changes: 0 additions & 8 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,14 +797,6 @@ def test_step_vars_in_model(self):
class TestType:
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)

def setup_method(self):
# save PyTensor config object
self.pytensor_config = copy(pytensor.config)

def teardown_method(self):
# restore pytensor config
pytensor.config = self.pytensor_config

@pytensor.config.change_flags({"floatX": "float64", "warn_float64": "ignore"})
def test_float64(self):
with pm.Model() as model:
Expand Down
4 changes: 0 additions & 4 deletions tests/variational/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def test_scale_cost_to_minibatch_works(aux_total_size):
y_obs = np.array([1.6, 1.4])
beta = len(y_obs) / float(aux_total_size)

# TODO: pytensor_config
# with pm.Model(pytensor_config=dict(floatX='float64')):
# did not not work as expected
# there were some numeric problems, so float64 is forced
with pytensor.config.change_flags(floatX="float64", warn_float64="ignore"):
assert pytensor.config.floatX == "float64"
assert pytensor.config.warn_float64 == "ignore"
Expand Down

0 comments on commit 0b85d02

Please sign in to comment.