diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index b6cd632e475f..e014d97fedd9 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -88,7 +88,7 @@ def build_type_map( is_abstract=cdef.info.is_abstract, is_final_class=cdef.info.is_final, ) - class_ir.is_ext_class = is_extension_class(cdef) + class_ir.is_ext_class = is_extension_class(module.path, cdef, errors) if class_ir.is_ext_class: class_ir.deletable = cdef.info.deletable_attributes.copy() # If global optimizations are disabled, turn of tracking of class children diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index 43ee547f8b4f..939c543c85a2 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -29,6 +29,7 @@ ) from mypy.semanal import refers_to_fullname from mypy.types import FINAL_DECORATOR_NAMES +from mypyc.errors import Errors DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"} @@ -125,15 +126,68 @@ def get_mypyc_attrs(stmt: ClassDef | Decorator) -> dict[str, Any]: return attrs -def is_extension_class(cdef: ClassDef) -> bool: - if any( - not is_trait_decorator(d) - and not is_dataclass_decorator(d) - and not get_mypyc_attr_call(d) - and not is_final_decorator(d) - for d in cdef.decorators - ): +def is_extension_class(path: str, cdef: ClassDef, errors: Errors) -> bool: + # Check for @mypyc_attr(native_class=True/False) decorator. + explicit_native_class = get_explicit_native_class(path, cdef, errors) + + # Classes with native_class=False are explicitly marked as non extension. + if explicit_native_class is False: return False + + implicit_extension_class = is_implicit_extension_class(cdef) + + # Classes with native_class=True should be extension classes, but they might + # not be able to be due to other reasons. Print an error in that case. + if explicit_native_class is True and not implicit_extension_class: + errors.error( + "Class is marked as native_class=True but it can't be a native class", path, cdef.line + ) + + return implicit_extension_class + + +def get_explicit_native_class(path: str, cdef: ClassDef, errors: Errors) -> bool | None: + """Return value of @mypyc_attr(native_class=True/False) decorator. + + Look for a @mypyc_attr decorator with native_class=True/False and return + the value assigned or None if it doesn't exist. Other values are an error. + """ + + for d in cdef.decorators: + mypyc_attr_call = get_mypyc_attr_call(d) + if not mypyc_attr_call: + continue + + for i, name in enumerate(mypyc_attr_call.arg_names): + if name != "native_class": + continue + + arg = mypyc_attr_call.args[i] + if not isinstance(arg, NameExpr): + errors.error("native_class must be used with True or False only", path, cdef.line) + return None + + if arg.name == "False": + return False + elif arg.name == "True": + return True + else: + errors.error("native_class must be used with True or False only", path, cdef.line) + return None + return None + + +def is_implicit_extension_class(cdef: ClassDef) -> bool: + for d in cdef.decorators: + # Classes that have any decorator other than supported decorators, are not extension classes + if ( + not is_trait_decorator(d) + and not is_dataclass_decorator(d) + and not get_mypyc_attr_call(d) + and not is_final_decorator(d) + ): + return False + if cdef.info.typeddict_type: return False if cdef.info.is_named_tuple: diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index e651e7adc384..0986f7522411 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -350,8 +350,10 @@ def next(i: Iterator[_T]) -> _T: pass def next(i: Iterator[_T], default: _T) -> _T: pass def hash(o: object) -> int: ... def globals() -> Dict[str, Any]: ... +def hasattr(obj: object, name: str) -> bool: ... def getattr(obj: object, name: str, default: Any = None) -> Any: ... def setattr(obj: object, name: str, value: Any) -> None: ... +def delattr(obj: object, name: str) -> None: ... def enumerate(x: Iterable[_T]) -> Iterator[Tuple[int, _T]]: ... @overload def zip(x: Iterable[_T], y: Iterable[_S]) -> Iterator[Tuple[_T, _S]]: ... diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index ed7c167d8621..972146bcb0b4 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -1345,3 +1345,28 @@ class SomeEnum(Enum): ALIAS = Literal[SomeEnum.AVALUE] ALIAS2 = Union[Literal[SomeEnum.AVALUE], None] + +[case testMypycAttrNativeClassErrors] +from mypy_extensions import mypyc_attr + +@mypyc_attr(native_class=False) +class AnnontatedNonExtensionClass: + pass + +@mypyc_attr(native_class=False) +class DerivedExplicitNonNativeClass(AnnontatedNonExtensionClass): + pass + + +def decorator(cls): + return cls + +@mypyc_attr(native_class=True) +@decorator +class NonNativeClassContradiction(): # E: Class is marked as native_class=True but it can't be a native class + pass + + +@mypyc_attr(native_class="yes") +class BadUse(): # E: native_class must be used with True or False only + pass diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 601d6d7a65a0..edf9e6bf1906 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -2829,3 +2829,56 @@ Traceback (most recent call last): File "native.py", line 5, in __del__ raise Exception("e2") Exception: e2 + +[case testMypycAttrNativeClass] +from mypy_extensions import mypyc_attr +from testutil import assertRaises + +@mypyc_attr(native_class=False) +class AnnontatedNonExtensionClass: + pass + +class DerivedClass(AnnontatedNonExtensionClass): + pass + +class ImplicitExtensionClass(): + pass + +@mypyc_attr(native_class=True) +class AnnotatedExtensionClass(): + pass + +def test_function(): + setattr(AnnontatedNonExtensionClass, 'attr_class', 5) + assert(hasattr(AnnontatedNonExtensionClass, 'attr_class') == True) + assert(getattr(AnnontatedNonExtensionClass, 'attr_class') == 5) + delattr(AnnontatedNonExtensionClass, 'attr_class') + assert(hasattr(AnnontatedNonExtensionClass, 'attr_class') == False) + + inst = AnnontatedNonExtensionClass() + setattr(inst, 'attr_instance', 6) + assert(hasattr(inst, 'attr_instance') == True) + assert(getattr(inst, 'attr_instance') == 6) + delattr(inst, 'attr_instance') + assert(hasattr(inst, 'attr_instance') == False) + + setattr(DerivedClass, 'attr_class', 5) + assert(hasattr(DerivedClass, 'attr_class') == True) + assert(getattr(DerivedClass, 'attr_class') == 5) + delattr(DerivedClass, 'attr_class') + assert(hasattr(DerivedClass, 'attr_class') == False) + + derived_inst = DerivedClass() + setattr(derived_inst, 'attr_instance', 6) + assert(hasattr(derived_inst, 'attr_instance') == True) + assert(getattr(derived_inst, 'attr_instance') == 6) + delattr(derived_inst, 'attr_instance') + assert(hasattr(derived_inst, 'attr_instance') == False) + + ext_inst = ImplicitExtensionClass() + with assertRaises(AttributeError): + setattr(ext_inst, 'attr_instance', 6) + + explicit_ext_inst = AnnotatedExtensionClass() + with assertRaises(AttributeError): + setattr(explicit_ext_inst, 'attr_instance', 6)