Skip to content

Commit

Permalink
Simplify Model __new__ and metaclass (pymc-devs#7473)
Browse files Browse the repository at this point in the history
* Type get_context correctly

get_context returns an instance of a Model, not a ContextMeta object
We don't need the typevar, since we don't use it for anything special

* Import from future to use delayed evaluation of annotations

All of these are supported on python>=3.9.

* New ModelManager class for managing model contexts

We create a global instance of it within this module, which is similar
to how it worked before, where a `context_class` attribute was attached
to the Model class.

We inherit from threading.local to ensure thread safety when working
with models on multiple threads. See pymc-devs#1552 for the reasoning. This is
already tested in `test_thread_safety`.

* Model class is now the context manager directly

* Fix type of UNSET in type definition

UNSET is the instance of the _UnsetType type.
We should be typing the latter here.

* Set model parent in init rather than in __new__

We use the new ModelManager.parent_context property to reliably set any
parent context, or else set it to None.

* Replace get_context in metaclass with classmethod

We set this directly on the class as a classmethod, which is clearer
than going via the metaclass.

* Remove get_contexts from metaclass

The original function does not behave as I expected.
In the following example I expected that it would return only the final
model, not root.

This method is not used anywhere in the pymc codebase, so I have dropped
it from the codebase. I originally included the following code to replace
it, but since it is not used anyway, it is better to remove it.

```python`
@classmethod
def get_contexts(cls) -> list[Model]:
    """Return a list of the currently active model contexts."""
    return MODEL_MANAGER.active_contexts
```

Example for testing behaviour in current main branch:
```python
import pymc as pm

with pm.Model(name="root") as root:
    print([c.name for c in pm.Model.get_contexts()])
    with pm.Model(name="first") as first:
        print([c.name for c in pm.Model.get_contexts()])
    with pm.Model(name="m_with_model_None", model=None) as m_with_model_None:
        # This one doesn't make much sense:
        print([c.name for c in pm.Model.get_contexts()])
```

* Simplify ContextMeta

We only keep the __call__ method, which is necessary to keep the
model context itself active during that model's __init__.

* Type Model.register_rv for for downstream typing

In pymc/distributions/distribution.py, this change allows the type
checker to infer that `rv_out` can only be a TensorVariable.

Thanks to @ricardoV94 for type hint on rv_var.

* Include np.ndarray as possible type for coord values

I originally tried numpy's ArrayLike, replacing Sequence entirely, but then I realized
that ArrayLike also allows non-sequences like integers and floats.

I am not certain if `values="a string"` should be legal. With the type hint sequence, it is.
Might be more accurate, but verbose to use `list | tuple | set | np.ndarray | None`.

* Use function-scoped new_dims to handle type hint varying throughout function

We don't want to allow the user to pass a `dims=[None, None]` to our function, but current behaviour
set `dims=[None] * N` at the end of `determine_coords`.

To handle this, I created a `new_dims` with a larger type scope which matches
the return type of `dims` in `determine_coords`.

Then I did the same within def Data to support this new type hint.

* Fix case of dims = [None, None, ...]

The only case where dims=[None, ...] is when the user has passed dims=None. Since the user passed dims=None,
they shouldn't be expecting any coords to match that dimension. Thus we don't need to try to add any
more coords to the model.

* Remove unused hack
  • Loading branch information
thomasaarholt authored and mkusnetsov committed Oct 26, 2024
1 parent 8d00bb2 commit d843bae
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 146 deletions.
26 changes: 15 additions & 11 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
def determine_coords(
model,
value: pd.DataFrame | pd.Series | xr.DataArray,
dims: Sequence[str | None] | None = None,
dims: Sequence[str] | None = None,
coords: dict[str, Sequence | np.ndarray] | None = None,
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]:
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
"""Determine coordinate values from data or the model (via ``dims``)."""
if coords is None:
coords = {}
Expand Down Expand Up @@ -268,9 +268,10 @@ def determine_coords(

if dims is None:
# TODO: Also determine dim names from the index
dims = [None] * np.ndim(value)

return coords, dims
new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value)
else:
new_dims = dims
return coords, new_dims


def ConstantData(
Expand Down Expand Up @@ -366,7 +367,7 @@ def Data(
The name for this variable.
value : array_like or pandas.Series, pandas.Dataframe
A value to associate with this variable.
dims : str or tuple of str, optional
dims : str, tuple of str or tuple of None, optional
Dimension names of the random variables (as opposed to the shapes of these
random variables). Use this when ``value`` is a pandas Series or DataFrame. The
``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ
Expand Down Expand Up @@ -451,14 +452,17 @@ def Data(
expected=x.ndim,
)

new_dims: Sequence[str] | Sequence[None] | None
if infer_dims_and_coords:
coords, dims = determine_coords(model, value, dims)
coords, new_dims = determine_coords(model, value, dims)
else:
new_dims = dims

if dims:
if new_dims:
xshape = x.shape
# Register new dimension lengths
for d, dname in enumerate(dims):
if dname not in model.dim_lengths:
for d, dname in enumerate(new_dims):
if dname not in model.dim_lengths and dname is not None:
model.add_coord(
name=dname,
# Note: Coordinate values can't be taken from
Expand All @@ -467,6 +471,6 @@ def Data(
length=xshape[d],
)

model.register_data_var(x, dims=dims)
model.register_data_var(x, dims=new_dims)

return x
194 changes: 59 additions & 135 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
import sys
Expand All @@ -19,13 +20,8 @@
import warnings

from collections.abc import Iterable, Sequence
from sys import modules
from typing import (
TYPE_CHECKING,
Literal,
Optional,
TypeVar,
Union,
cast,
overload,
)
Expand All @@ -42,7 +38,6 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from typing_extensions import Self

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import is_valid_observed
Expand Down Expand Up @@ -73,6 +68,7 @@
VarName,
WithMemoization,
_add_future_warning_tag,
_UnsetType,
get_transformed_name,
get_value_vars_from_user_vars,
get_var_name,
Expand All @@ -92,118 +88,36 @@
]


T = TypeVar("T", bound="ContextMeta")
class ModelManager(threading.local):
"""Keeps track of currently active model contexts.
A global instance of this is created in this module on import.
Use that instance, `MODEL_MANAGER` to inspect current contexts.
class ContextMeta(type):
"""Functionality for objects that put themselves in a context manager."""

def __new__(cls, name, bases, dct, **kwargs):
"""Add __enter__ and __exit__ methods to the class."""

def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
return self

def __exit__(self, typ, value, traceback):
self.__class__.context_class.get_contexts().pop()
It inherits from threading.local so is thread-safe, if models
can be entered/exited within individual threads.
"""

dct[__enter__.__name__] = __enter__
dct[__exit__.__name__] = __exit__
def __init__(self):
self.active_contexts: list[Model] = []

# We strip off keyword args, per the warning from
# StackExchange:
# DO NOT send "**kwargs" to "type.__new__". It won't catch them and
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
return super().__new__(cls, name, bases, dct)
@property
def current_context(self) -> Model | None:
"""Return the innermost context of any current contexts."""
return self.active_contexts[-1] if self.active_contexts else None

# FIXME: is there a more elegant way to automatically add methods to the class that
# are instance methods instead of class methods?
def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs):
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
if context_class is not None:
cls._context_class = context_class
super().__init__(name, bases, nmspc)
@property
def parent_context(self) -> Model | None:
"""Return the parent context to the active context, if any."""
return self.active_contexts[-2] if len(self.active_contexts) > 1 else None

def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None:
"""Return the most recently pushed context object of type ``cls`` on the stack, or ``None``.

If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``.
"""
try:
candidate: T | None = cls.get_contexts()[-1]
except IndexError:
# Calling code expects to get a TypeError if the entity
# is unfound, and there's too much to fix.
if error_if_none:
raise TypeError(f"No {cls} on context stack")
return None
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access:
raise BlockModelAccessError(candidate.error_msg_on_access)
return candidate

def get_contexts(cls) -> list[T]:
"""Return a stack of context instances for the ``context_class`` of ``cls``."""
# This lazily creates the context class's contexts
# thread-local object, as needed. This seems inelegant to me,
# but since the context class is not guaranteed to exist when
# the metaclass is being instantiated, I couldn't figure out a
# better way. [2019/10/11:rpg]

# no race-condition here, contexts is a thread-local object
# be sure not to override contexts in a subclass however!
context_class = cls.context_class
assert isinstance(
context_class, type
), f"Name of context class, {context_class} was not resolvable to a class"
if not hasattr(context_class, "contexts"):
context_class.contexts = threading.local()

contexts = context_class.contexts

if not hasattr(contexts, "stack"):
contexts.stack = []
return contexts.stack

# the following complex property accessor is necessary because the
# context_class may not have been created at the point it is
# specified, so the context_class may be a class *name* rather
# than a class.
@property
def context_class(cls) -> type:
def resolve_type(c: type | str) -> type:
if isinstance(c, str):
c = getattr(modules[cls.__module__], c)
if isinstance(c, type):
return c
raise ValueError(f"Cannot resolve context class {c}")

assert cls is not None
if isinstance(cls._context_class, str):
cls._context_class = resolve_type(cls._context_class)
if not isinstance(cls._context_class, str | type):
raise ValueError(
f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type"
)
return cls._context_class

# Inherit context class from parent
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.context_class = super().context_class

# Initialize object in its own context...
# Merged from InitContextMeta in the original.
def __call__(cls, *args, **kwargs):
# We type hint Model here so type checkers understand that Model is a context manager.
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info.
instance: Model = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance
# MODEL_MANAGER is instantiated at import, and serves as a truth for
# what any currently active model contexts are.
MODEL_MANAGER = ModelManager()


def modelcontext(model: Optional["Model"]) -> "Model":
def modelcontext(model: Model | None) -> Model:
"""Return the given model or, if None was supplied, try to find one in the context stack."""
if model is None:
model = Model.get_context(error_if_none=False)
Expand Down Expand Up @@ -372,6 +286,18 @@ def profile(self):
return self._pytensor_function.profile


class ContextMeta(type):
"""A metaclass in order to apply a model's context during `Model.__init__``."""

# We want the Model's context to be active during __init__. In order for this
# to apply to subclasses of Model as well, we need to use a metaclass.
def __call__(cls: type[Model], *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # applies context
instance.__init__(*args, **kwargs)
return instance


class Model(WithMemoization, metaclass=ContextMeta):
"""Encapsulates the variables and likelihood factors of a model.
Expand Down Expand Up @@ -495,22 +421,14 @@ class Model(WithMemoization, metaclass=ContextMeta):
"""

if TYPE_CHECKING:
def __enter__(self):
"""Enter the context manager."""
MODEL_MANAGER.active_contexts.append(self)
return self

def __enter__(self: Self) -> Self:
"""Enter the context manager."""

def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
"""Exit the context manager."""

def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if model is UNSET:
instance._parent = cls.get_context(error_if_none=False)
else:
instance._parent = model
return instance
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
"""Exit the context manager."""
_ = MODEL_MANAGER.active_contexts.pop()

@staticmethod
def _validate_name(name):
Expand All @@ -525,11 +443,11 @@ def __init__(
check_bounds=True,
*,
coords_mutable=None,
model: Union[Literal[UNSET], None, "Model"] = UNSET,
model: _UnsetType | None | Model = UNSET,
):
del model # used in __new__ to define the parent of this model
self.name = self._validate_name(name)
self.check_bounds = check_bounds
self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context

if coords_mutable is not None:
warnings.warn(
Expand Down Expand Up @@ -577,6 +495,17 @@ def __init__(
functools.partial(str_for_model, formatting="latex"), self
)

@classmethod
def get_context(
cls, error_if_none: bool = True, allow_block_model_access: bool = False
) -> Model | None:
model = MODEL_MANAGER.current_context
if isinstance(model, BlockModelAccess) and not allow_block_model_access:
raise BlockModelAccessError(model.error_msg_on_access)
if model is None and error_if_none:
raise TypeError("No model on context stack")
return model

@property
def parent(self):
return self._parent
Expand Down Expand Up @@ -967,7 +896,7 @@ def shape_from_dims(self, dims):
def add_coord(
self,
name: str,
values: Sequence | None = None,
values: Sequence | np.ndarray | None = None,
mutable: bool | None = None,
*,
length: int | Variable | None = None,
Expand Down Expand Up @@ -1233,16 +1162,16 @@ def set_data(

def register_rv(
self,
rv_var,
name,
rv_var: RandomVariable,
name: str,
*,
observed=None,
total_size=None,
dims=None,
default_transform=UNSET,
transform=UNSET,
initval=None,
):
) -> TensorVariable:
"""Register an (un)observed random variable with the model.
Parameters
Expand Down Expand Up @@ -2074,11 +2003,6 @@ def to_graphviz(
)


# this is really disgusting, but it breaks a self-loop: I can't pass Model
# itself as context class init arg.
Model._context_class = Model


class BlockModelAccess(Model):
"""Can be used to prevent user access to Model contexts."""

Expand Down

0 comments on commit d843bae

Please sign in to comment.