Skip to content

Commit 72acb7e

Browse files
committed
infer arguments of generic calls with declared type context
1 parent 64ab79e commit 72acb7e

File tree

2 files changed

+53
-8
lines changed

2 files changed

+53
-8
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,10 +476,8 @@ def _(i: int):
476476
b: list[int | None] | None = id([i])
477477
c: list[int | None] | int | None = id([i])
478478
reveal_type(a) # revealed: list[int | None]
479-
# TODO: these should reveal `list[int | None]`
480-
# we currently do not use the call expression annotation as type context for argument inference
481-
reveal_type(b) # revealed: list[Unknown | int]
482-
reveal_type(c) # revealed: list[Unknown | int]
479+
reveal_type(b) # revealed: list[int | None]
480+
reveal_type(c) # revealed: list[int | None]
483481

484482
a: list[int | None] | None = [i]
485483
b: list[int | None] | None = lst(i)
@@ -495,3 +493,26 @@ def _(i: int):
495493
reveal_type(b) # revealed: list[Unknown]
496494
reveal_type(c) # revealed: list[Unknown]
497495
```
496+
497+
The function arguments are inferred using the type context:
498+
499+
```py
500+
from typing import TypedDict
501+
502+
class TD(TypedDict):
503+
x: int
504+
505+
def f[T](x: list[T]) -> T:
506+
return x[0]
507+
508+
a: TD = f([{"x": 0}, {"x": 1}])
509+
reveal_type(a) # revealed: TD
510+
511+
b: TD | None = f([{"x": 0}, {"x": 1}])
512+
# TODO: Narrow away the `None` here.
513+
reveal_type(b) # revealed: TD | None
514+
515+
# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor"
516+
# error: [invalid-key] "Invalid key access on TypedDict `TD`: Unknown key "y""
517+
c: TD | None = f([{"y": 0}, {"x": 1}])
518+
```

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5535,6 +5535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
55355535
ast_arguments: &ast::Arguments,
55365536
arguments: &mut CallArguments<'a, 'db>,
55375537
bindings: &Bindings<'db>,
5538+
call_expression_tcx: TypeContext<'db>,
55385539
) {
55395540
debug_assert!(
55405541
ast_arguments.len() == arguments.len()
@@ -5603,10 +5604,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
56035604
return None;
56045605
};
56055606

5606-
let parameter_type =
5607+
let mut parameter_type =
56075608
overload.signature.parameters()[*parameter_index].annotated_type()?;
56085609

5609-
// TODO: For now, skip any parameter annotations that mention any typevars. There
5610+
// If this is a generic call, attempt to specialize the parameter type using the
5611+
// declared type context, if provided.
5612+
if let Some(generic_context) = overload.signature.generic_context
5613+
&& let Some(return_ty) = overload.signature.return_ty
5614+
&& let Some(declared_return_ty) = call_expression_tcx.annotation
5615+
{
5616+
let mut builder =
5617+
SpecializationBuilder::new(db, generic_context.inferable_typevars(db));
5618+
5619+
let _ = builder.infer(return_ty, declared_return_ty);
5620+
let specialization = builder.build(generic_context, call_expression_tcx);
5621+
5622+
// Note that we are not necessarily "preferring the declared type" here, as the
5623+
// type context will only be preferred during the inference of this expression
5624+
// by the same heuristics we use for the inference of the outer generic call.
5625+
parameter_type = parameter_type.apply_specialization(db, specialization);
5626+
}
5627+
5628+
// TODO: For now, skip any parameter annotations that still mention any typevars. There
56105629
// are two issues:
56115630
//
56125631
// First, if we include those typevars in the type context that we use to infer the
@@ -6820,7 +6839,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
68206839
let infer_call_arguments = |bindings: Option<Bindings<'db>>| {
68216840
if let Some(bindings) = bindings {
68226841
let bindings = bindings.match_parameters(self.db(), &call_arguments);
6823-
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
6842+
self.infer_all_argument_types(
6843+
arguments,
6844+
&mut call_arguments,
6845+
&bindings,
6846+
tcx,
6847+
);
68246848
} else {
68256849
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
68266850
self.infer_argument_types(arguments, &mut call_arguments, &argument_forms);
@@ -6841,7 +6865,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
68416865
let bindings = callable_type
68426866
.bindings(self.db())
68436867
.match_parameters(self.db(), &call_arguments);
6844-
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
6868+
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings, tcx);
68456869

68466870
// Validate `TypedDict` constructor calls after argument type inference
68476871
if let Some(class_literal) = callable_type.as_class_literal() {

0 commit comments

Comments
 (0)