-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Rob Zinkov <8529+zaxtax@users.noreply.github.com>
- Loading branch information
1 parent
5fc0463
commit 8046695
Showing
8 changed files
with
73 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from functools import wraps | ||
|
||
from pymc import Model | ||
|
||
|
||
def as_model(*model_args, **model_kwargs): | ||
R""" | ||
Decorator to provide context to PyMC models declared in a function. | ||
This removes all need to think about context managers and lets you separate creating a generative model from using the model. | ||
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3. | ||
Examples | ||
-------- | ||
.. code:: python | ||
import pymc as pm | ||
import pymc_experimental as pmx | ||
# The following are equivalent | ||
# standard PyMC API with context manager | ||
with pm.Model(coords={"obs": ["a", "b"]}) as model: | ||
x = pm.Normal("x", 0., 1., dims="obs") | ||
pm.sample() | ||
# functional API using decorator | ||
@pmx.as_model(coords={"obs": ["a", "b"]}) | ||
def basic_model(): | ||
pm.Normal("x", 0., 1., dims="obs") | ||
m = basic_model() | ||
pm.sample(model=m) | ||
""" | ||
|
||
def decorator(f): | ||
@wraps(f) | ||
def make_model(*args, **kwargs): | ||
with Model(*model_args, **model_kwargs) as m: | ||
f(*args, **kwargs) | ||
return m | ||
|
||
return make_model | ||
|
||
return decorator |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import pymc as pm | ||
|
||
import pymc_experimental as pmx | ||
|
||
|
||
def test_logp(): | ||
"""Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator | ||
and a functional syntax. Checks whether the kwarg `coords` can be passed. | ||
""" | ||
coords = {"obs": ["a", "b"]} | ||
|
||
with pm.Model(coords=coords) as model: | ||
pm.Normal("x", 0.0, 1.0, dims="obs") | ||
|
||
@pmx.as_model(coords=coords) | ||
def model_wrapped(): | ||
pm.Normal("x", 0.0, 1.0, dims="obs") | ||
|
||
mw = model_wrapped() | ||
|
||
np.testing.assert_equal(model.point_logps(), mw.point_logps()) |