Skip to content

Commit 914095d

Browse files
dcreagerAlexWaygoodcarljm
authored
[red-knot] Super-basic generic inference at call sites (#17301)
This PR adds **_very_** basic inference of generic typevars at call sites. It does not bring in a full unification algorithm, and there are a few TODOs in the test suite that are not discharged by this. But it handles a good number of useful cases! And the PR does not add anything that would go away with a more sophisticated constraint solver. In short, we just look for typevars in the formal parameters, and assume that the inferred type of the corresponding argument is what that typevar should map to. If a typevar appears more than once, we union together the corresponding argument types. Cases we are not yet handling: - We are not widening literals. - We are not recursing into parameters that are themselves generic aliases. - We are not being very clever with parameters that are union types. --------- Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com> Co-authored-by: Carl Meyer <carl@astral.sh>
1 parent 5350288 commit 914095d

File tree

15 files changed

+622
-241
lines changed

15 files changed

+622
-241
lines changed

crates/red_knot_python_semantic/resources/mdtest/call/methods.md

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -410,29 +410,19 @@ def does_nothing[T](f: T) -> T:
410410

411411
class C:
412412
@classmethod
413-
# TODO: no error should be emitted here (needs support for generics)
414-
# error: [invalid-argument-type]
415413
@does_nothing
416414
def f1(cls: type[C], x: int) -> str:
417415
return "a"
418-
# TODO: no error should be emitted here (needs support for generics)
419-
# error: [invalid-argument-type]
416+
420417
@does_nothing
421418
@classmethod
422419
def f2(cls: type[C], x: int) -> str:
423420
return "a"
424421

425-
# TODO: All of these should be `str` (and not emit an error), once we support generics
426-
427-
# error: [call-non-callable]
428-
reveal_type(C.f1(1)) # revealed: Unknown
429-
# error: [call-non-callable]
430-
reveal_type(C().f1(1)) # revealed: Unknown
431-
432-
# error: [call-non-callable]
433-
reveal_type(C.f2(1)) # revealed: Unknown
434-
# error: [call-non-callable]
435-
reveal_type(C().f2(1)) # revealed: Unknown
422+
reveal_type(C.f1(1)) # revealed: str
423+
reveal_type(C().f1(1)) # revealed: str
424+
reveal_type(C.f2(1)) # revealed: str
425+
reveal_type(C().f2(1)) # revealed: str
436426
```
437427

438428
[functions and methods]: https://docs.python.org/3/howto/descriptor.html#functions-and-methods

crates/red_knot_python_semantic/resources/mdtest/generics/classes.md

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,102 @@ If a typevar does not provide a default, we use `Unknown`:
149149
reveal_type(C()) # revealed: C[Unknown]
150150
```
151151

152+
## Inferring generic class parameters from constructors
153+
152154
If the type of a constructor parameter is a class typevar, we can use that to infer the type
153-
parameter:
155+
parameter. The types inferred from a type context and from a constructor parameter must be
156+
consistent with each other.
157+
158+
## `__new__` only
154159

155160
```py
156-
class E[T]:
161+
class C[T]:
162+
def __new__(cls, x: T) -> "C"[T]:
163+
return object.__new__(cls)
164+
165+
reveal_type(C(1)) # revealed: C[Literal[1]]
166+
167+
# TODO: error: [invalid-argument-type]
168+
wrong_innards: C[int] = C("five")
169+
```
170+
171+
## `__init__` only
172+
173+
```py
174+
class C[T]:
157175
def __init__(self, x: T) -> None: ...
158176

159-
# TODO: revealed: E[int] or E[Literal[1]]
160-
reveal_type(E(1)) # revealed: E[Unknown]
177+
reveal_type(C(1)) # revealed: C[Literal[1]]
178+
179+
# TODO: error: [invalid-argument-type]
180+
wrong_innards: C[int] = C("five")
161181
```
162182

163-
The types inferred from a type context and from a constructor parameter must be consistent with each
164-
other:
183+
## Identical `__new__` and `__init__` signatures
165184

166185
```py
186+
class C[T]:
187+
def __new__(cls, x: T) -> "C"[T]:
188+
return object.__new__(cls)
189+
190+
def __init__(self, x: T) -> None: ...
191+
192+
reveal_type(C(1)) # revealed: C[Literal[1]]
193+
167194
# TODO: error: [invalid-argument-type]
168-
wrong_innards: E[int] = E("five")
195+
wrong_innards: C[int] = C("five")
196+
```
197+
198+
## Compatible `__new__` and `__init__` signatures
199+
200+
```py
201+
class C[T]:
202+
def __new__(cls, *args, **kwargs) -> "C"[T]:
203+
return object.__new__(cls)
204+
205+
def __init__(self, x: T) -> None: ...
206+
207+
reveal_type(C(1)) # revealed: C[Literal[1]]
208+
209+
# TODO: error: [invalid-argument-type]
210+
wrong_innards: C[int] = C("five")
211+
212+
class D[T]:
213+
def __new__(cls, x: T) -> "D"[T]:
214+
return object.__new__(cls)
215+
216+
def __init__(self, *args, **kwargs) -> None: ...
217+
218+
reveal_type(D(1)) # revealed: D[Literal[1]]
219+
220+
# TODO: error: [invalid-argument-type]
221+
wrong_innards: D[int] = D("five")
222+
```
223+
224+
## `__init__` is itself generic
225+
226+
TODO: These do not currently work yet, because we don't correctly model the nested generic contexts.
227+
228+
```py
229+
class C[T]:
230+
def __init__[S](self, x: T, y: S) -> None: ...
231+
232+
# TODO: no error
233+
# TODO: revealed: C[Literal[1]]
234+
# error: [invalid-argument-type]
235+
reveal_type(C(1, 1)) # revealed: C[Unknown]
236+
# TODO: no error
237+
# TODO: revealed: C[Literal[1]]
238+
# error: [invalid-argument-type]
239+
reveal_type(C(1, "string")) # revealed: C[Unknown]
240+
# TODO: no error
241+
# TODO: revealed: C[Literal[1]]
242+
# error: [invalid-argument-type]
243+
reveal_type(C(1, True)) # revealed: C[Unknown]
244+
245+
# TODO: error for the correct reason
246+
# error: [invalid-argument-type] "Argument to this function is incorrect: Expected `S`, found `Literal[1]`"
247+
wrong_innards: C[int] = C("five", 1)
169248
```
170249

171250
## Generic subclass
@@ -200,10 +279,7 @@ class C[T]:
200279
def cannot_shadow_class_typevar[T](self, t: T): ...
201280

202281
c: C[int] = C[int]()
203-
# TODO: no error
204-
# TODO: revealed: str or Literal["string"]
205-
# error: [invalid-argument-type]
206-
reveal_type(c.method("string")) # revealed: U
282+
reveal_type(c.method("string")) # revealed: Literal["string"]
207283
```
208284

209285
## Cyclic class definition

crates/red_knot_python_semantic/resources/mdtest/generics/functions.md

Lines changed: 30 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -43,33 +43,14 @@ def absurd[T]() -> T:
4343
If the type of a generic function parameter is a typevar, then we can infer what type that typevar
4444
is bound to at each call site.
4545

46-
TODO: Note that some of the TODO revealed types have two options, since we haven't decided yet
47-
whether we want to infer a more specific `Literal` type where possible, or use heuristics to weaken
48-
the inferred type to e.g. `int`.
49-
5046
```py
5147
def f[T](x: T) -> T:
5248
return x
5349

54-
# TODO: no error
55-
# TODO: revealed: int or Literal[1]
56-
# error: [invalid-argument-type]
57-
reveal_type(f(1)) # revealed: T
58-
59-
# TODO: no error
60-
# TODO: revealed: float
61-
# error: [invalid-argument-type]
62-
reveal_type(f(1.0)) # revealed: T
63-
64-
# TODO: no error
65-
# TODO: revealed: bool or Literal[true]
66-
# error: [invalid-argument-type]
67-
reveal_type(f(True)) # revealed: T
68-
69-
# TODO: no error
70-
# TODO: revealed: str or Literal["string"]
71-
# error: [invalid-argument-type]
72-
reveal_type(f("string")) # revealed: T
50+
reveal_type(f(1)) # revealed: Literal[1]
51+
reveal_type(f(1.0)) # revealed: float
52+
reveal_type(f(True)) # revealed: Literal[True]
53+
reveal_type(f("string")) # revealed: Literal["string"]
7354
```
7455

7556
## Inferring “deep” generic parameter types
@@ -82,7 +63,7 @@ def f[T](x: list[T]) -> T:
8263
return x[0]
8364

8465
# TODO: revealed: float
85-
reveal_type(f([1.0, 2.0])) # revealed: T
66+
reveal_type(f([1.0, 2.0])) # revealed: Unknown
8667
```
8768

8869
## Typevar constraints
@@ -93,7 +74,6 @@ in the function.
9374

9475
```py
9576
def good_param[T: int](x: T) -> None:
96-
# TODO: revealed: T & int
9777
reveal_type(x) # revealed: T
9878
```
9979

@@ -162,61 +142,41 @@ parameters simultaneously.
162142
def two_params[T](x: T, y: T) -> T:
163143
return x
164144

165-
# TODO: no error
166-
# TODO: revealed: str
167-
# error: [invalid-argument-type]
168-
# error: [invalid-argument-type]
169-
reveal_type(two_params("a", "b")) # revealed: T
170-
171-
# TODO: no error
172-
# TODO: revealed: str | int
173-
# error: [invalid-argument-type]
174-
# error: [invalid-argument-type]
175-
reveal_type(two_params("a", 1)) # revealed: T
145+
reveal_type(two_params("a", "b")) # revealed: Literal["a", "b"]
146+
reveal_type(two_params("a", 1)) # revealed: Literal["a", 1]
176147
```
177148

178-
```py
179-
def param_with_union[T](x: T | int, y: T) -> T:
180-
return y
149+
When one of the parameters is a union, we attempt to find the smallest specialization that satisfies
150+
all of the constraints.
181151

182-
# TODO: no error
183-
# TODO: revealed: str
184-
# error: [invalid-argument-type]
185-
reveal_type(param_with_union(1, "a")) # revealed: T
152+
```py
153+
def union_param[T](x: T | None) -> T:
154+
if x is None:
155+
raise ValueError
156+
return x
186157

187-
# TODO: no error
188-
# TODO: revealed: str
189-
# error: [invalid-argument-type]
190-
# error: [invalid-argument-type]
191-
reveal_type(param_with_union("a", "a")) # revealed: T
158+
reveal_type(union_param("a")) # revealed: Literal["a"]
159+
reveal_type(union_param(1)) # revealed: Literal[1]
160+
reveal_type(union_param(None)) # revealed: Unknown
161+
```
192162

193-
# TODO: no error
194-
# TODO: revealed: int
195-
# error: [invalid-argument-type]
196-
reveal_type(param_with_union(1, 1)) # revealed: T
163+
```py
164+
def union_and_nonunion_params[T](x: T | int, y: T) -> T:
165+
return y
197166

198-
# TODO: no error
199-
# TODO: revealed: str | int
200-
# error: [invalid-argument-type]
201-
# error: [invalid-argument-type]
202-
reveal_type(param_with_union("a", 1)) # revealed: T
167+
reveal_type(union_and_nonunion_params(1, "a")) # revealed: Literal["a"]
168+
reveal_type(union_and_nonunion_params("a", "a")) # revealed: Literal["a"]
169+
reveal_type(union_and_nonunion_params(1, 1)) # revealed: Literal[1]
170+
reveal_type(union_and_nonunion_params(3, 1)) # revealed: Literal[1]
171+
reveal_type(union_and_nonunion_params("a", 1)) # revealed: Literal["a", 1]
203172
```
204173

205174
```py
206175
def tuple_param[T, S](x: T | S, y: tuple[T, S]) -> tuple[T, S]:
207176
return y
208177

209-
# TODO: no error
210-
# TODO: revealed: tuple[str, int]
211-
# error: [invalid-argument-type]
212-
# error: [invalid-argument-type]
213-
reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[T, S]
214-
215-
# TODO: no error
216-
# TODO: revealed: tuple[str, int]
217-
# error: [invalid-argument-type]
218-
# error: [invalid-argument-type]
219-
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[T, S]
178+
reveal_type(tuple_param("a", ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
179+
reveal_type(tuple_param(1, ("a", 1))) # revealed: tuple[Literal["a"], Literal[1]]
220180
```
221181

222182
## Inferring nested generic function calls
@@ -231,15 +191,6 @@ def f[T](x: T) -> tuple[T, int]:
231191
def g[T](x: T) -> T | None:
232192
return x
233193

234-
# TODO: no error
235-
# TODO: revealed: tuple[str | None, int]
236-
# error: [invalid-argument-type]
237-
# error: [invalid-argument-type]
238-
reveal_type(f(g("a"))) # revealed: tuple[T, int]
239-
240-
# TODO: no error
241-
# TODO: revealed: tuple[str, int] | None
242-
# error: [invalid-argument-type]
243-
# error: [invalid-argument-type]
244-
reveal_type(g(f("a"))) # revealed: T | None
194+
reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
195+
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
245196
```

crates/red_knot_python_semantic/resources/mdtest/generics/scoping.md

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,8 @@ to a different type each time.
5959
def f[T](x: T) -> T:
6060
return x
6161

62-
# TODO: no error
63-
# TODO: revealed: int or Literal[1]
64-
# error: [invalid-argument-type]
65-
reveal_type(f(1)) # revealed: T
66-
# TODO: no error
67-
# TODO: revealed: str or Literal["a"]
68-
# error: [invalid-argument-type]
69-
reveal_type(f("a")) # revealed: T
62+
reveal_type(f(1)) # revealed: Literal[1]
63+
reveal_type(f("a")) # revealed: Literal["a"]
7064
```
7165

7266
## Methods can mention class typevars
@@ -157,10 +151,7 @@ class C[T]:
157151
return y
158152

159153
c: C[int] = C()
160-
# TODO: no errors
161-
# TODO: revealed: str
162-
# error: [invalid-argument-type]
163-
reveal_type(c.m(1, "string")) # revealed: S
154+
reveal_type(c.m(1, "string")) # revealed: Literal["string"]
164155
```
165156

166157
## Unbound typevars

0 commit comments

Comments
 (0)