From 1bcfc041bb767ee93e90676b0a61f3e40267e858 Mon Sep 17 00:00:00 2001 From: Chad Dombrova Date: Thu, 25 Nov 2021 03:14:08 -0800 Subject: [PATCH] [mypyc] Introduce ClassBuilder and add support for attrs classes (#11328) Create a class structure for building different types of classes, extension vs non-extension, plus various types of dataclasses. Specilize member expressions. Prior to this change dataclasses.field would only be specialized if called without the module, as field() Add support for attrs classes and build proper __annotations__ for dataclasses. Co-authored-by: Jingchen Ye <97littleleaf11@gmail.com> --- mypyc/ir/class_ir.py | 8 + mypyc/irbuild/classdef.py | 380 ++++++++++++++++++++---------- mypyc/irbuild/expression.py | 28 +-- mypyc/irbuild/function.py | 10 - mypyc/irbuild/specialize.py | 32 ++- mypyc/irbuild/util.py | 44 +++- mypyc/test-data/run-attrs.test | 318 +++++++++++++++++++++++++ mypyc/test-data/run-python37.test | 8 + mypyc/test/test_run.py | 4 +- 9 files changed, 671 insertions(+), 161 deletions(-) create mode 100644 mypyc/test-data/run-attrs.test diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index 65274e425098f..ade04f39edcb9 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -152,6 +152,14 @@ def __init__(self, name: str, module_name: str, is_trait: bool = False, # None if separate compilation prevents this from working self.children: Optional[List[ClassIR]] = [] + def __repr__(self) -> str: + return ( + "ClassIR(" + "name={self.name}, module_name={self.module_name}, " + "is_trait={self.is_trait}, is_generated={self.is_generated}, " + "is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}" + ")".format(self=self)) + @property def fullname(self) -> str: return "{}.{}".format(self.module_name, self.name) diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 80cce0bbd35fc..6c2958c5b8de8 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -1,12 +1,14 @@ """Transform class definitions from the mypy AST form to IR.""" -from typing import List, Optional, Tuple +from abc import abstractmethod +from typing import Callable, List, Optional, Tuple from typing_extensions import Final from mypy.nodes import ( - ClassDef, FuncDef, OverloadedFuncDef, PassStmt, AssignmentStmt, NameExpr, StrExpr, - ExpressionStmt, TempNode, Decorator, Lvalue, RefExpr, is_class_var + ClassDef, FuncDef, OverloadedFuncDef, PassStmt, AssignmentStmt, CallExpr, NameExpr, StrExpr, + ExpressionStmt, TempNode, Decorator, Lvalue, MemberExpr, RefExpr, TypeInfo, is_class_var ) +from mypy.types import Instance, get_proper_type from mypyc.ir.ops import ( Value, Register, Call, LoadErrorValue, LoadStatic, InitStatic, TupleSet, SetAttr, Return, BasicBlock, Branch, MethodCall, NAMESPACE_TYPE, LoadAddress @@ -24,10 +26,10 @@ ) from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op from mypyc.irbuild.util import ( - is_dataclass_decorator, get_func_def, is_dataclass, is_constant + is_dataclass_decorator, get_func_def, is_constant, dataclass_type ) from mypyc.irbuild.builder import IRBuilder -from mypyc.irbuild.function import transform_method +from mypyc.irbuild.function import handle_ext_method, handle_non_ext_method, load_type def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: @@ -61,32 +63,28 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: # decorated or inherit from Enum. Classes decorated with @trait do not # apply here, and are handled in a different way. if ir.is_ext_class: - # If the class is not decorated, generate an extension class for it. - type_obj: Optional[Value] = allocate_class(builder, cdef) - non_ext: Optional[NonExtClassInfo] = None - dataclass_non_ext = dataclass_non_ext_info(builder, cdef) + cls_type = dataclass_type(cdef) + if cls_type is None: + cls_builder: ClassBuilder = ExtClassBuilder(builder, cdef) + elif cls_type in ['dataclasses', 'attr-auto']: + cls_builder = DataClassBuilder(builder, cdef) + elif cls_type == 'attr': + cls_builder = AttrsClassBuilder(builder, cdef) + else: + raise ValueError(cls_type) else: - non_ext_bases = populate_non_ext_bases(builder, cdef) - non_ext_metaclass = find_non_ext_metaclass(builder, cdef, non_ext_bases) - non_ext_dict = setup_non_ext_dict(builder, cdef, non_ext_metaclass, non_ext_bases) - # We populate __annotations__ for non-extension classes - # because dataclasses uses it to determine which attributes to compute on. - # TODO: Maybe generate more precise types for annotations - non_ext_anns = builder.call_c(dict_new_op, [], cdef.line) - non_ext = NonExtClassInfo(non_ext_dict, non_ext_bases, non_ext_anns, non_ext_metaclass) - dataclass_non_ext = None - attrs_to_cache: List[Tuple[Lvalue, RType]] = [] + cls_builder = NonExtClassBuilder(builder, cdef) for stmt in cdef.defs.body: if isinstance(stmt, OverloadedFuncDef) and stmt.is_property: - if not ir.is_ext_class: + if isinstance(cls_builder, NonExtClassBuilder): # properties with both getters and setters in non_extension # classes not supported builder.error("Property setters not supported in non-extension classes", stmt.line) for item in stmt.items: with builder.catch_errors(stmt.line): - transform_method(builder, cdef, non_ext, get_func_def(item)) + cls_builder.add_method(get_func_def(item)) elif isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef)): # Ignore plugin generated methods (since they have no # bodies to compile and will need to have the bodies @@ -94,7 +92,7 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: if cdef.info.names[stmt.name].plugin_generated: continue with builder.catch_errors(stmt.line): - transform_method(builder, cdef, non_ext, get_func_def(stmt)) + cls_builder.add_method(get_func_def(stmt)) elif isinstance(stmt, PassStmt): continue elif isinstance(stmt, AssignmentStmt): @@ -108,53 +106,217 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: continue # We want to collect class variables in a dictionary for both real # non-extension classes and fake dataclass ones. - var_non_ext = non_ext or dataclass_non_ext - if var_non_ext: - add_non_ext_class_attr(builder, var_non_ext, lvalue, stmt, cdef, attrs_to_cache) - if non_ext: - continue - # Variable declaration with no body - if isinstance(stmt.rvalue, TempNode): - continue - # Only treat marked class variables as class variables. - if not (is_class_var(lvalue) or stmt.is_final_def): - continue - typ = builder.load_native_type_object(cdef.fullname) - value = builder.accept(stmt.rvalue) - builder.call_c( - py_setattr_op, [typ, builder.load_str(lvalue.name), value], stmt.line) - if builder.non_function_scope() and stmt.is_final_def: - builder.init_final_static(lvalue, value, cdef.name) + cls_builder.add_attr(lvalue, stmt) + elif isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr): # Docstring. Ignore pass else: builder.error("Unsupported statement in class body", stmt.line) - if not non_ext: # That is, an extension class - generate_attr_defaults(builder, cdef) - create_ne_from_eq(builder, cdef) - if dataclass_non_ext: - assert type_obj - dataclass_finalize(builder, cdef, dataclass_non_ext, type_obj) - else: + cls_builder.finalize(ir) + + +class ClassBuilder: + """Create IR for a class definition. + + This is an abstract base class. + """ + + def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: + self.builder = builder + self.cdef = cdef + self.attrs_to_cache: List[Tuple[Lvalue, RType]] = [] + + @abstractmethod + def add_method(self, fdef: FuncDef) -> None: + """Add a method to the class IR""" + + @abstractmethod + def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: + """Add an attribute to the class IR""" + + @abstractmethod + def finalize(self, ir: ClassIR) -> None: + """Perform any final operations to complete the class IR""" + + +class NonExtClassBuilder(ClassBuilder): + def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: + super().__init__(builder, cdef) + self.non_ext = self.create_non_ext_info() + + def create_non_ext_info(self) -> NonExtClassInfo: + non_ext_bases = populate_non_ext_bases(self.builder, self.cdef) + non_ext_metaclass = find_non_ext_metaclass(self.builder, self.cdef, non_ext_bases) + non_ext_dict = setup_non_ext_dict(self.builder, self.cdef, non_ext_metaclass, + non_ext_bases) + # We populate __annotations__ for non-extension classes + # because dataclasses uses it to determine which attributes to compute on. + # TODO: Maybe generate more precise types for annotations + non_ext_anns = self.builder.call_c(dict_new_op, [], self.cdef.line) + return NonExtClassInfo(non_ext_dict, non_ext_bases, non_ext_anns, non_ext_metaclass) + + def add_method(self, fdef: FuncDef) -> None: + handle_non_ext_method(self.builder, self.non_ext, self.cdef, fdef) + + def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: + add_non_ext_class_attr_ann(self.builder, self.non_ext, lvalue, stmt) + add_non_ext_class_attr(self.builder, self.non_ext, lvalue, stmt, self.cdef, + self.attrs_to_cache) + + def finalize(self, ir: ClassIR) -> None: # Dynamically create the class via the type constructor - non_ext_class = load_non_ext_class(builder, ir, non_ext, cdef.line) - non_ext_class = load_decorated_class(builder, cdef, non_ext_class) + non_ext_class = load_non_ext_class(self.builder, ir, self.non_ext, self.cdef.line) + non_ext_class = load_decorated_class(self.builder, self.cdef, non_ext_class) # Save the decorated class - builder.add(InitStatic(non_ext_class, cdef.name, builder.module_name, NAMESPACE_TYPE)) + self.builder.add(InitStatic(non_ext_class, self.cdef.name, self.builder.module_name, + NAMESPACE_TYPE)) # Add the non-extension class to the dict - builder.call_c(dict_set_item_op, - [ - builder.load_globals_dict(), - builder.load_str(cdef.name), - non_ext_class - ], cdef.line) + self.builder.call_c(dict_set_item_op, + [ + self.builder.load_globals_dict(), + self.builder.load_str(self.cdef.name), + non_ext_class + ], self.cdef.line) # Cache any cacheable class attributes - cache_class_attrs(builder, attrs_to_cache, cdef) + cache_class_attrs(self.builder, self.attrs_to_cache, self.cdef) + + +class ExtClassBuilder(ClassBuilder): + def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: + super().__init__(builder, cdef) + # If the class is not decorated, generate an extension class for it. + self.type_obj: Optional[Value] = allocate_class(builder, cdef) + + def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool: + """Controls whether to skip generating a default for an attribute.""" + return False + + def add_method(self, fdef: FuncDef) -> None: + handle_ext_method(self.builder, self.cdef, fdef) + + def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: + # Variable declaration with no body + if isinstance(stmt.rvalue, TempNode): + return + # Only treat marked class variables as class variables. + if not (is_class_var(lvalue) or stmt.is_final_def): + return + typ = self.builder.load_native_type_object(self.cdef.fullname) + value = self.builder.accept(stmt.rvalue) + self.builder.call_c( + py_setattr_op, [typ, self.builder.load_str(lvalue.name), value], stmt.line) + if self.builder.non_function_scope() and stmt.is_final_def: + self.builder.init_final_static(lvalue, value, self.cdef.name) + + def finalize(self, ir: ClassIR) -> None: + generate_attr_defaults(self.builder, self.cdef, self.skip_attr_default) + create_ne_from_eq(self.builder, self.cdef) + + +class DataClassBuilder(ExtClassBuilder): + # controls whether an __annotations__ attribute should be added to the class + # __dict__. This is not desirable for attrs classes where auto_attribs is + # disabled, as attrs will reject it. + add_annotations_to_dict = True + + def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: + super().__init__(builder, cdef) + self.non_ext = self.create_non_ext_info() + + def create_non_ext_info(self) -> NonExtClassInfo: + """Set up a NonExtClassInfo to track dataclass attributes. + + In addition to setting up a normal extension class for dataclasses, + we also collect its class attributes like a non-extension class so + that we can hand them to the dataclass decorator. + """ + return NonExtClassInfo( + self.builder.call_c(dict_new_op, [], self.cdef.line), + self.builder.add(TupleSet([], self.cdef.line)), + self.builder.call_c(dict_new_op, [], self.cdef.line), + self.builder.add(LoadAddress(type_object_op.type, type_object_op.src, self.cdef.line)) + ) + + def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool: + return stmt.type is not None + + def get_type_annotation(self, stmt: AssignmentStmt) -> Optional[TypeInfo]: + # We populate __annotations__ because dataclasses uses it to determine + # which attributes to compute on. + ann_type = get_proper_type(stmt.type) + if isinstance(ann_type, Instance): + return ann_type.type + return None + + def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: + add_non_ext_class_attr_ann(self.builder, self.non_ext, lvalue, stmt, + self.get_type_annotation) + add_non_ext_class_attr(self.builder, self.non_ext, lvalue, stmt, self.cdef, + self.attrs_to_cache) + super().add_attr(lvalue, stmt) + + def finalize(self, ir: ClassIR) -> None: + """Generate code to finish instantiating a dataclass. + + This works by replacing all of the attributes on the class + (which will be descriptors) with whatever they would be in a + non-extension class, calling dataclass, then switching them back. + + The resulting class is an extension class and instances of it do not + have a __dict__ (unless something else requires it). + All methods written explicitly in the source are compiled and + may be called through the vtable while the methods generated + by dataclasses are interpreted and may not be. + + (If we just called dataclass without doing this, it would think that all + of the descriptors for our attributes are default values and generate an + incorrect constructor. We need to do the switch so that dataclass gets the + appropriate defaults.) + """ + super().finalize(ir) + assert self.type_obj + add_dunders_to_non_ext_dict(self.builder, self.non_ext, self.cdef.line, + self.add_annotations_to_dict) + dec = self.builder.accept( + next(d for d in self.cdef.decorators if is_dataclass_decorator(d))) + self.builder.call_c( + dataclass_sleight_of_hand, [dec, self.type_obj, self.non_ext.dict, self.non_ext.anns], + self.cdef.line) + + +class AttrsClassBuilder(DataClassBuilder): + """Create IR for an attrs class where auto_attribs=False (the default). + + When auto_attribs is enabled, attrs classes behave similarly to dataclasses + (i.e. types are stored as annotations on the class) and are thus handled + by DataClassBuilder, but when auto_attribs is disabled the types are + provided via attr.ib(type=...) + """ + + add_annotations_to_dict = False + + def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool: + return True + + def get_type_annotation(self, stmt: AssignmentStmt) -> Optional[TypeInfo]: + if isinstance(stmt.rvalue, CallExpr): + # find the type arg in `attr.ib(type=str)` + callee = stmt.rvalue.callee + if (isinstance(callee, MemberExpr) and + callee.fullname in ['attr.ib', 'attr.attr'] and + 'type' in stmt.rvalue.arg_names): + index = stmt.rvalue.arg_names.index('type') + type_name = stmt.rvalue.args[index] + if isinstance(type_name, NameExpr) and isinstance(type_name.node, TypeInfo): + lvalue = stmt.lvalues[0] + assert isinstance(lvalue, NameExpr) + return type_name.node + return None def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value: @@ -312,23 +474,39 @@ def setup_non_ext_dict(builder: IRBuilder, return non_ext_dict +def add_non_ext_class_attr_ann(builder: IRBuilder, + non_ext: NonExtClassInfo, + lvalue: NameExpr, + stmt: AssignmentStmt, + get_type_info: Optional[Callable[[AssignmentStmt], + Optional[TypeInfo]]] = None + ) -> None: + """Add a class attribute to __annotations__ of a non-extension class.""" + typ: Optional[Value] = None + if get_type_info is not None: + type_info = get_type_info(stmt) + if type_info: + typ = load_type(builder, type_info, stmt.line) + + if typ is None: + # FIXME: if get_type_info is not provided, don't fall back to stmt.type? + ann_type = get_proper_type(stmt.type) + if isinstance(ann_type, Instance): + typ = load_type(builder, ann_type.type, stmt.line) + else: + typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line)) + + key = builder.load_str(lvalue.name) + builder.call_c(dict_set_item_op, [non_ext.anns, key, typ], stmt.line) + + def add_non_ext_class_attr(builder: IRBuilder, non_ext: NonExtClassInfo, lvalue: NameExpr, stmt: AssignmentStmt, cdef: ClassDef, attr_to_cache: List[Tuple[Lvalue, RType]]) -> None: - """Add a class attribute to __annotations__ of a non-extension class. - - If the attribute is initialized with a value, also add it to __dict__. - """ - # We populate __annotations__ because dataclasses uses it to determine - # which attributes to compute on. - # TODO: Maybe generate more precise types for annotations - key = builder.load_str(lvalue.name) - typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line)) - builder.call_c(dict_set_item_op, [non_ext.anns, key, typ], stmt.line) - + """Add a class attribute to __dict__ of a non-extension class.""" # Only add the attribute to the __dict__ if the assignment is of the form: # x: type = value (don't add attributes of the form 'x: type' to the __dict__). if not isinstance(stmt.rvalue, TempNode): @@ -346,8 +524,14 @@ def add_non_ext_class_attr(builder: IRBuilder, attr_to_cache.append((lvalue, object_rprimitive)) -def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef) -> None: - """Generate an initialization method for default attr values (from class vars).""" +def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef, + skip: Optional[Callable[[str, AssignmentStmt], bool]] = None) -> None: + """Generate an initialization method for default attr values (from class vars). + + If provided, the skip arg should be a callable which will return whether + to skip generating a default for an attribute. It will be passed the name of + the attribute and the corresponding AssignmentStmt. + """ cls = builder.mapper.type_to_ir[cdef.info] if cls.builtin_base: return @@ -371,8 +555,7 @@ def generate_attr_defaults(builder: IRBuilder, cdef: ClassDef) -> None: check_deletable_declaration(builder, cls, stmt.line) continue - # Skip type annotated assignments in dataclasses - if is_dataclass(cdef) and stmt.type: + if skip is not None and skip(name, stmt): continue default_assignments.append(stmt) @@ -456,7 +639,7 @@ def load_non_ext_class(builder: IRBuilder, line: int) -> Value: cls_name = builder.load_str(ir.name) - finish_non_ext_dict(builder, non_ext, line) + add_dunders_to_non_ext_dict(builder, non_ext, line) class_type_obj = builder.py_call( non_ext.metaclass, @@ -503,11 +686,11 @@ def create_mypyc_attrs_tuple(builder: IRBuilder, ir: ClassIR, line: int) -> Valu return builder.new_tuple(items, line) -def finish_non_ext_dict(builder: IRBuilder, non_ext: NonExtClassInfo, line: int) -> None: - # Add __annotations__ to the class dict. - builder.call_c(dict_set_item_op, - [non_ext.dict, builder.load_str('__annotations__'), - non_ext.anns], -1) +def add_dunders_to_non_ext_dict(builder: IRBuilder, non_ext: NonExtClassInfo, + line: int, add_annotations: bool = True) -> None: + if add_annotations: + # Add __annotations__ to the class dict. + builder.add_to_non_ext_dict(non_ext, '__annotations__', non_ext.anns, line) # We add a __doc__ attribute so if the non-extension class is decorated with the # dataclass decorator, dataclass will not try to look for __text_signature__. @@ -517,46 +700,3 @@ def finish_non_ext_dict(builder: IRBuilder, non_ext: NonExtClassInfo, line: int) non_ext, '__doc__', builder.load_str(filler_doc_str), line) builder.add_to_non_ext_dict( non_ext, '__module__', builder.load_str(builder.module_name), line) - - -def dataclass_finalize( - builder: IRBuilder, cdef: ClassDef, non_ext: NonExtClassInfo, type_obj: Value) -> None: - """Generate code to finish instantiating a dataclass. - - This works by replacing all of the attributes on the class - (which will be descriptors) with whatever they would be in a - non-extension class, calling dataclass, then switching them back. - - The resulting class is an extension class and instances of it do not - have a __dict__ (unless something else requires it). - All methods written explicitly in the source are compiled and - may be called through the vtable while the methods generated - by dataclasses are interpreted and may not be. - - (If we just called dataclass without doing this, it would think that all - of the descriptors for our attributes are default values and generate an - incorrect constructor. We need to do the switch so that dataclass gets the - appropriate defaults.) - """ - finish_non_ext_dict(builder, non_ext, cdef.line) - dec = builder.accept(next(d for d in cdef.decorators if is_dataclass_decorator(d))) - builder.call_c( - dataclass_sleight_of_hand, [dec, type_obj, non_ext.dict, non_ext.anns], cdef.line) - - -def dataclass_non_ext_info(builder: IRBuilder, cdef: ClassDef) -> Optional[NonExtClassInfo]: - """Set up a NonExtClassInfo to track dataclass attributes. - - In addition to setting up a normal extension class for dataclasses, - we also collect its class attributes like a non-extension class so - that we can hand them to the dataclass decorator. - """ - if is_dataclass(cdef): - return NonExtClassInfo( - builder.call_c(dict_new_op, [], cdef.line), - builder.add(TupleSet([], cdef.line)), - builder.call_c(dict_new_op, [], cdef.line), - builder.add(LoadAddress(type_object_op.type, type_object_op.src, cdef.line)) - ) - else: - return None diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index c0b5a3bf4ff40..6a42820b9b21e 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -38,7 +38,7 @@ from mypyc.primitives.set_ops import set_add_op, set_update_op from mypyc.primitives.str_ops import str_slice_op from mypyc.primitives.int_ops import int_comparison_op_mapping -from mypyc.irbuild.specialize import specializers +from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.for_helpers import ( translate_list_comprehension, translate_set_comprehension, @@ -209,7 +209,8 @@ def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value: callee = callee.analyzed.expr # Unwrap type application if isinstance(callee, MemberExpr): - return translate_method_call(builder, expr, callee) + return apply_method_specialization(builder, expr, callee) or \ + translate_method_call(builder, expr, callee) elif isinstance(callee, SuperExpr): return translate_super_method_call(builder, expr, callee) else: @@ -219,7 +220,8 @@ def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value: def translate_call(builder: IRBuilder, expr: CallExpr, callee: Expression) -> Value: # The common case of calls is refexprs if isinstance(callee, RefExpr): - return translate_refexpr_call(builder, expr, callee) + return apply_function_specialization(builder, expr, callee) or \ + translate_refexpr_call(builder, expr, callee) function = builder.accept(callee) args = [builder.accept(arg) for arg in expr.args] @@ -229,18 +231,6 @@ def translate_call(builder: IRBuilder, expr: CallExpr, callee: Expression) -> Va def translate_refexpr_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value: """Translate a non-method call.""" - - # TODO: Allow special cases to have default args or named args. Currently they don't since - # they check that everything in arg_kinds is ARG_POS. - - # If there is a specializer for this function, try calling it. - # We would return the first successful one. - if callee.fullname and (callee.fullname, None) in specializers: - for specializer in specializers[callee.fullname, None]: - val = specializer(builder, expr, callee) - if val is not None: - return val - # Gen the argument values arg_values = [builder.accept(arg) for arg in expr.args] @@ -297,11 +287,9 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr # If there is a specializer for this method name/type, try calling it. # We would return the first successful one. - if (callee.name, receiver_typ) in specializers: - for specializer in specializers[callee.name, receiver_typ]: - val = specializer(builder, expr, callee) - if val is not None: - return val + val = apply_method_specialization(builder, expr, callee, receiver_typ) + if val is not None: + return val obj = builder.accept(callee.expr) args = [builder.accept(arg) for arg in expr.args] diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index bdd4ed992f2fd..5b567251111ab 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -117,16 +117,6 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: builder.functions.append(func_ir) -def transform_method(builder: IRBuilder, - cdef: ClassDef, - non_ext: Optional[NonExtClassInfo], - fdef: FuncDef) -> None: - if non_ext: - handle_non_ext_method(builder, non_ext, cdef, fdef) - else: - handle_ext_method(builder, cdef, fdef) - - def transform_lambda_expr(builder: IRBuilder, expr: LambdaExpr) -> Value: typ = get_proper_type(builder.types[expr]) assert isinstance(typ, CallableType) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 05975d7b2bf5e..d35039ecc0bc0 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -57,6 +57,34 @@ specializers: Dict[Tuple[str, Optional[RType]], List[Specializer]] = {} +def _apply_specialization(builder: 'IRBuilder', expr: CallExpr, callee: RefExpr, + name: Optional[str], typ: Optional[RType] = None) -> Optional[Value]: + # TODO: Allow special cases to have default args or named args. Currently they don't since + # they check that everything in arg_kinds is ARG_POS. + + # If there is a specializer for this function, try calling it. + # Return the first successful one. + if name and (name, typ) in specializers: + for specializer in specializers[name, typ]: + val = specializer(builder, expr, callee) + if val is not None: + return val + return None + + +def apply_function_specialization(builder: 'IRBuilder', expr: CallExpr, + callee: RefExpr) -> Optional[Value]: + """Invoke the Specializer callback for a function if one has been registered""" + return _apply_specialization(builder, expr, callee, callee.fullname) + + +def apply_method_specialization(builder: 'IRBuilder', expr: CallExpr, callee: MemberExpr, + typ: Optional[RType] = None) -> Optional[Value]: + """Invoke the Specializer callback for a method if one has been registered""" + name = callee.fullname if typ is None else callee.name + return _apply_specialization(builder, expr, callee, name, typ) + + def specialize_function( name: str, typ: Optional[RType] = None) -> Callable[[Specializer], Specializer]: """Decorator to register a function as being a specializer. @@ -329,10 +357,12 @@ def gen_inner_stmts() -> None: @specialize_function('dataclasses.field') +@specialize_function('attr.ib') +@specialize_function('attr.attrib') @specialize_function('attr.Factory') def translate_dataclasses_field_call( builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: - """Special case for 'dataclasses.field' and 'attr.Factory' + """Special case for 'dataclasses.field', 'attr.attrib', and 'attr.Factory' function calls because the results of such calls are type-checked by mypy using the types of the arguments to their respective functions, resulting in attempted coercions by mypyc that throw a diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index cceef36a9feb0..7a7b95245d4c3 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -4,11 +4,18 @@ from mypy.nodes import ( ClassDef, FuncDef, Decorator, OverloadedFuncDef, StrExpr, CallExpr, RefExpr, Expression, - IntExpr, FloatExpr, Var, TupleExpr, UnaryExpr, BytesExpr, + IntExpr, FloatExpr, Var, NameExpr, TupleExpr, UnaryExpr, BytesExpr, ArgKind, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, ARG_OPT, GDEF, ) +DATACLASS_DECORATORS = { + 'dataclasses.dataclass', + 'attr.s', + 'attr.attrs', +} + + def is_trait_decorator(d: Expression) -> bool: return isinstance(d, RefExpr) and d.fullname == 'mypy_extensions.trait' @@ -17,21 +24,40 @@ def is_trait(cdef: ClassDef) -> bool: return any(is_trait_decorator(d) for d in cdef.decorators) or cdef.info.is_protocol -def is_dataclass_decorator(d: Expression) -> bool: - return ( - (isinstance(d, RefExpr) and d.fullname == 'dataclasses.dataclass') - or ( - isinstance(d, CallExpr) +def dataclass_decorator_type(d: Expression) -> Optional[str]: + if isinstance(d, RefExpr) and d.fullname in DATACLASS_DECORATORS: + return d.fullname.split('.')[0] + elif (isinstance(d, CallExpr) and isinstance(d.callee, RefExpr) - and d.callee.fullname == 'dataclasses.dataclass' - ) - ) + and d.callee.fullname in DATACLASS_DECORATORS): + name = d.callee.fullname.split('.')[0] + if name == 'attr' and 'auto_attribs' in d.arg_names: + # Note: the mypy attrs plugin checks that the value of auto_attribs is + # not computed at runtime, so we don't need to perform that check here + auto = d.args[d.arg_names.index('auto_attribs')] + if isinstance(auto, NameExpr) and auto.name == 'True': + return 'attr-auto' + return name + else: + return None + + +def is_dataclass_decorator(d: Expression) -> bool: + return dataclass_decorator_type(d) is not None def is_dataclass(cdef: ClassDef) -> bool: return any(is_dataclass_decorator(d) for d in cdef.decorators) +def dataclass_type(cdef: ClassDef) -> Optional[str]: + for d in cdef.decorators: + typ = dataclass_decorator_type(d) + if typ is not None: + return typ + return None + + def get_mypyc_attr_literal(e: Expression) -> Any: """Convert an expression from a mypyc_attr decorator to a value. diff --git a/mypyc/test-data/run-attrs.test b/mypyc/test-data/run-attrs.test new file mode 100644 index 0000000000000..9c402a3eea7c5 --- /dev/null +++ b/mypyc/test-data/run-attrs.test @@ -0,0 +1,318 @@ +-- Test cases for dataclasses based on the attrs library, where auto_attribs=True + +[case testRunAttrsclass] +import attr +from typing import Set, List, Callable, Any + +@attr.s(auto_attribs=True) +class Person1: + age : int + name : str + + def __bool__(self) -> bool: + return self.name == 'robot' + +def testBool(p: Person1) -> bool: + if p: + return True + else: + return False + +@attr.s(auto_attribs=True) +class Person1b(Person1): + id: str = '000' + +@attr.s(auto_attribs=True) +class Person2: + age : int + name : str = attr.ib(default='robot') + +@attr.s(auto_attribs=True, order=True) +class Person3: + age : int = attr.ib(default = 6) + friendIDs : List[int] = attr.ib(factory = list) + + def get_age(self) -> int: + return (self.age) + + def set_age(self, new_age : int) -> None: + self.age = new_age + + def add_friendID(self, fid : int) -> None: + self.friendIDs.append(fid) + + def get_friendIDs(self) -> List[int]: + return self.friendIDs + +def get_next_age(g: Callable[[Any], int]) -> Callable[[Any], int]: + def f(a: Any) -> int: + return g(a) + 1 + return f + +@attr.s(auto_attribs=True) +class Person4: + age : int + _name : str = 'Bot' + + @get_next_age + def get_age(self) -> int: + return self.age + + @property + def name(self) -> str: + return self._name + +@attr.s(auto_attribs=True) +class Point: + x : int = attr.ib(converter=int) + y : int = attr.ib(init=False) + + def __attrs_post_init__(self): + self.y = self.x + 1 + + +[file other.py] +from native import Person1, Person1b, Person2, Person3, Person4, testBool, Point +i1 = Person1(age = 5, name = 'robot') +assert i1.age == 5 +assert i1.name == 'robot' +assert testBool(i1) == True +assert testBool(Person1(age = 5, name = 'robo')) == False +i1b = Person1b(age = 5, name = 'robot') +assert i1b.age == 5 +assert i1b.name == 'robot' +assert i1b.id == '000' +assert testBool(i1b) == True +assert testBool(Person1b(age = 5, name = 'robo')) == False +i1c = Person1b(age = 20, name = 'robot', id = 'test') +assert i1c.age == 20 +assert i1c.id == 'test' + +i2 = Person2(age = 5) +assert i2.age == 5 +assert i2.name == 'robot' +i3 = Person2(age = 5, name = 'new_robot') +assert i3.age == 5 +assert i3.name == 'new_robot' +i4 = Person3() +assert i4.age == 6 +assert i4.friendIDs == [] +i5 = Person3(age = 5) +assert i5.age == 5 +assert i5.friendIDs == [] +i6 = Person3(age = 5, friendIDs = [1,2,3]) +assert i6.age == 5 +assert i6.friendIDs == [1,2,3] +assert i6.get_age() == 5 +i6.set_age(10) +assert i6.get_age() == 10 +i6.add_friendID(4) +assert i6.get_friendIDs() == [1,2,3,4] +i7 = Person4(age = 5) +assert i7.get_age() == 6 +i7.age += 3 +assert i7.age == 8 +assert i7.name == 'Bot' +i8 = Person3(age = 1, friendIDs = [1,2]) +i9 = Person3(age = 1, friendIDs = [1,2]) +assert i8 == i9 +i8.age = 2 +assert i8 > i9 + +assert Person1.__annotations__ == {'age': int, 'name': str} +assert Person2.__annotations__ == {'age': int, 'name': str} + +p1 = Point(2) +assert p1.x == 2 +assert p1.y == 3 +p2 = Point('2') +assert p2.x == 2 +assert p2.y == 3 + +assert Point.__annotations__ == {'x': int, 'y': int} + +[file driver.py] +import sys + +# PEP 526 introduced in 3.6 +version = sys.version_info[:2] +if version[0] < 3 or version[1] < 6: + exit() + +# Run the tests in both interpreted and compiled mode +import other +import other_interpreted + +# Test for an exceptional cases +from testutil import assertRaises +from native import Person1, Person1b, Person3 +from types import BuiltinMethodType + +with assertRaises(TypeError, "missing 1 required positional argument"): + Person1(0) + +with assertRaises(TypeError, "missing 2 required positional arguments"): + Person1b() + +with assertRaises(TypeError, "int object expected; got str"): + Person1('nope', 'test') + +p = Person1(0, 'test') +with assertRaises(TypeError, "int object expected; got str"): + p.age = 'nope' + +assert isinstance(Person3().get_age, BuiltinMethodType) + + +[case testRunAttrsclassNonAuto] +import attr +from typing import Set, List, Callable, Any + +@attr.s +class Person1: + age = attr.ib(type=int) + name = attr.ib(type=str) + + def __bool__(self) -> bool: + return self.name == 'robot' + +def testBool(p: Person1) -> bool: + if p: + return True + else: + return False + +@attr.s +class Person1b(Person1): + id = attr.ib(type=str, default='000') + +@attr.s +class Person2: + age = attr.ib(type=int) + name = attr.ib(type=str, default='robot') + +@attr.s(order=True) +class Person3: + age = attr.ib(type=int, default=6) + friendIDs = attr.ib(factory=list, type=List[int]) + + def get_age(self) -> int: + return (self.age) + + def set_age(self, new_age : int) -> None: + self.age = new_age + + def add_friendID(self, fid : int) -> None: + self.friendIDs.append(fid) + + def get_friendIDs(self) -> List[int]: + return self.friendIDs + +def get_next_age(g: Callable[[Any], int]) -> Callable[[Any], int]: + def f(a: Any) -> int: + return g(a) + 1 + return f + +@attr.s +class Person4: + age = attr.ib(type=int) + _name = attr.ib(type=str, default='Bot') + + @get_next_age + def get_age(self) -> int: + return self.age + + @property + def name(self) -> str: + return self._name + +@attr.s +class Point: + x = attr.ib(type=int, converter=int) + y = attr.ib(type=int, init=False) + + def __attrs_post_init__(self): + self.y = self.x + 1 + + +[file other.py] +from native import Person1, Person1b, Person2, Person3, Person4, testBool, Point +i1 = Person1(age = 5, name = 'robot') +assert i1.age == 5 +assert i1.name == 'robot' +assert testBool(i1) == True +assert testBool(Person1(age = 5, name = 'robo')) == False +i1b = Person1b(age = 5, name = 'robot') +assert i1b.age == 5 +assert i1b.name == 'robot' +assert i1b.id == '000' +assert testBool(i1b) == True +assert testBool(Person1b(age = 5, name = 'robo')) == False +i1c = Person1b(age = 20, name = 'robot', id = 'test') +assert i1c.age == 20 +assert i1c.id == 'test' + +i2 = Person2(age = 5) +assert i2.age == 5 +assert i2.name == 'robot' +i3 = Person2(age = 5, name = 'new_robot') +assert i3.age == 5 +assert i3.name == 'new_robot' +i4 = Person3() +assert i4.age == 6 +assert i4.friendIDs == [] +i5 = Person3(age = 5) +assert i5.age == 5 +assert i5.friendIDs == [] +i6 = Person3(age = 5, friendIDs = [1,2,3]) +assert i6.age == 5 +assert i6.friendIDs == [1,2,3] +assert i6.get_age() == 5 +i6.set_age(10) +assert i6.get_age() == 10 +i6.add_friendID(4) +assert i6.get_friendIDs() == [1,2,3,4] +i7 = Person4(age = 5) +assert i7.get_age() == 6 +i7.age += 3 +assert i7.age == 8 +assert i7.name == 'Bot' +i8 = Person3(age = 1, friendIDs = [1,2]) +i9 = Person3(age = 1, friendIDs = [1,2]) +assert i8 == i9 +i8.age = 2 +assert i8 > i9 + +p1 = Point(2) +assert p1.x == 2 +assert p1.y == 3 +p2 = Point('2') +assert p2.x == 2 +assert p2.y == 3 + +[file driver.py] +import sys + +# Run the tests in both interpreted and compiled mode +import other +import other_interpreted + +# Test for an exceptional cases +from testutil import assertRaises +from native import Person1, Person1b, Person3 +from types import BuiltinMethodType + +with assertRaises(TypeError, "missing 1 required positional argument"): + Person1(0) + +with assertRaises(TypeError, "missing 2 required positional arguments"): + Person1b() + +with assertRaises(TypeError, "int object expected; got str"): + Person1('nope', 'test') + +p = Person1(0, 'test') +with assertRaises(TypeError, "int object expected; got str"): + p.age = 'nope' + +assert isinstance(Person3().get_age, BuiltinMethodType) diff --git a/mypyc/test-data/run-python37.test b/mypyc/test-data/run-python37.test index 3660ba13a6c8d..734e116c13354 100644 --- a/mypyc/test-data/run-python37.test +++ b/mypyc/test-data/run-python37.test @@ -1,6 +1,7 @@ -- Test cases for Python 3.7 features [case testRunDataclass] +import dataclasses from dataclasses import dataclass, field from typing import Set, List, Callable, Any @@ -27,6 +28,11 @@ class Person2: age : int name : str = field(default='robot') +@dataclasses.dataclass +class Person2b: + age : int + name : str = dataclasses.field(default='robot') + @dataclass(order = True) class Person3: age : int = field(default = 6) @@ -109,6 +115,8 @@ assert i8 == i9 i8.age = 2 assert i8 > i9 +assert Person1.__annotations__ == {'age': int, 'name': str} +assert Person2.__annotations__ == {'age': int, 'name': str} [file driver.py] import sys diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 3876fee2a1579..306c151ef319c 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -52,8 +52,10 @@ 'run-bench.test', 'run-mypy-sim.test', 'run-dunders.test', - 'run-singledispatch.test' + 'run-singledispatch.test', + 'run-attrs.test', ] + if sys.version_info >= (3, 7): files.append('run-python37.test') if sys.version_info >= (3, 8):