diff --git a/pymc/model/core.py b/pymc/model/core.py index 3e59e723152..41d4e0864a9 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1577,6 +1577,41 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.named_vars or self.name_for(key) in self.named_vars + def __copy__(self): + return self.copy() + + def __deepcopy__(self, _): + return self.copy() + + def copy(self): + """ + Clone the model + + To access variables in the cloned model use `cloned_model["var_name"]`. + + Examples + -------- + .. code-block:: python + + import pymc as pm + import copy + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + + clone_m = copy.copy(m) + + # Access cloned variables by name + clone_x = clone_m["x"] + + # z will be part of clone_m but not m + z = pm.Deterministic("z", clone_x + 1) + """ + from pymc.model.fgraph import clone_model + + return clone_model(self) + def replace_rvs_by_values( self, graphs: Sequence[TensorVariable], diff --git a/pymc/model/fgraph.py b/pymc/model/fgraph.py index b1d67fd07b0..8c37861c8f3 100644 --- a/pymc/model/fgraph.py +++ b/pymc/model/fgraph.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + from copy import copy, deepcopy import pytensor @@ -158,6 +160,14 @@ def fgraph_from_model( "Nested sub-models cannot be converted to fgraph. Convert the parent model instead" ) + if any( + ("_rotated_" in var_name or "_hsgp_coeffs_" in var_name) for var_name in model.named_vars + ): + warnings.warn( + "Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883", + UserWarning, + ) + # Collect PyTensor variables rvs_to_values = model.rvs_to_values rvs = list(rvs_to_values.keys()) diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 95a46873fbb..cea9edf6471 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import pickle import threading import traceback @@ -1761,3 +1762,48 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: figsize=None, dpi=300, ) + + +class TestModelCopy: + @pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) + def test_copy_model(self, copy_method) -> None: + with pm.Model() as simple_model: + pm.Normal("y") + + copy_simple_model = copy_method(simple_model) + + with simple_model: + simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42) + + with copy_simple_model: + z = pm.Deterministic("z", copy_simple_model["y"] + 1) + copy_simple_model_prior_predictive = pm.sample_prior_predictive( + samples=1, random_seed=42 + ) + + assert ( + simple_model_prior_predictive["prior"]["y"].values + == copy_simple_model_prior_predictive["prior"]["y"].values + ) + + assert "z" in copy_simple_model.named_vars + assert "z" not in simple_model.named_vars + assert ( + copy_simple_model_prior_predictive["prior"]["z"].values + == 1 + simple_model_prior_predictive["prior"]["y"].values + ) + + @pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) + def test_guassian_process_copy_failure(self, copy_method) -> None: + with pm.Model() as gaussian_process_model: + ell = pm.Gamma("ell", alpha=2, beta=1) + cov = 2 * pm.gp.cov.ExpQuad(1, ell) + gp = pm.gp.Latent(cov_func=cov) + f = gp.prior("f", X=np.arange(10)[:, None]) + pm.Normal("y", f * 2) + + with pytest.warns( + UserWarning, + match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883", + ): + copy_method(gaussian_process_model)