diff --git a/pymc/data.py b/pymc/data.py index 247825981f1..22fc8717c3f 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -221,9 +221,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, value: pd.DataFrame | pd.Series | xr.DataArray, - dims: Sequence[str | None] | None = None, + dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, -) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]: +) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]: """Determine coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} @@ -268,9 +268,10 @@ def determine_coords( if dims is None: # TODO: Also determine dim names from the index - dims = [None] * np.ndim(value) - - return coords, dims + new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value) + else: + new_dims = dims + return coords, new_dims def ConstantData( @@ -366,7 +367,7 @@ def Data( The name for this variable. value : array_like or pandas.Series, pandas.Dataframe A value to associate with this variable. - dims : str or tuple of str, optional + dims : str, tuple of str or tuple of None, optional Dimension names of the random variables (as opposed to the shapes of these random variables). Use this when ``value`` is a pandas Series or DataFrame. The ``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ @@ -451,14 +452,17 @@ def Data( expected=x.ndim, ) + new_dims: Sequence[str] | Sequence[None] | None if infer_dims_and_coords: - coords, dims = determine_coords(model, value, dims) + coords, new_dims = determine_coords(model, value, dims) + else: + new_dims = dims - if dims: + if new_dims: xshape = x.shape # Register new dimension lengths - for d, dname in enumerate(dims): - if dname not in model.dim_lengths: + for d, dname in enumerate(new_dims): + if dname not in model.dim_lengths and dname is not None: model.add_coord( name=dname, # Note: Coordinate values can't be taken from @@ -467,6 +471,6 @@ def Data( length=xshape[d], ) - model.register_data_var(x, dims=dims) + model.register_data_var(x, dims=new_dims) return x diff --git a/pymc/model/core.py b/pymc/model/core.py index 48d2117eb26..ad60a84dfb9 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import functools import sys @@ -19,13 +20,8 @@ import warnings from collections.abc import Iterable, Sequence -from sys import modules from typing import ( - TYPE_CHECKING, Literal, - Optional, - TypeVar, - Union, cast, overload, ) @@ -42,7 +38,6 @@ from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.type import RandomType from pytensor.tensor.variable import TensorConstant, TensorVariable -from typing_extensions import Self from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.data import is_valid_observed @@ -73,6 +68,7 @@ VarName, WithMemoization, _add_future_warning_tag, + _UnsetType, get_transformed_name, get_value_vars_from_user_vars, get_var_name, @@ -92,118 +88,36 @@ ] -T = TypeVar("T", bound="ContextMeta") +class ModelManager(threading.local): + """Keeps track of currently active model contexts. + A global instance of this is created in this module on import. + Use that instance, `MODEL_MANAGER` to inspect current contexts. -class ContextMeta(type): - """Functionality for objects that put themselves in a context manager.""" - - def __new__(cls, name, bases, dct, **kwargs): - """Add __enter__ and __exit__ methods to the class.""" - - def __enter__(self): - self.__class__.context_class.get_contexts().append(self) - return self - - def __exit__(self, typ, value, traceback): - self.__class__.context_class.get_contexts().pop() + It inherits from threading.local so is thread-safe, if models + can be entered/exited within individual threads. + """ - dct[__enter__.__name__] = __enter__ - dct[__exit__.__name__] = __exit__ + def __init__(self): + self.active_contexts: list[Model] = [] - # We strip off keyword args, per the warning from - # StackExchange: - # DO NOT send "**kwargs" to "type.__new__". It won't catch them and - # you'll get a "TypeError: type() takes 1 or 3 arguments" exception. - return super().__new__(cls, name, bases, dct) + @property + def current_context(self) -> Model | None: + """Return the innermost context of any current contexts.""" + return self.active_contexts[-1] if self.active_contexts else None - # FIXME: is there a more elegant way to automatically add methods to the class that - # are instance methods instead of class methods? - def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs): - """Add ``__enter__`` and ``__exit__`` methods to the new class automatically.""" - if context_class is not None: - cls._context_class = context_class - super().__init__(name, bases, nmspc) + @property + def parent_context(self) -> Model | None: + """Return the parent context to the active context, if any.""" + return self.active_contexts[-2] if len(self.active_contexts) > 1 else None - def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None: - """Return the most recently pushed context object of type ``cls`` on the stack, or ``None``. - If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``. - """ - try: - candidate: T | None = cls.get_contexts()[-1] - except IndexError: - # Calling code expects to get a TypeError if the entity - # is unfound, and there's too much to fix. - if error_if_none: - raise TypeError(f"No {cls} on context stack") - return None - if isinstance(candidate, BlockModelAccess) and not allow_block_model_access: - raise BlockModelAccessError(candidate.error_msg_on_access) - return candidate - - def get_contexts(cls) -> list[T]: - """Return a stack of context instances for the ``context_class`` of ``cls``.""" - # This lazily creates the context class's contexts - # thread-local object, as needed. This seems inelegant to me, - # but since the context class is not guaranteed to exist when - # the metaclass is being instantiated, I couldn't figure out a - # better way. [2019/10/11:rpg] - - # no race-condition here, contexts is a thread-local object - # be sure not to override contexts in a subclass however! - context_class = cls.context_class - assert isinstance( - context_class, type - ), f"Name of context class, {context_class} was not resolvable to a class" - if not hasattr(context_class, "contexts"): - context_class.contexts = threading.local() - - contexts = context_class.contexts - - if not hasattr(contexts, "stack"): - contexts.stack = [] - return contexts.stack - - # the following complex property accessor is necessary because the - # context_class may not have been created at the point it is - # specified, so the context_class may be a class *name* rather - # than a class. - @property - def context_class(cls) -> type: - def resolve_type(c: type | str) -> type: - if isinstance(c, str): - c = getattr(modules[cls.__module__], c) - if isinstance(c, type): - return c - raise ValueError(f"Cannot resolve context class {c}") - - assert cls is not None - if isinstance(cls._context_class, str): - cls._context_class = resolve_type(cls._context_class) - if not isinstance(cls._context_class, str | type): - raise ValueError( - f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type" - ) - return cls._context_class - - # Inherit context class from parent - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - cls.context_class = super().context_class - - # Initialize object in its own context... - # Merged from InitContextMeta in the original. - def __call__(cls, *args, **kwargs): - # We type hint Model here so type checkers understand that Model is a context manager. - # This metaclass is only used for Model, so this is safe to do. See #6809 for more info. - instance: Model = cls.__new__(cls, *args, **kwargs) - with instance: # appends context - instance.__init__(*args, **kwargs) - return instance +# MODEL_MANAGER is instantiated at import, and serves as a truth for +# what any currently active model contexts are. +MODEL_MANAGER = ModelManager() -def modelcontext(model: Optional["Model"]) -> "Model": +def modelcontext(model: Model | None) -> Model: """Return the given model or, if None was supplied, try to find one in the context stack.""" if model is None: model = Model.get_context(error_if_none=False) @@ -372,6 +286,18 @@ def profile(self): return self._pytensor_function.profile +class ContextMeta(type): + """A metaclass in order to apply a model's context during `Model.__init__``.""" + + # We want the Model's context to be active during __init__. In order for this + # to apply to subclasses of Model as well, we need to use a metaclass. + def __call__(cls: type[Model], *args, **kwargs): + instance = cls.__new__(cls, *args, **kwargs) + with instance: # applies context + instance.__init__(*args, **kwargs) + return instance + + class Model(WithMemoization, metaclass=ContextMeta): """Encapsulates the variables and likelihood factors of a model. @@ -495,22 +421,14 @@ class Model(WithMemoization, metaclass=ContextMeta): """ - if TYPE_CHECKING: + def __enter__(self): + """Enter the context manager.""" + MODEL_MANAGER.active_contexts.append(self) + return self - def __enter__(self: Self) -> Self: - """Enter the context manager.""" - - def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: - """Exit the context manager.""" - - def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs): - # resolves the parent instance - instance = super().__new__(cls) - if model is UNSET: - instance._parent = cls.get_context(error_if_none=False) - else: - instance._parent = model - return instance + def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: + """Exit the context manager.""" + _ = MODEL_MANAGER.active_contexts.pop() @staticmethod def _validate_name(name): @@ -525,11 +443,11 @@ def __init__( check_bounds=True, *, coords_mutable=None, - model: Union[Literal[UNSET], None, "Model"] = UNSET, + model: _UnsetType | None | Model = UNSET, ): - del model # used in __new__ to define the parent of this model self.name = self._validate_name(name) self.check_bounds = check_bounds + self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context if coords_mutable is not None: warnings.warn( @@ -577,6 +495,17 @@ def __init__( functools.partial(str_for_model, formatting="latex"), self ) + @classmethod + def get_context( + cls, error_if_none: bool = True, allow_block_model_access: bool = False + ) -> Model | None: + model = MODEL_MANAGER.current_context + if isinstance(model, BlockModelAccess) and not allow_block_model_access: + raise BlockModelAccessError(model.error_msg_on_access) + if model is None and error_if_none: + raise TypeError("No model on context stack") + return model + @property def parent(self): return self._parent @@ -967,7 +896,7 @@ def shape_from_dims(self, dims): def add_coord( self, name: str, - values: Sequence | None = None, + values: Sequence | np.ndarray | None = None, mutable: bool | None = None, *, length: int | Variable | None = None, @@ -1233,8 +1162,8 @@ def set_data( def register_rv( self, - rv_var, - name, + rv_var: RandomVariable, + name: str, *, observed=None, total_size=None, @@ -1242,7 +1171,7 @@ def register_rv( default_transform=UNSET, transform=UNSET, initval=None, - ): + ) -> TensorVariable: """Register an (un)observed random variable with the model. Parameters @@ -2074,11 +2003,6 @@ def to_graphviz( ) -# this is really disgusting, but it breaks a self-loop: I can't pass Model -# itself as context class init arg. -Model._context_class = Model - - class BlockModelAccess(Model): """Can be used to prevent user access to Model contexts."""