Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @as_model decorator #268

Merged
merged 10 commits into from
Nov 22, 2023
Merged
3 changes: 2 additions & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ methods in the current release of PyMC experimental.
.. autosummary::
:toctree: generated/

marginal_model.MarginalModel
as_model
MarginalModel
model_builder.ModelBuilder

Inference
Expand Down
3 changes: 2 additions & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@

from pymc_experimental import distributions, gp, utils
from pymc_experimental.inference.fit import fit
from pymc_experimental.marginal_model import MarginalModel
from pymc_experimental.model.marginal_model import MarginalModel
from pymc_experimental.model.model_api import as_model
Empty file.
46 changes: 46 additions & 0 deletions pymc_experimental/model/model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from functools import wraps

from pymc import Model


def as_model(*model_args, **model_kwargs):
R"""
Decorator to provide context to PyMC models declared in a function.
This removes all need to think about context managers and lets you separate creating a generative model from using the model.

Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.

Examples
--------
.. code:: python

import pymc as pm
import pymc_experimental as pmx

# The following are equivalent
theorashid marked this conversation as resolved.
Show resolved Hide resolved

# standard PyMC API with context manager
with pm.Model(coords={"obs": ["a", "b"]}) as model:
x = pm.Normal("x", 0., 1., dims="obs")
pm.sample()

# functional API using decorator
@pmx.as_model(coords={"obs": ["a", "b"]})
def basic_model():
pm.Normal("x", 0., 1., dims="obs")

m = basic_model()
pm.sample(model=m)

"""

def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also add a name kwarg here to change model name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass it to model_kwargs

with Model(*model_args, **model_kwargs) as m:
f(*args, **kwargs)
return m

return make_model

return decorator
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pymc.util import UNSET
from scipy.special import logsumexp

from pymc_experimental.marginal_model import (
from pymc_experimental.model.marginal_model import (
FiniteDiscreteMarginalRV,
MarginalModel,
is_conditional_dependent,
Expand Down
22 changes: 22 additions & 0 deletions pymc_experimental/tests/model/test_model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import pymc as pm

import pymc_experimental as pmx


def test_logp():
"""Compare standard PyMC `with pm.Model()` context API against `pmx.model` decorator
and a functional syntax. Checks whether the kwarg `coords` can be passed.
"""
coords = {"obs": ["a", "b"]}

with pm.Model(coords=coords) as model:
pm.Normal("x", 0.0, 1.0, dims="obs")

@pmx.as_model(coords=coords)
def model_wrapped():
pm.Normal("x", 0.0, 1.0, dims="obs")

mw = model_wrapped()

np.testing.assert_equal(model.point_logps(), mw.point_logps())
Loading