From d843baea5fc9fc7db47ae9f2dc50665662767765 Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Thu, 10 Oct 2024 14:54:56 +0200 Subject: [PATCH] Simplify Model __new__ and metaclass (#7473) * Type get_context correctly get_context returns an instance of a Model, not a ContextMeta object We don't need the typevar, since we don't use it for anything special * Import from future to use delayed evaluation of annotations All of these are supported on python>=3.9. * New ModelManager class for managing model contexts We create a global instance of it within this module, which is similar to how it worked before, where a `context_class` attribute was attached to the Model class. We inherit from threading.local to ensure thread safety when working with models on multiple threads. See #1552 for the reasoning. This is already tested in `test_thread_safety`. * Model class is now the context manager directly * Fix type of UNSET in type definition UNSET is the instance of the _UnsetType type. We should be typing the latter here. * Set model parent in init rather than in __new__ We use the new ModelManager.parent_context property to reliably set any parent context, or else set it to None. * Replace get_context in metaclass with classmethod We set this directly on the class as a classmethod, which is clearer than going via the metaclass. * Remove get_contexts from metaclass The original function does not behave as I expected. In the following example I expected that it would return only the final model, not root. This method is not used anywhere in the pymc codebase, so I have dropped it from the codebase. I originally included the following code to replace it, but since it is not used anyway, it is better to remove it. ```python` @classmethod def get_contexts(cls) -> list[Model]: """Return a list of the currently active model contexts.""" return MODEL_MANAGER.active_contexts ``` Example for testing behaviour in current main branch: ```python import pymc as pm with pm.Model(name="root") as root: print([c.name for c in pm.Model.get_contexts()]) with pm.Model(name="first") as first: print([c.name for c in pm.Model.get_contexts()]) with pm.Model(name="m_with_model_None", model=None) as m_with_model_None: # This one doesn't make much sense: print([c.name for c in pm.Model.get_contexts()]) ``` * Simplify ContextMeta We only keep the __call__ method, which is necessary to keep the model context itself active during that model's __init__. * Type Model.register_rv for for downstream typing In pymc/distributions/distribution.py, this change allows the type checker to infer that `rv_out` can only be a TensorVariable. Thanks to @ricardoV94 for type hint on rv_var. * Include np.ndarray as possible type for coord values I originally tried numpy's ArrayLike, replacing Sequence entirely, but then I realized that ArrayLike also allows non-sequences like integers and floats. I am not certain if `values="a string"` should be legal. With the type hint sequence, it is. Might be more accurate, but verbose to use `list | tuple | set | np.ndarray | None`. * Use function-scoped new_dims to handle type hint varying throughout function We don't want to allow the user to pass a `dims=[None, None]` to our function, but current behaviour set `dims=[None] * N` at the end of `determine_coords`. To handle this, I created a `new_dims` with a larger type scope which matches the return type of `dims` in `determine_coords`. Then I did the same within def Data to support this new type hint. * Fix case of dims = [None, None, ...] The only case where dims=[None, ...] is when the user has passed dims=None. Since the user passed dims=None, they shouldn't be expecting any coords to match that dimension. Thus we don't need to try to add any more coords to the model. * Remove unused hack --- pymc/data.py | 26 +++--- pymc/model/core.py | 194 ++++++++++++++------------------------------- 2 files changed, 74 insertions(+), 146 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 247825981f1..22fc8717c3f 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -221,9 +221,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, value: pd.DataFrame | pd.Series | xr.DataArray, - dims: Sequence[str | None] | None = None, + dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, -) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]: +) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]: """Determine coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} @@ -268,9 +268,10 @@ def determine_coords( if dims is None: # TODO: Also determine dim names from the index - dims = [None] * np.ndim(value) - - return coords, dims + new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value) + else: + new_dims = dims + return coords, new_dims def ConstantData( @@ -366,7 +367,7 @@ def Data( The name for this variable. value : array_like or pandas.Series, pandas.Dataframe A value to associate with this variable. - dims : str or tuple of str, optional + dims : str, tuple of str or tuple of None, optional Dimension names of the random variables (as opposed to the shapes of these random variables). Use this when ``value`` is a pandas Series or DataFrame. The ``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ @@ -451,14 +452,17 @@ def Data( expected=x.ndim, ) + new_dims: Sequence[str] | Sequence[None] | None if infer_dims_and_coords: - coords, dims = determine_coords(model, value, dims) + coords, new_dims = determine_coords(model, value, dims) + else: + new_dims = dims - if dims: + if new_dims: xshape = x.shape # Register new dimension lengths - for d, dname in enumerate(dims): - if dname not in model.dim_lengths: + for d, dname in enumerate(new_dims): + if dname not in model.dim_lengths and dname is not None: model.add_coord( name=dname, # Note: Coordinate values can't be taken from @@ -467,6 +471,6 @@ def Data( length=xshape[d], ) - model.register_data_var(x, dims=dims) + model.register_data_var(x, dims=new_dims) return x diff --git a/pymc/model/core.py b/pymc/model/core.py index 48d2117eb26..ad60a84dfb9 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import functools import sys @@ -19,13 +20,8 @@ import warnings from collections.abc import Iterable, Sequence -from sys import modules from typing import ( - TYPE_CHECKING, Literal, - Optional, - TypeVar, - Union, cast, overload, ) @@ -42,7 +38,6 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable -from typing_extensions import Self from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import is_valid_observed @@ -73,6 +68,7 @@ VarName, WithMemoization, _add_future_warning_tag, + _UnsetType, get_transformed_name, get_value_vars_from_user_vars, get_var_name, @@ -92,118 +88,36 @@ ] -T = TypeVar("T", bound="ContextMeta") +class ModelManager(threading.local): + """Keeps track of currently active model contexts. + A global instance of this is created in this module on import. + Use that instance, `MODEL_MANAGER` to inspect current contexts. -class ContextMeta(type): - """Functionality for objects that put themselves in a context manager.""" - - def __new__(cls, name, bases, dct, **kwargs): - """Add __enter__ and __exit__ methods to the class.""" - - def __enter__(self): - self.__class__.context_class.get_contexts().append(self) - return self - - def __exit__(self, typ, value, traceback): - self.__class__.context_class.get_contexts().pop() + It inherits from threading.local so is thread-safe, if models + can be entered/exited within individual threads. + """ - dct[__enter__.__name__] = __enter__ - dct[__exit__.__name__] = __exit__ + def __init__(self): + self.active_contexts: list[Model] = [] - # We strip off keyword args, per the warning from - # StackExchange: - # DO NOT send "**kwargs" to "type.__new__". It won't catch them and - # you'll get a "TypeError: type() takes 1 or 3 arguments" exception. - return super().__new__(cls, name, bases, dct) + @property + def current_context(self) -> Model | None: + """Return the innermost context of any current contexts.""" + return self.active_contexts[-1] if self.active_contexts else None - # FIXME: is there a more elegant way to automatically add methods to the class that - # are instance methods instead of class methods? - def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs): - """Add ``__enter__`` and ``__exit__`` methods to the new class automatically.""" - if context_class is not None: - cls._context_class = context_class - super().__init__(name, bases, nmspc) + @property + def parent_context(self) -> Model | None: + """Return the parent context to the active context, if any.""" + return self.active_contexts[-2] if len(self.active_contexts) > 1 else None - def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None: - """Return the most recently pushed context object of type ``cls`` on the stack, or ``None``. - If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``. - """ - try: - candidate: T | None = cls.get_contexts()[-1] - except IndexError: - # Calling code expects to get a TypeError if the entity - # is unfound, and there's too much to fix. - if error_if_none: - raise TypeError(f"No {cls} on context stack") - return None - if isinstance(candidate, BlockModelAccess) and not allow_block_model_access: - raise BlockModelAccessError(candidate.error_msg_on_access) - return candidate - - def get_contexts(cls) -> list[T]: - """Return a stack of context instances for the ``context_class`` of ``cls``.""" - # This lazily creates the context class's contexts - # thread-local object, as needed. This seems inelegant to me, - # but since the context class is not guaranteed to exist when - # the metaclass is being instantiated, I couldn't figure out a - # better way. [2019/10/11:rpg] - - # no race-condition here, contexts is a thread-local object - # be sure not to override contexts in a subclass however! - context_class = cls.context_class - assert isinstance( - context_class, type - ), f"Name of context class, {context_class} was not resolvable to a class" - if not hasattr(context_class, "contexts"): - context_class.contexts = threading.local() - - contexts = context_class.contexts - - if not hasattr(contexts, "stack"): - contexts.stack = [] - return contexts.stack - - # the following complex property accessor is necessary because the - # context_class may not have been created at the point it is - # specified, so the context_class may be a class *name* rather - # than a class. - @property - def context_class(cls) -> type: - def resolve_type(c: type | str) -> type: - if isinstance(c, str): - c = getattr(modules[cls.__module__], c) - if isinstance(c, type): - return c - raise ValueError(f"Cannot resolve context class {c}") - - assert cls is not None - if isinstance(cls._context_class, str): - cls._context_class = resolve_type(cls._context_class) - if not isinstance(cls._context_class, str | type): - raise ValueError( - f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type" - ) - return cls._context_class - - # Inherit context class from parent - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - cls.context_class = super().context_class - - # Initialize object in its own context... - # Merged from InitContextMeta in the original. - def __call__(cls, *args, **kwargs): - # We type hint Model here so type checkers understand that Model is a context manager. - # This metaclass is only used for Model, so this is safe to do. See #6809 for more info. - instance: Model = cls.__new__(cls, *args, **kwargs) - with instance: # appends context - instance.__init__(*args, **kwargs) - return instance +# MODEL_MANAGER is instantiated at import, and serves as a truth for +# what any currently active model contexts are. +MODEL_MANAGER = ModelManager() -def modelcontext(model: Optional["Model"]) -> "Model": +def modelcontext(model: Model | None) -> Model: """Return the given model or, if None was supplied, try to find one in the context stack.""" if model is None: model = Model.get_context(error_if_none=False) @@ -372,6 +286,18 @@ def profile(self): return self._pytensor_function.profile +class ContextMeta(type): + """A metaclass in order to apply a model's context during `Model.__init__``.""" + + # We want the Model's context to be active during __init__. In order for this + # to apply to subclasses of Model as well, we need to use a metaclass. + def __call__(cls: type[Model], *args, **kwargs): + instance = cls.__new__(cls, *args, **kwargs) + with instance: # applies context + instance.__init__(*args, **kwargs) + return instance + + class Model(WithMemoization, metaclass=ContextMeta): """Encapsulates the variables and likelihood factors of a model. @@ -495,22 +421,14 @@ class Model(WithMemoization, metaclass=ContextMeta): """ - if TYPE_CHECKING: + def __enter__(self): + """Enter the context manager.""" + MODEL_MANAGER.active_contexts.append(self) + return self - def __enter__(self: Self) -> Self: - """Enter the context manager.""" - - def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: - """Exit the context manager.""" - - def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs): - # resolves the parent instance - instance = super().__new__(cls) - if model is UNSET: - instance._parent = cls.get_context(error_if_none=False) - else: - instance._parent = model - return instance + def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: + """Exit the context manager.""" + _ = MODEL_MANAGER.active_contexts.pop() @staticmethod def _validate_name(name): @@ -525,11 +443,11 @@ def __init__( check_bounds=True, *, coords_mutable=None, - model: Union[Literal[UNSET], None, "Model"] = UNSET, + model: _UnsetType | None | Model = UNSET, ): - del model # used in __new__ to define the parent of this model self.name = self._validate_name(name) self.check_bounds = check_bounds + self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context if coords_mutable is not None: warnings.warn( @@ -577,6 +495,17 @@ def __init__( functools.partial(str_for_model, formatting="latex"), self ) + @classmethod + def get_context( + cls, error_if_none: bool = True, allow_block_model_access: bool = False + ) -> Model | None: + model = MODEL_MANAGER.current_context + if isinstance(model, BlockModelAccess) and not allow_block_model_access: + raise BlockModelAccessError(model.error_msg_on_access) + if model is None and error_if_none: + raise TypeError("No model on context stack") + return model + @property def parent(self): return self._parent @@ -967,7 +896,7 @@ def shape_from_dims(self, dims): def add_coord( self, name: str, - values: Sequence | None = None, + values: Sequence | np.ndarray | None = None, mutable: bool | None = None, *, length: int | Variable | None = None, @@ -1233,8 +1162,8 @@ def set_data( def register_rv( self, - rv_var, - name, + rv_var: RandomVariable, + name: str, *, observed=None, total_size=None, @@ -1242,7 +1171,7 @@ def register_rv( default_transform=UNSET, transform=UNSET, initval=None, - ): + ) -> TensorVariable: """Register an (un)observed random variable with the model. Parameters @@ -2074,11 +2003,6 @@ def to_graphviz( ) -# this is really disgusting, but it breaks a self-loop: I can't pass Model -# itself as context class init arg. -Model._context_class = Model - - class BlockModelAccess(Model): """Can be used to prevent user access to Model contexts."""