Skip to content

Commit

Permalink
Add @as_model decorator (#268)
Browse files Browse the repository at this point in the history

Co-authored-by: Rob Zinkov <8529+zaxtax@users.noreply.github.com>
  • Loading branch information
theorashid and zaxtax authored Nov 22, 2023
1 parent 5fc0463 commit 8046695
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ methods in the current release of PyMC experimental.
.. autosummary::
:toctree: generated/

marginal_model.MarginalModel
as_model
MarginalModel
model_builder.ModelBuilder

Inference
Expand Down
3 changes: 2 additions & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@

from pymc_experimental import distributions, gp, utils
from pymc_experimental.inference.fit import fit
from pymc_experimental.marginal_model import MarginalModel
from pymc_experimental.model.marginal_model import MarginalModel
from pymc_experimental.model.model_api import as_model
Empty file.
File renamed without changes.
46 changes: 46 additions & 0 deletions pymc_experimental/model/model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from functools import wraps

from pymc import Model


def as_model(*model_args, **model_kwargs):
R"""
Decorator to provide context to PyMC models declared in a function.
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
Examples
--------
.. code:: python
import pymc as pm
import pymc_experimental as pmx
# The following are equivalent
# standard PyMC API with context manager
with pm.Model(coords={"obs": ["a", "b"]}) as model:
x = pm.Normal("x", 0., 1., dims="obs")
pm.sample()
# functional API using decorator
@pmx.as_model(coords={"obs": ["a", "b"]})
def basic_model():
pm.Normal("x", 0., 1., dims="obs")
m = basic_model()
pm.sample(model=m)
"""

def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
with Model(*model_args, **model_kwargs) as m:
f(*args, **kwargs)
return m

return make_model

return decorator
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pymc.util import UNSET
from scipy.special import logsumexp

from pymc_experimental.marginal_model import (
from pymc_experimental.model.marginal_model import (
FiniteDiscreteMarginalRV,
MarginalModel,
is_conditional_dependent,
Expand Down
22 changes: 22 additions & 0 deletions pymc_experimental/tests/model/test_model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pymc as pm

import pymc_experimental as pmx


def test_logp():
"""Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator
and a functional syntax. Checks whether the kwarg `coords` can be passed.
"""
coords = {"obs": ["a", "b"]}

with pm.Model(coords=coords) as model:
pm.Normal("x", 0.0, 1.0, dims="obs")

@pmx.as_model(coords=coords)
def model_wrapped():
pm.Normal("x", 0.0, 1.0, dims="obs")

mw = model_wrapped()

np.testing.assert_equal(model.point_logps(), mw.point_logps())

0 comments on commit 8046695

Please sign in to comment.