Skip to content

Commit

Permalink
Analyze and check default arguments to lambdas (python#7306)
Browse files Browse the repository at this point in the history
This will allow us to report errors in them and also is needed to
support them in mypyc.
  • Loading branch information
msullivan authored Aug 9, 2019
1 parent 4ff341f commit 5bb6796
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 17 deletions.
29 changes: 16 additions & 13 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,19 +926,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])

# Type check initialization expressions.
body_is_trivial = self.is_trivial_body(defn.body)
for arg in item.arguments:
if arg.initializer is None:
continue
if body_is_trivial and isinstance(arg.initializer, EllipsisExpr):
continue
name = arg.variable.name()
msg = 'Incompatible default for '
if name.startswith('__tuple_arg_'):
msg += "tuple argument {}".format(name[12:])
else:
msg += 'argument "{}"'.format(name)
self.check_simple_assignment(arg.variable.type, arg.initializer,
context=arg, msg=msg, lvalue_name='argument', rvalue_name='default')
self.check_default_args(item, body_is_trivial)

# Type check body in a new scope.
with self.binder.top_frame_context():
Expand Down Expand Up @@ -978,6 +966,21 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str])

self.binder = old_binder

def check_default_args(self, item: FuncItem, body_is_trivial: bool) -> None:
for arg in item.arguments:
if arg.initializer is None:
continue
if body_is_trivial and isinstance(arg.initializer, EllipsisExpr):
continue
name = arg.variable.name()
msg = 'Incompatible default for '
if name.startswith('__tuple_arg_'):
msg += "tuple argument {}".format(name[12:])
else:
msg += 'argument "{}"'.format(name)
self.check_simple_assignment(arg.variable.type, arg.initializer,
context=arg, msg=msg, lvalue_name='argument', rvalue_name='default')

def is_forward_op_method(self, method_name: str) -> bool:
if self.options.python_version[0] == 2 and method_name == '__div__':
return True
Expand Down
1 change: 1 addition & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3065,6 +3065,7 @@ def find_typeddict_context(self, context: Optional[Type]) -> Optional[TypedDictT

def visit_lambda_expr(self, e: LambdaExpr) -> Type:
"""Type check lambda expression."""
self.chk.check_default_args(e, body_is_trivial=False)
inferred_type, type_override = self.infer_lambda_type_using_context(e)
if not inferred_type:
self.chk.return_types.append(AnyType(TypeOfAny.special_form))
Expand Down
14 changes: 10 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
defn.type = defn.type.copy_modified(ret_type=NoneType())
self.prepare_method_signature(defn, self.type)

# Analyze function signature and initializers first.
# Analyze function signature
with self.tvar_scope_frame(self.tvar_scope.method_frame()):
if defn.type:
self.check_classvar_in_signature(defn.type)
Expand All @@ -577,10 +577,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
if isinstance(defn, FuncDef):
assert isinstance(defn.type, CallableType)
defn.type = set_callable_name(defn.type, defn)
for arg in defn.arguments:
if arg.initializer:
arg.initializer.accept(self)

self.analyze_arg_initializers(defn)
self.analyze_function_body(defn)
if defn.is_coroutine and isinstance(defn.type, CallableType) and not self.deferred:
if defn.is_async_generator:
Expand Down Expand Up @@ -868,6 +866,13 @@ def add_function_to_symbol_table(self, func: Union[FuncDef, OverloadedFuncDef])
func._fullname = self.qualified_name(func.name())
self.add_symbol(func.name(), func, func)

def analyze_arg_initializers(self, defn: FuncItem) -> None:
with self.tvar_scope_frame(self.tvar_scope.method_frame()):
# Analyze default arguments
for arg in defn.arguments:
if arg.initializer:
arg.initializer.accept(self)

def analyze_function_body(self, defn: FuncItem) -> None:
is_method = self.is_class_scope()
with self.tvar_scope_frame(self.tvar_scope.method_frame()):
Expand Down Expand Up @@ -3722,6 +3727,7 @@ def analyze_comp_for_2(self, expr: Union[GeneratorExpr,
expr.sequences[0].accept(self)

def visit_lambda_expr(self, expr: LambdaExpr) -> None:
self.analyze_arg_initializers(expr)
self.analyze_function_body(expr)

def visit_conditional_expr(self, expr: ConditionalExpr) -> None:
Expand Down
6 changes: 6 additions & 0 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2549,3 +2549,9 @@ def f() -> int: ...
[file p/d.py]
import p
def f() -> int: ...

[case testLambdaDefaultTypeErrors]
lambda a=nonsense: a # E: Name 'nonsense' is not defined
lambda a=(1 + 'asdf'): a # E: Unsupported operand types for + ("int" and "str")
def f(x: int = i): # E: Name 'i' is not defined
i = 42

0 comments on commit 5bb6796

Please sign in to comment.