diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index e492b8dd7335..6a97c1ab8107 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -5,7 +5,7 @@ import mypy.errorcodes as codes from mypy import message_registry -from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr +from mypy.nodes import DictExpr, IntExpr, StrExpr, TypeInfo, UnaryExpr from mypy.plugin import ( AttributeContext, ClassDefContext, @@ -47,7 +47,12 @@ dataclass_tag_callback, replace_function_sig_callback, ) -from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback +from mypy.plugins.enums import ( + enum_member_callback, + enum_name_callback, + enum_new_callback, + enum_value_callback, +) from mypy.plugins.functools import ( functools_total_ordering_maker_callback, functools_total_ordering_makers, @@ -104,6 +109,12 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] return partial_new_callback elif fullname == "enum.member": return enum_member_callback + elif ( + (st := self.lookup_fully_qualified(fullname)) + and isinstance(st.node, TypeInfo) + and getattr(st.node, "is_enum", False) + ): + return enum_new_callback return None def get_function_signature_hook( diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index 0be2e083b6dd..79929691eed6 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -28,6 +28,8 @@ LiteralType, ProperType, Type, + TypeVarType, + UnionType, get_proper_type, is_named_instance, ) @@ -297,3 +299,62 @@ def _extract_underlying_field_name(typ: Type) -> str | None: # as a string. assert isinstance(underlying_literal.value, str) return underlying_literal.value + + +def enum_new_callback(ctx: mypy.plugin.FunctionContext) -> Type: + """This plugin refines the return type of `__new__`, ensuring reconstructed + Enums are idempotent. + + By default, mypy will infer that `Foo(Foo.x)` is of type `Foo`. This plugin + ensures types are not loosened, meaning with this plugin enabled + `Foo(Foo.x)` is of type `Literal[Foo.x]?`. + + This means with this plugin: + ``` + reveal_type(Foo(Foo.x)) # mypy reveals Literal[Foo.x]? + ``` + + This plugin works by adjusting the return type of `__new__` to be the given + argument type, if and only if `__new__` comes from `enum.Enum`. + + This plugin supports arguments that are Final, Literial, Union of Literials + and generic TypeVars. + """ + base_ret = ctx.default_return_type + enum_inst = get_proper_type(base_ret) + if not isinstance(enum_inst, Instance): + return base_ret + + info: TypeInfo = enum_inst.type + if not info.is_enum: + return base_ret + + if _implements_new(info): + return base_ret + + if not ctx.args or not ctx.args[0] or not ctx.arg_types or not ctx.arg_types[0]: + return base_ret + + arg0_t = get_proper_type(ctx.arg_types[0][0]) + + if isinstance(arg0_t, Instance) and arg0_t.type is info: + return arg0_t + elif isinstance(arg0_t, LiteralType) and arg0_t.fallback.type is info: + return arg0_t + elif isinstance(arg0_t, UnionType): + + def is_memeber(given_t: ProperType) -> bool: + return (isinstance(given_t, Instance) and given_t.type is info) or ( + isinstance(given_t, LiteralType) and given_t.fallback.type is info + ) + + items = [get_proper_type(it) for it in arg0_t.items] + if items and all(is_memeber(item) for item in items): + return arg0_t + elif (isinstance(arg0_t, TypeVarType)) and isinstance( + upperbound_t := get_proper_type(arg0_t.upper_bound), Instance + ): + if upperbound_t.type is info: + return arg0_t + + return base_ret diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 3bcf9745a801..4be1b514c75e 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -40,6 +40,415 @@ reveal_type(Animal.BEE) # N: Revealed type is "Literal[__main__.Animal.BEE]?" reveal_type(Animal.CAT) # N: Revealed type is "Literal[__main__.Animal.CAT]?" reveal_type(Animal.DOG) # N: Revealed type is "Literal[__main__.Animal.DOG]?" +-- Ensure idempotentency +-- ----------------------- + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_enum] +from enum import Enum + +class E(Enum): + A = "a" + B = "b" + +reveal_type(E(E.A)) # N: Revealed type is "Literal[__main__.E.A]?" +reveal_type(E(E.B)) # N: Revealed type is "Literal[__main__.E.B]?" +reveal_type(E("a")) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_intenum] +from enum import IntEnum + +class I(IntEnum): + A = 0 + B = 1 + +reveal_type(I(I.A)) # N: Revealed type is "Literal[__main__.I.A]?" +reveal_type(I(I.B)) # N: Revealed type is "Literal[__main__.I.B]?" +reveal_type(I(0)) # N: Revealed type is "__main__.I" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_strenum] +# flags: --python-version 3.11 +from enum import StrEnum + +class S(StrEnum): + A = "a" + B = "b" + +reveal_type(S(S.A)) # N: Revealed type is "Literal[__main__.S.A]?" +reveal_type(S(S.B)) # N: Revealed type is "Literal[__main__.S.B]?" +reveal_type(S("x")) # N: Revealed type is "__main__.S" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_union_members] +from enum import Enum +from typing import Literal, Union + +class E(Enum): + A = "a" + B = "b" + C = "c" + +u: Union[Literal[E.A], Literal[E.B]] = E.A +reveal_type(E(u)) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_alias_final] +from enum import Enum +from typing import Final + +class E(Enum): + A = "a" + B = "b" + +o = E.A +m: Final = E.A +n: Final = E.B + +reveal_type(E(o)) # N: Revealed type is "__main__.E" +reveal_type(E(m)) # N: Revealed type is "Literal[__main__.E.A]?" +reveal_type(E(n)) # N: Revealed type is "Literal[__main__.E.B]?" +[out] + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_alias_literal] +from enum import Enum +from typing import Final, Literal + +class E(Enum): + A = "a" + B = "b" + +o = E.A +m: Literal[E.A] = E.A +n: Literal[E.B] = E.B + +reveal_type(m) # N: Revealed type is "Literal[__main__.E.A]" +reveal_type(E(o)) # N: Revealed type is "__main__.E" +reveal_type(E(m)) # N: Revealed type is "Literal[__main__.E.A]" +reveal_type(E(n)) # N: Revealed type is "Literal[__main__.E.B]" +[out] + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_alias_nonfinal] +from enum import Enum + +class E(Enum): + A = "a" + B = "b" + +m = E.A # non-Final alias -> typically typed as E, not a literal +reveal_type(m) # N: Revealed type is "__main__.E" +reveal_type(E(m)) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_intflag] +from enum import IntFlag + +class F(IntFlag): + A = 1 + B = 2 + +reveal_type(F(F.A)) # N: Revealed type is "Literal[__main__.F.A]?" +reveal_type(F(F.B)) # N: Revealed type is "Literal[__main__.F.B]?" +reveal_type(F(1)) # N: Revealed type is "__main__.F" +reveal_type(F(F.A | F.B)) # N: Revealed type is "__main__.F" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_callsite_compat] +from enum import Enum +from typing import Literal + +class E(Enum): + A = "a" + B = "b" + +def takes_A(x: Literal[E.A]) -> None: pass +def takes_B(x: Literal[E.B]) -> None: pass + +takes_A(E(E.A)) # OK +takes_B(E(E.B)) # OK +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_functional_api] +from enum import Enum + +E = Enum("E", "A B") +reveal_type(E(E.A)) # N: Revealed type is "Literal[__main__.E.A]?" +reveal_type(E(E.B)) # N: Revealed type is "Literal[__main__.E.B]?" +reveal_type(E("A")) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_dctor_non_enum_no_change] +class C: + def __init__(self, x: int) -> None: + self.x = x + +reveal_type(C(1)) # N: Revealed type is "__main__.C" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_idempotent_with_custom_new] +from enum import Enum +from typing import Self + +class NewEnum(Enum): + A = "a" + B = "B" + + def __new__(self, value: "NewEnum") -> "NewEnum": + return NewEnum.A + + +reveal_type(NewEnum(NewEnum.A)) # N: Revealed type is "__main__.NewEnum" +reveal_type(NewEnum(NewEnum.B)) # N: Revealed type is "__main__.NewEnum" +[out] + +[builtins fixtures/tuple.pyi] +[case enum_ctor_literal_via_function_call] +from enum import Enum +from typing import Literal + +class E(Enum): + A = "a" + B = "b" + +def get_A() -> Literal[E.A]: # returns a literal by type, but arg expr is a CallExpr + return E.A + +reveal_type(E(get_A())) # N: Revealed type is "Literal[__main__.E.A]" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_union_via_function_call] +from enum import Enum +from typing import Literal, Union + +class E(Enum): + A = "a" + B = "b" + C = "c" + +def get_u() -> Union[Literal[E.A], Literal[E.B]]: + return E.A + +reveal_type(E(get_u())) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_literal_via_cast] +from enum import Enum +from typing import Literal, cast + +class E(Enum): + A = "a" + B = "b" + +reveal_type(E(cast(Literal[E.A], E.A))) # N: Revealed type is "Literal[__main__.E.A]" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_union_via_cast] +from enum import Enum +from typing import Literal, Union, cast + +class E(Enum): + A = "a" + B = "b" + C = "c" + +u = cast(Union[Literal[E.A], Literal[E.B]], E.A) + +reveal_type(E(u)) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" +[out] + +[builtins fixtures/tuple.pyi] +[case enum_ctor_union_mixed_instance_and_literal] +from enum import Enum +from typing import Literal, Union + +class E(Enum): + A = "a" + B = "b" + +u: Union[E, Literal[E.A]] = E.A +# Not a pure union of member literals -> should stay E +reveal_type(E(u)) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_optional_literal] +from enum import Enum +from typing import Optional, Literal + +class E(Enum): + A = "a" + B = "b" + +v: Optional[Literal[E.A]] = E.A +# Contains None → not a pure union of member literals -> should stay E +reveal_type(E(v)) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_union_via_conditional_expression] +from enum import Enum +from typing import Literal, Union + +class E(Enum): + A = "a" + B = "b" + C = "c" + +flag: bool +# The expression is a ConditionalExpr (not a NameExpr), but the *type* should be a union of member literals. +w: Union[Literal[E.A], Literal[E.B]] = E.A +w = E.A if flag else E.B +reveal_type(E(w)) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" +[out] + + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_custom_new_should_not_refine] +from enum import Enum + +class E(Enum): + def __new__(cls, val): + obj = object.__new__(cls) + obj._value_ = val + return obj + A = 1 + B = 2 + +reveal_type(E(E.A)) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_non_final_alias_to_member] +from enum import Enum + +class E(Enum): + A = "a" + +m = E.A # non-Final alias; typical type is just E +reveal_type(E(m)) # N: Revealed type is "__main__.E" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_alias_to_class] +from enum import Enum + +class E(Enum): + A = "a" + B = "b" + +Alias = E # type alias to the class +reveal_type(Alias(E.A)) # N: Revealed type is "Literal[__main__.E.A]?" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_intflag_combinations] +from enum import IntFlag + +class F(IntFlag): + A = 1 + B = 2 + +reveal_type(F(F.A | F.B)) # N: Revealed type is "__main__.F" +[out] + +[builtins fixtures/tuple.pyi] +[case enumctor_member_direct] +from enum import Enum + +class E(Enum): + A = "a" + B = "b" + +reveal_type(E(E.A)) # N: Revealed type is "Literal[__main__.E.A]?" +reveal_type(E(E.B)) # N: Revealed type is "Literal[__main__.E.B]?" +[out] + + +[builtins fixtures/tuple.pyi] +[case enum_ctor_member_union_var] +from enum import Enum +from typing import Literal, Union + +class E(Enum): + A = "a" + B = "b" + C = "c" + +u: Union[Literal[E.A], Literal[E.B]] = E.A +reveal_type(E(u)) # N: Revealed type is "Union[Literal[__main__.E.A], Literal[__main__.E.B]]" +[out] + +[builtins fixtures/tuple.pyi] +[case enum_ctor_literal_via_generic_identity] +from enum import Enum +from typing import TypeVar, Literal + +class E(Enum): + A = "a" + B = "b" + +T = TypeVar("T", bound=E) + +def ident(x: T) -> T: + return E(x) + +x: Literal[E.A] = E.A +reveal_type(ident(x)) # N: Revealed type is "Literal[__main__.E.A]" + +[out] + +[builtins fixtures/tuple.pyi] +[case enum_ctor_literal_via_generic_idedntity_2] +from enum import IntEnum +from typing import TypeVar, Generic + +class Option(IntEnum): + x=0 + y=1 + +T = TypeVar("T", bound=Option) + +class Base(Generic[T]): + option: T + + def __init__(self, option: T) -> None: + self.option = Option(option) # for runtime safety + +[out] + [builtins fixtures/tuple.pyi] -- Creation from EnumMeta