Skip to content

Commit

Permalink
Allow copy and deepcopy of PYMC models (#7492)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekermanjian authored Oct 3, 2024
1 parent 67f43ae commit cdcdb58
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
35 changes: 35 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,41 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.named_vars or self.name_for(key) in self.named_vars

def __copy__(self):
return self.copy()

def __deepcopy__(self, _):
return self.copy()

def copy(self):
"""
Clone the model
To access variables in the cloned model use `cloned_model["var_name"]`.
Examples
--------
.. code-block:: python
import pymc as pm
import copy
with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))
clone_m = copy.copy(m)
# Access cloned variables by name
clone_x = clone_m["x"]
# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

return clone_model(self)

def replace_rvs_by_values(
self,
graphs: Sequence[TensorVariable],
Expand Down
10 changes: 10 additions & 0 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from copy import copy, deepcopy

import pytensor
Expand Down Expand Up @@ -158,6 +160,14 @@ def fgraph_from_model(
"Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
)

if any(
("_rotated_" in var_name or "_hsgp_coeffs_" in var_name) for var_name in model.named_vars
):
warnings.warn(
"Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
UserWarning,
)

# Collect PyTensor variables
rvs_to_values = model.rvs_to_values
rvs = list(rvs_to_values.keys())
Expand Down
46 changes: 46 additions & 0 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pickle
import threading
import traceback
Expand Down Expand Up @@ -1761,3 +1762,48 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
figsize=None,
dpi=300,
)


class TestModelCopy:
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
with pm.Model() as simple_model:
pm.Normal("y")

copy_simple_model = copy_method(simple_model)

with simple_model:
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42)

with copy_simple_model:
z = pm.Deterministic("z", copy_simple_model["y"] + 1)
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)

assert (
simple_model_prior_predictive["prior"]["y"].values
== copy_simple_model_prior_predictive["prior"]["y"].values
)

assert "z" in copy_simple_model.named_vars
assert "z" not in simple_model.named_vars
assert (
copy_simple_model_prior_predictive["prior"]["z"].values
== 1 + simple_model_prior_predictive["prior"]["y"].values
)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_guassian_process_copy_failure(self, copy_method) -> None:
with pm.Model() as gaussian_process_model:
ell = pm.Gamma("ell", alpha=2, beta=1)
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)

with pytest.warns(
UserWarning,
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
):
copy_method(gaussian_process_model)

0 comments on commit cdcdb58

Please sign in to comment.