diff --git a/mypy_extensions.py b/mypy_extensions.py index 1910000..90a2d18 100644 --- a/mypy_extensions.py +++ b/mypy_extensions.py @@ -5,12 +5,37 @@ from mypy_extensions import TypedDict """ -from typing import Any, Dict +from typing import Any, Callable, Dict, Final, Literal, TypeVar import sys # _type_check is NOT a part of public typing API, it is used here only to mimic # the (convenient) behavior of types provided by typing module. -from typing import _type_check # type: ignore +from typing import _type_check # type: ignore [attr-defined] + +from typing_extensions import NotRequired, TypeGuard, Unpack + + +_C = TypeVar("_C", bound=Callable[..., Any]) + +MYPYC_ATTRS: Final = frozenset([ + "native_class", + "allow_interpreted_subclasses", + "serializable", + "free_list_len", +]) + +MypycAttr = Literal[ + "native_class", + "allow_interpreted_subclasses", + "serializable", + "free_list_len", +] + +class MypycAttrs(TypedDict): + native_class: NotRequired[bool] + allow_interpreted_subclasses: NotRequired[bool] + serializable: NotRequired[bool] + free_list_len: NotRequired[int] def _check_fails(cls, other): @@ -158,12 +183,25 @@ def trait(cls): return cls -def mypyc_attr(*attrs, **kwattrs): +def _validate_mypyc_attr_key(key: str) -> TypeGuard[MypycAttr]: + if key not in MYPYC_ATTRS: + raise ValueError( + f"{key!r} is not a valid `mypyc_attr` key.\n" + "Valid keys are: {', '.join(map(repr, sorted(MYPYC_ATTRS)))}" + ) + + +def mypyc_attr(*attrs: MypycAttr, **kwattrs: Unpack[MypycAttrs]) -> Callable[[_C], _C]: + for key in attrs: + _validate_mypyc_attr_key(key) + for key, value in kwattrs.items(): + _validate_mypyc_attr_key(key) + if not isinstance(value, bool): + raise TypeError(f"{key} value should be boolean, not {type(value).__name__}.") return lambda x: x -# TODO: We may want to try to properly apply this to any type -# variables left over... +# TODO: We may want to try to properly apply this to any type variables left over... class _FlexibleAliasClsApplied: def __init__(self, val): self.val = val