Skip to content

Commit

Permalink
Support enum.member for python3.11+ (python#17382)
Browse files Browse the repository at this point in the history
There are no tests for `@enum.member` used as a decorator, because I can
only decorate classes and functions, which are not supported right now:
https://mypy-play.net/?mypy=latest&python=3.12&gist=449ee8c12eba9f807cfc7832f1ea2c49

```python
import enum

class A(enum.Enum):
    class x: ...

reveal_type(A.x)  # Revealed type is "def () -> __main__.A.x"
```

This issue is separate and rather complex, so I would prefer to solve it
independently.

Refs python#17376

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
  • Loading branch information
sobolevn and AlexWaygood authored Jun 14, 2024
1 parent 5dd062a commit dac88f3
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""

def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
from mypy.plugins import ctypes, singledispatch
from mypy.plugins import ctypes, enums, singledispatch

if fullname == "_ctypes.Array":
return ctypes.array_constructor_callback
Expand All @@ -51,6 +51,8 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
import mypy.plugins.functools

return mypy.plugins.functools.partial_new_callback
elif fullname == "enum.member":
return enums.enum_member_callback

return None

Expand Down
18 changes: 18 additions & 0 deletions mypy/plugins/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def _infer_value_type_with_auto_fallback(
return None
proper_type = get_proper_type(fixup_partial_type(proper_type))
if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"):
if is_named_instance(proper_type, "enum.member") and proper_type.args:
return proper_type.args[0]
return proper_type
assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed."
info = ctx.type.type
Expand Down Expand Up @@ -126,6 +128,22 @@ def _implements_new(info: TypeInfo) -> bool:
return type_with_new.fullname not in ("enum.Enum", "enum.IntEnum", "enum.StrEnum")


def enum_member_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""By default `member(1)` will be infered as `member[int]`,
we want to improve the inference to be `Literal[1]` here."""
if ctx.arg_types or ctx.arg_types[0]:
arg = get_proper_type(ctx.arg_types[0][0])
proper_return = get_proper_type(ctx.default_return_type)
if (
isinstance(arg, Instance)
and arg.last_known_value
and isinstance(proper_return, Instance)
and len(proper_return.args) == 1
):
return proper_return.copy_modified(args=[arg])
return ctx.default_return_type


def enum_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
Expand Down
19 changes: 19 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -2166,3 +2166,22 @@ class Other(Enum):
reveal_type(Other.a) # N: Revealed type is "Literal[__main__.Other.a]?"
reveal_type(Other.Support.b) # N: Revealed type is "builtins.int"
[builtins fixtures/dict.pyi]


[case testEnumMemberSupport]
# flags: --python-version 3.11
# This was added in 3.11
from enum import Enum, member

class A(Enum):
x = member(1)
y = 2

reveal_type(A.x) # N: Revealed type is "Literal[__main__.A.x]?"
reveal_type(A.x.value) # N: Revealed type is "Literal[1]?"
reveal_type(A.y) # N: Revealed type is "Literal[__main__.A.y]?"
reveal_type(A.y.value) # N: Revealed type is "Literal[2]?"

def some_a(a: A):
reveal_type(a.value) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]"
[builtins fixtures/dict.pyi]
5 changes: 5 additions & 0 deletions test-data/unit/lib-stub/enum.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,8 @@ class StrEnum(str, Enum):
class nonmember(Generic[_T]):
value: _T
def __init__(self, value: _T) -> None: ...

# It is python-3.11+ only:
class member(Generic[_T]):
value: _T
def __init__(self, value: _T) -> None: ...

0 comments on commit dac88f3

Please sign in to comment.