Skip to content

Commit 51cff8e

Browse files
committed
[red-knot] Infer return type of lambda expression
1 parent 58d5fe9 commit 51cff8e

File tree

3 files changed

+29
-25
lines changed

3 files changed

+29
-25
lines changed

crates/red_knot_python_semantic/resources/mdtest/expression/lambda.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
`lambda` expressions can be defined without any parameters.
66

77
```py
8-
reveal_type(lambda: 1) # revealed: () -> @Todo(lambda return type)
8+
reveal_type(lambda: 1) # revealed: () -> Literal[1]
99

1010
# error: [unresolved-reference]
11-
reveal_type(lambda: a) # revealed: () -> @Todo(lambda return type)
11+
reveal_type(lambda: a) # revealed: () -> Unknown
1212
```
1313

1414
## With parameters
@@ -17,45 +17,47 @@ Unlike parameters in function definition, the parameters in a `lambda` expressio
1717
annotated.
1818

1919
```py
20-
reveal_type(lambda a: a) # revealed: (a) -> @Todo(lambda return type)
21-
reveal_type(lambda a, b: a + b) # revealed: (a, b) -> @Todo(lambda return type)
20+
reveal_type(lambda a: a) # revealed: (a) -> Unknown
21+
reveal_type(lambda a, b: a + b) # revealed: (a, b) -> Unknown
2222
```
2323

2424
But, it can have default values:
2525

2626
```py
27-
reveal_type(lambda a=1: a) # revealed: (a=Literal[1]) -> @Todo(lambda return type)
28-
reveal_type(lambda a, b=2: a) # revealed: (a, b=Literal[2]) -> @Todo(lambda return type)
27+
reveal_type(lambda a=1: a) # revealed: (a=Literal[1]) -> Unknown | Literal[1]
28+
reveal_type(lambda a, b=2: a) # revealed: (a, b=Literal[2]) -> Unknown
2929
```
3030

3131
And, positional-only parameters:
3232

3333
```py
34-
reveal_type(lambda a, b, /, c: c) # revealed: (a, b, /, c) -> @Todo(lambda return type)
34+
reveal_type(lambda a, b, /, c: c) # revealed: (a, b, /, c) -> Unknown
3535
```
3636

3737
And, keyword-only parameters:
3838

3939
```py
40-
reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=Literal[2], c) -> @Todo(lambda return type)
40+
reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=Literal[2], c) -> Unknown | Literal[2]
4141
```
4242

4343
And, variadic parameter:
4444

4545
```py
46-
reveal_type(lambda *args: args) # revealed: (*args) -> @Todo(lambda return type)
46+
# TODO: should be `tuple[Unknown, ...]` (needs generics)
47+
reveal_type(lambda *args: args) # revealed: (*args) -> tuple
4748
```
4849

4950
And, keyword-varidic parameter:
5051

5152
```py
52-
reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> @Todo(lambda return type)
53+
# TODO: should be `dict[str, Unknown]` (needs generics)
54+
reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> dict
5355
```
5456

5557
Mixing all of them together:
5658

5759
```py
58-
# revealed: (a, b, /, c=Literal[True], *args, *, d=Literal["default"], e=Literal[5], **kwargs) -> @Todo(lambda return type)
60+
# revealed: (a, b, /, c=Literal[True], *args, *, d=Literal["default"], e=Literal[5], **kwargs) -> None
5961
reveal_type(lambda a, b, /, c=True, *args, d="default", e=5, **kwargs: None)
6062
```
6163

@@ -96,5 +98,5 @@ Here, a `lambda` expression is used as the default value for a parameter in anot
9698
expression.
9799

98100
```py
99-
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=(x, y) -> @Todo(lambda return type)) -> @Todo(lambda return type)
101+
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=(x, y) -> Literal[0]) -> Literal[2]
100102
```

crates/red_knot_python_semantic/src/semantic_index/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,7 @@ where
913913
.iter_non_variadic_params()
914914
.filter_map(|param| param.default.as_deref())
915915
{
916+
self.add_standalone_expression(default);
916917
self.visit_expr(default);
917918
}
918919
// The symbol for the function name itself has to be evaluated
@@ -1738,6 +1739,7 @@ where
17381739
.iter_non_variadic_params()
17391740
.filter_map(|param| param.default.as_deref())
17401741
{
1742+
self.add_standalone_expression(default);
17411743
self.visit_expr(default);
17421744
}
17431745
self.visit_parameters(parameters);

crates/red_knot_python_semantic/src/types/infer.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ impl<'db> TypeInferenceBuilder<'db> {
13381338
.iter_non_variadic_params()
13391339
.filter_map(|param| param.default.as_deref())
13401340
{
1341-
self.infer_expression(default);
1341+
self.infer_standalone_expression(default);
13421342
}
13431343

13441344
// If there are type params, parameters and returns are evaluated in that scope, that is, in
@@ -1463,7 +1463,7 @@ impl<'db> TypeInferenceBuilder<'db> {
14631463
} = parameter_with_default;
14641464
let default_ty = default
14651465
.as_ref()
1466-
.map(|default| self.file_expression_type(default));
1466+
.map(|default| infer_expression_type(self.db(), self.index.expression(&**default)));
14671467
if let Some(annotation) = parameter.annotation.as_ref() {
14681468
let declared_ty = self.file_expression_type(annotation);
14691469
let declared_and_inferred_ty = if let Some(default_ty) = default_ty {
@@ -3447,9 +3447,15 @@ impl<'db> TypeInferenceBuilder<'db> {
34473447
let ast::ExprLambda {
34483448
range: _,
34493449
parameters,
3450-
body: _,
3450+
body,
34513451
} = lambda_expression;
34523452

3453+
let mut default_type = |parameter: &ast::ParameterWithDefault| {
3454+
parameter
3455+
.default()
3456+
.map(|default| self.infer_standalone_expression(default))
3457+
};
3458+
34533459
let parameters = if let Some(parameters) = parameters {
34543460
let positional_only = parameters
34553461
.posonlyargs
@@ -3459,9 +3465,7 @@ impl<'db> TypeInferenceBuilder<'db> {
34593465
Some(parameter.name().id.clone()),
34603466
None,
34613467
ParameterKind::PositionalOnly {
3462-
default_ty: parameter
3463-
.default()
3464-
.map(|default| self.infer_expression(default)),
3468+
default_ty: default_type(parameter),
34653469
},
34663470
)
34673471
})
@@ -3474,9 +3478,7 @@ impl<'db> TypeInferenceBuilder<'db> {
34743478
Some(parameter.name().id.clone()),
34753479
None,
34763480
ParameterKind::PositionalOrKeyword {
3477-
default_ty: parameter
3478-
.default()
3479-
.map(|default| self.infer_expression(default)),
3481+
default_ty: default_type(parameter),
34803482
},
34813483
)
34823484
})
@@ -3496,9 +3498,7 @@ impl<'db> TypeInferenceBuilder<'db> {
34963498
Some(parameter.name().id.clone()),
34973499
None,
34983500
ParameterKind::KeywordOnly {
3499-
default_ty: parameter
3500-
.default()
3501-
.map(|default| self.infer_expression(default)),
3501+
default_ty: default_type(parameter),
35023502
},
35033503
)
35043504
})
@@ -3525,7 +3525,7 @@ impl<'db> TypeInferenceBuilder<'db> {
35253525

35263526
Type::Callable(CallableType::General(GeneralCallableType::new(
35273527
self.db(),
3528-
Signature::new(parameters, Some(todo_type!("lambda return type"))),
3528+
Signature::new(parameters, Some(self.file_expression_type(body))),
35293529
)))
35303530
}
35313531

0 commit comments

Comments
 (0)