Skip to content

Commit

Permalink
infer return types
Browse files Browse the repository at this point in the history
  • Loading branch information
KotlinIsland committed Jan 8, 2025
1 parent a8d99fe commit becb55b
Show file tree
Hide file tree
Showing 16 changed files with 207 additions and 84 deletions.
30 changes: 11 additions & 19 deletions .mypy/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -1989,15 +1989,15 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"type_context\" is not using @override but is overriding a method in class \"mypy.plugin.CheckerPluginInterface\"",
"offset": 456,
"offset": 464,
"src": "def type_context(self) -> list[Type | None]:",
"target": "mypy.checker"
},
{
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_overloaded_func_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
"offset": 191,
"offset": 198,
"src": "def visit_overloaded_func_def(self, defn: OverloadedFuncDef, do_items=True) -> None:",
"target": "mypy.checker.TypeChecker.visit_overloaded_func_def"
},
Expand Down Expand Up @@ -2029,7 +2029,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_class_def\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
"offset": 1113,
"offset": 1151,
"src": "def visit_class_def(self, defn: ClassDef) -> None:",
"target": "mypy.checker.TypeChecker.visit_class_def"
},
Expand Down Expand Up @@ -2093,7 +2093,7 @@
"code": "truthy-bool",
"column": 23,
"message": "\"signature\" has type \"Type\" which does not implement __bool__ or __len__ so it could always be true in boolean context",
"offset": 169,
"offset": 172,
"src": "if signature:",
"target": "mypy.checker.TypeChecker.check_assignment"
},
Expand Down Expand Up @@ -2245,7 +2245,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_if_stmt\" is not using @override but is overriding a method in class \"mypy.visitor.NodeVisitor\"",
"offset": 111,
"offset": 123,
"src": "def visit_if_stmt(self, s: IfStmt) -> None:",
"target": "mypy.checker.TypeChecker.visit_if_stmt"
},
Expand Down Expand Up @@ -2639,7 +2639,7 @@
"code": "truthy-bool",
"column": 34,
"message": "\"item_name_expr\" has type \"Expression\" which does not implement __bool__ or __len__ so it could always be true in boolean context",
"offset": 424,
"offset": 430,
"src": "key_context = item_name_expr or item_arg",
"target": "mypy.checkexpr.ExpressionChecker.validate_typeddict_kwargs"
},
Expand Down Expand Up @@ -3055,7 +3055,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"visit_await_expr\" is not using @override but is overriding a method in class \"mypy.visitor.ExpressionVisitor\"",
"offset": 21,
"offset": 33,
"src": "def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type:",
"target": "mypy.checkexpr.ExpressionChecker.visit_await_expr"
},
Expand Down Expand Up @@ -6859,7 +6859,7 @@
"code": "redundant-expr",
"column": 19,
"message": "Condition is always false",
"offset": 375,
"offset": 377,
"src": "if e.type is None:",
"target": "mypy.errors.Errors.render_messages"
},
Expand Down Expand Up @@ -32695,7 +32695,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"__repr__\" is not using @override but is overriding a method in class \"builtins.object\"",
"offset": 96,
"offset": 97,
"src": "def __repr__(self) -> str:",
"target": "mypy.types.Type.__repr__"
},
Expand Down Expand Up @@ -33631,7 +33631,7 @@
"code": "explicit-override",
"column": 4,
"message": "Method \"describe\" is not using @override but is overriding a method in class \"mypy.types.AnyType\"",
"offset": 59,
"offset": 61,
"src": "def describe(self) -> str:",
"target": "mypy.types.UntypedType.describe"
},
Expand Down Expand Up @@ -36585,19 +36585,11 @@
"src": "def attrgetter(name: str) -> operator.attrgetter[Any]:",
"target": "mypy.util.attrgetter"
},
{
"code": "no-any-expr",
"column": 11,
"message": "Expression type contains \"Any\" (has type \"attrgetter[Any]\")",
"offset": 1,
"src": "return operator.attrgetter(name)",
"target": "mypy.util.attrgetter"
},
{
"code": "no-any-expr",
"column": 7,
"message": "Expression has type \"Any\"",
"offset": 4,
"offset": 5,
"src": "if orjson is not None:",
"target": "mypy.util.json_dumps"
},
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Basedmypy Changelog

## [Unreleased]
### Added
- infer return types and generator types

## [2.9.1]
### Fixed
Expand Down
19 changes: 19 additions & 0 deletions docs/source/based_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,22 @@ When a parameter is named `_`, it's type will be inferred as `object`:
reveal_type(_) # Revealed type is "object"
This is to help with writing functions for callbacks where you don't care about certain parameters.


Return Type Inferred
--------------------

.. code-block:: python
def f(): # Revealed type is "() -> 1"
return 1
Generator Type Inferred
-----------------------

.. code-block:: python
def f(): # Revealed type is "() -> Generator[1, str, 2]"
a: str = yield 1
return 2
96 changes: 82 additions & 14 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import itertools
from collections import defaultdict
from contextlib import ExitStack, contextmanager
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
AbstractSet,
Callable,
Expand All @@ -23,7 +23,7 @@
cast,
overload,
)
from typing_extensions import TypeAlias as _TypeAlias
from typing_extensions import ContextManager, TypeAlias as _TypeAlias

import mypy.checkexpr
from mypy import errorcodes as codes, join, message_registry, nodes, operators
Expand Down Expand Up @@ -249,7 +249,9 @@
# Maximum length of fixed tuple types inferred when narrowing from variadic tuples.
MAX_PRECISE_TUPLE_SIZE: Final = 8

DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator]
DeferredNodeType: _TypeAlias = Union[
FuncDef, LambdaExpr, OverloadedFuncDef, Decorator, AssignmentStmt
]
FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef]


Expand All @@ -260,7 +262,7 @@
class DeferredNode(NamedTuple):
node: DeferredNodeType
# And its TypeInfo (for semantic analysis self type handling
active_typeinfo: TypeInfo | None
active_typeinfo: TypeInfo | FuncItem | None


# Same as above, but for fine-grained mode targets. Only top-level functions/methods
Expand Down Expand Up @@ -452,6 +454,12 @@ def __init__(
# always the next if statement to have a redundant expression
self.allow_redundant_expr = False

self.should_defer_current_node = False
"""when a parent node should be defered"""
self.inferred_return_types: list[Type] = []
self.inferred_yield_types: list[Type] = []
self.inferred_send_types: list[Type] = []

@property
def type_context(self) -> list[Type | None]:
return self.expr_checker.type_context
Expand Down Expand Up @@ -551,10 +559,15 @@ def check_second_pass(
# (self.pass_num, type_name, node.fullname or node.name))
done.add(node)
with ExitStack() as stack:
if active_typeinfo:
cm: ContextManager[object] = nullcontext()
if isinstance(active_typeinfo, FuncItem):
stack.enter_context(self.scope.push_function(active_typeinfo))
cm = self.tscope.function_scope(active_typeinfo)
elif active_typeinfo:
stack.enter_context(self.tscope.class_scope(active_typeinfo))
stack.enter_context(self.scope.push_class(active_typeinfo))
self.check_partial(node)
with cm:
self.check_partial(node)
return True

def check_partial(self, node: DeferredNodeType | FineGrainedDeferredNodeType) -> None:
Expand All @@ -579,7 +592,9 @@ def check_top_level(self, node: MypyFile) -> None:
assert not self.current_node_deferred
# TODO: Handle __all__

def defer_node(self, node: DeferredNodeType, enclosing_class: TypeInfo | None) -> None:
def defer_node(
self, node: DeferredNodeType, enclosing_class: TypeInfo | FuncItem | None
) -> None:
"""Defer a node for processing during next type-checking pass.
Args:
Expand Down Expand Up @@ -1577,6 +1592,12 @@ def check_func_def(
new_frame = self.binder.push_frame()
new_frame.types[key] = narrowed_type
self.binder.declarations[key] = old_binder.declarations[key]
inferred_return_types = self.inferred_return_types
self.inferred_return_types = []
inferred_yield_types = self.inferred_yield_types
self.inferred_yield_types = []
inferred_send_types = self.inferred_send_types
self.inferred_send_types = []
with self.scope.push_function(defn):
# We suppress reachability warnings for empty generator functions
# (return; yield) which have a "yield" that's unreachable by definition
Expand All @@ -1591,6 +1612,37 @@ def check_func_def(
if _is_empty_generator_function(item) or len(expanded) >= 2:
self.binder.suppress_unreachable_warnings()
self.accept(item.body)
if self.options.default_return:
if not self.binder.is_unreachable():
self.inferred_return_types.append(NoneType())

ret_type = get_proper_type(typ.ret_type)
if (
isinstance(ret_type, UntypedType)
and ret_type.type_of_any == TypeOfAny.to_be_inferred
):
ret_type = make_simplified_union(self.inferred_return_types)
item.type = typ.copy_modified(ret_type=ret_type)
if self.inferred_yield_types or self.inferred_send_types:
yield_type = (
make_simplified_union(self.inferred_yield_types)
if self.inferred_yield_types
else self.named_type("builtins.object")
)
# `Never` here isn't ideal, and neither is `object`, so we just go with the default typevar value
send_type = (
make_simplified_intersection(self.inferred_send_types)
if self.inferred_send_types
else NoneType()
)
assert isinstance(item.type, CallableType)
item.type.ret_type = Instance(
self.lookup_typeinfo("typing.Generator"),
[yield_type, send_type, item.type.ret_type],
)
self.inferred_return_types = inferred_return_types
self.inferred_yield_types = inferred_yield_types
self.inferred_send_types = inferred_send_types
unreachable = self.binder.is_unreachable()
if new_frame is not None:
self.binder.pop_frame(True, 0)
Expand Down Expand Up @@ -1783,13 +1835,14 @@ def check_for_missing_annotations(self, fdef: FuncItem) -> None:
if not fdef.arguments or (
len(fdef.arguments) == 1 and (fdef.arg_names[0] in ("self", "cls"))
):
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
if not has_return_statement(fdef) and not fdef.is_generator:
self.note(
'Use "-> None" if function does not return a value',
fdef,
code=codes.NO_UNTYPED_DEF,
)
if not self.options.default_return:
self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef)
if not has_return_statement(fdef) and not fdef.is_generator:
self.note(
'Use "-> None" if function does not return a value',
fdef,
code=codes.NO_UNTYPED_DEF,
)
else:
self.fail(message_registry.FUNCTION_TYPE_EXPECTED, fdef)
elif isinstance(fdef.type, CallableType):
Expand Down Expand Up @@ -3260,6 +3313,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
s.new_syntax,
override_infer=s.unanalyzed_type is not None,
)
if self.should_defer_current_node:
self.defer_node(s, self.scope.top_function())
self.should_defer_current_node = False
if s.is_alias_def:
self.check_type_alias_rvalue(s)

Expand Down Expand Up @@ -5055,6 +5111,18 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if defn.is_async_generator:
self.fail(message_registry.RETURN_IN_ASYNC_GENERATOR, s)
return
if defn.type:
assert isinstance(defn.type, CallableType)
proper_type = get_proper_type(defn.type.ret_type)
infer = (
isinstance(proper_type, UntypedType)
and proper_type.type_of_any == TypeOfAny.to_be_inferred
)
else:
infer = True
if infer:
self.inferred_return_types.append(typ)
return
# Returning a value of type Any is always fine.
if isinstance(typ, AnyType):
# (Unless you asked to be warned in that case, and the
Expand Down
40 changes: 29 additions & 11 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
fullname == p or fullname.startswith(f"{p}.")
for p in self.chk.options.untyped_calls_exclude
):
if callee_type.implicit:
if callee_type.implicit and not self.chk.options.infer_function_types:
self.msg.untyped_function_call(callee_type, e)
if fullname is None and member is not None:
assert object_type is not None
Expand All @@ -638,7 +638,13 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
fullname == p or fullname.startswith(f"{p}.")
for p in self.chk.options.untyped_calls_exclude
):
if callee_type.implicit:
proper_type = get_proper_type(callee_type.ret_type)
if (
isinstance(proper_type, UntypedType)
and proper_type.type_of_any == TypeOfAny.to_be_inferred
):
self.chk.current_node_deferred = True
elif callee_type.implicit and not self.chk.options.infer_function_types:
self.msg.untyped_function_call(callee_type, e)
elif has_untyped_type(callee_type):
# Get module of the function, to get its settings
Expand Down Expand Up @@ -6290,22 +6296,34 @@ def not_ready_callback(self, name: str, context: Context) -> None:
def visit_yield_expr(self, e: YieldExpr) -> Type:
return_type = self.chk.return_types[-1]
expected_item_type = self.chk.get_generator_yield_type(return_type, False)
proper_type = get_proper_type(return_type)
infer = (
isinstance(proper_type, UntypedType)
and proper_type.type_of_any == TypeOfAny.to_be_inferred
)
if infer and self.type_context[-1]:
self.chk.inferred_send_types.append(self.type_context[-1])
if e.expr is None:
if (
if infer:
self.chk.inferred_yield_types.append(NoneType())
elif (
not isinstance(get_proper_type(expected_item_type), (NoneType, AnyType))
and self.chk.in_checked_function()
):
self.chk.fail(message_registry.YIELD_VALUE_EXPECTED, e)
else:
actual_item_type = self.accept(e.expr, expected_item_type)
self.chk.check_subtype(
actual_item_type,
expected_item_type,
e,
message_registry.INCOMPATIBLE_TYPES_IN_YIELD,
"actual type",
"expected type",
)
if infer:
self.chk.inferred_yield_types.append(actual_item_type)
else:
self.chk.check_subtype(
actual_item_type,
expected_item_type,
e,
message_registry.INCOMPATIBLE_TYPES_IN_YIELD,
"actual type",
"expected type",
)
return self.chk.get_generator_receive_type(return_type, False)

def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type:
Expand Down
Loading

0 comments on commit becb55b

Please sign in to comment.