Skip to content

Commit 742f8a4

Browse files
authored
[ty] Use C[T] instead of C[Unknown] for the upper bound of Self (#20479)
### Summary This PR includes two changes, both of which are necessary to resolve astral-sh/ty#1196: * For a generic class `C[T]`, we previously used `C[Unknown]` as the upper bound of the `Self` type variable. There were two problems with this. For one, when `Self` appeared in contravariant position, we would materialize its upper bound to `Bottom[C[Unknown]]` (which might simplify to `C[Never]` if `C` is covariant in `T`) when accessing methods on `Top[C[Unknown]]`. This would result in `invalid-argument` errors on the `self` parameter. Also, using an upper bound of `C[Unknown]` would mean that inside methods, references to `T` would be treated as `Unknown`. This could lead to false negatives. To fix this, we now use `C[T]` (with a "nested" typevar) as the upper bound for `Self` on `C[T]`. * In order to make this work, we needed to allow assignability/subtyping of inferable typevars to other types, since we now check assignability of e.g. `C[int]` to `C[T]` (when checking assignability to the upper bound of `Self`) when calling an instance-method on `C[int]` whose `self` parameter is annotated as `self: Self` (or implicitly `Self`, following #18007). closes astral-sh/ty#1196 closes astral-sh/ty#1208 ### Test Plan Regression tests for both issues.
1 parent fd5c48c commit 742f8a4

File tree

7 files changed

+213
-15
lines changed

7 files changed

+213
-15
lines changed

crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,31 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
366366
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
367367
```
368368

369+
## Passing generic functions to generic functions
370+
371+
```py
372+
from typing import Callable, TypeVar
373+
374+
A = TypeVar("A")
375+
B = TypeVar("B")
376+
T = TypeVar("T")
377+
378+
def invoke(fn: Callable[[A], B], value: A) -> B:
379+
return fn(value)
380+
381+
def identity(x: T) -> T:
382+
return x
383+
384+
def head(xs: list[T]) -> T:
385+
return xs[0]
386+
387+
# TODO: this should be `Literal[1]`
388+
reveal_type(invoke(identity, 1)) # revealed: Unknown
389+
390+
# TODO: this should be `Unknown | int`
391+
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
392+
```
393+
369394
## Opaque decorators don't affect typevar binding
370395

371396
Inside the body of a generic function, we should be able to see that the typevars bound by that

crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,27 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
323323
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
324324
```
325325

326+
## Passing generic functions to generic functions
327+
328+
```py
329+
from typing import Callable
330+
331+
def invoke[A, B](fn: Callable[[A], B], value: A) -> B:
332+
return fn(value)
333+
334+
def identity[T](x: T) -> T:
335+
return x
336+
337+
def head[T](xs: list[T]) -> T:
338+
return xs[0]
339+
340+
# TODO: this should be `Literal[1]`
341+
reveal_type(invoke(identity, 1)) # revealed: Unknown
342+
343+
# TODO: this should be `Unknown | int`
344+
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
345+
```
346+
326347
## Protocols as TypeVar bounds
327348

328349
Protocol types can be used as TypeVar bounds, just like nominal types.

crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,11 @@ a covariant generic, this is equivalent to using the upper bound of the type par
321321
`object`):
322322

323323
```py
324+
from typing import Self
325+
324326
class Covariant[T]:
325-
def get(self) -> T:
327+
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
328+
def get(self: Self) -> T:
326329
raise NotImplementedError
327330

328331
def _(x: object):
@@ -335,7 +338,8 @@ Similarly, contravariant type parameters use their lower bound of `Never`:
335338

336339
```py
337340
class Contravariant[T]:
338-
def push(self, x: T) -> None: ...
341+
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
342+
def push(self: Self, x: T) -> None: ...
339343

340344
def _(x: object):
341345
if isinstance(x, Contravariant):
@@ -350,8 +354,10 @@ the type system, so we represent it with the internal `Top[]` special form.
350354

351355
```py
352356
class Invariant[T]:
353-
def push(self, x: T) -> None: ...
354-
def get(self) -> T:
357+
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
358+
def push(self: Self, x: T) -> None: ...
359+
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
360+
def get(self: Self) -> T:
355361
raise NotImplementedError
356362

357363
def _(x: object):

crates/ty_python_semantic/resources/mdtest/overloads.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ reveal_type(foo(b"")) # revealed: bytes
9999
## Methods
100100

101101
```py
102-
from typing import overload
102+
from typing_extensions import Self, overload
103103

104104
class Foo1:
105105
@overload
@@ -126,6 +126,18 @@ foo2 = Foo2()
126126
reveal_type(foo2.method) # revealed: Overload[() -> None, (x: str) -> str]
127127
reveal_type(foo2.method()) # revealed: None
128128
reveal_type(foo2.method("")) # revealed: str
129+
130+
class Foo3:
131+
@overload
132+
def takes_self_or_int(self: Self, x: Self) -> Self: ...
133+
@overload
134+
def takes_self_or_int(self: Self, x: int) -> int: ...
135+
def takes_self_or_int(self: Self, x: Self | int) -> Self | int:
136+
return x
137+
138+
foo3 = Foo3()
139+
reveal_type(foo3.takes_self_or_int(foo3)) # revealed: Foo3
140+
reveal_type(foo3.takes_self_or_int(1)) # revealed: int
129141
```
130142

131143
## Constructor

crates/ty_python_semantic/src/types.rs

Lines changed: 126 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,9 +1593,16 @@ impl<'db> Type<'db> {
15931593
})
15941594
}
15951595

1596+
(Type::TypeVar(_), _) if relation.is_assignability() => {
1597+
// The implicit lower bound of a typevar is `Never`, which means
1598+
// that it is always assignable to any other type.
1599+
1600+
// TODO: record the unification constraints
1601+
1602+
ConstraintSet::from(true)
1603+
}
1604+
15961605
// `Never` is the bottom type, the empty set.
1597-
// Other than one unlikely edge case (TypeVars bound to `Never`),
1598-
// no other type is a subtype of or assignable to `Never`.
15991606
(_, Type::Never) => ConstraintSet::from(false),
16001607

16011608
(Type::Union(union), _) => union.elements(db).iter().when_all(db, |&elem_ty| {
@@ -1632,6 +1639,22 @@ impl<'db> Type<'db> {
16321639
// be specialized to `Never`.)
16331640
(_, Type::NonInferableTypeVar(_)) => ConstraintSet::from(false),
16341641

1642+
(_, Type::TypeVar(typevar))
1643+
if relation.is_assignability()
1644+
&& typevar.typevar(db).upper_bound(db).is_none_or(|bound| {
1645+
!self
1646+
.has_relation_to_impl(db, bound, relation, visitor)
1647+
.is_never_satisfied()
1648+
}) =>
1649+
{
1650+
// TODO: record the unification constraints
1651+
1652+
typevar
1653+
.typevar(db)
1654+
.upper_bound(db)
1655+
.when_none_or(|bound| self.has_relation_to_impl(db, bound, relation, visitor))
1656+
}
1657+
16351658
// TODO: Infer specializations here
16361659
(Type::TypeVar(_), _) | (_, Type::TypeVar(_)) => ConstraintSet::from(false),
16371660

@@ -5662,13 +5685,25 @@ impl<'db> Type<'db> {
56625685
],
56635686
});
56645687
};
5665-
let instance = Type::instance(db, class.unknown_specialization(db));
5688+
5689+
let upper_bound = Type::instance(
5690+
db,
5691+
class.apply_specialization(db, |generic_context| {
5692+
let types = generic_context
5693+
.variables(db)
5694+
.iter()
5695+
.map(|typevar| Type::NonInferableTypeVar(*typevar));
5696+
5697+
generic_context.specialize(db, types.collect())
5698+
}),
5699+
);
5700+
56665701
let class_definition = class.definition(db);
56675702
let typevar = TypeVarInstance::new(
56685703
db,
56695704
ast::name::Name::new_static("Self"),
56705705
Some(class_definition),
5671-
Some(TypeVarBoundOrConstraints::UpperBound(instance).into()),
5706+
Some(TypeVarBoundOrConstraints::UpperBound(upper_bound).into()),
56725707
// According to the [spec], we can consider `Self`
56735708
// equivalent to an invariant type variable
56745709
// [spec]: https://typing.python.org/en/latest/spec/generics.html#self
@@ -6010,8 +6045,8 @@ impl<'db> Type<'db> {
60106045
partial.get(db, bound_typevar).unwrap_or(self)
60116046
}
60126047
TypeMapping::MarkTypeVarsInferable(binding_context) => {
6013-
if bound_typevar.binding_context(db) == *binding_context {
6014-
Type::TypeVar(bound_typevar)
6048+
if binding_context.is_none_or(|context| context == bound_typevar.binding_context(db)) {
6049+
Type::TypeVar(bound_typevar.mark_typevars_inferable(db, visitor))
60156050
} else {
60166051
self
60176052
}
@@ -6695,8 +6730,17 @@ pub enum TypeMapping<'a, 'db> {
66956730
BindSelf(Type<'db>),
66966731
/// Replaces occurrences of `typing.Self` with a new `Self` type variable with the given upper bound.
66976732
ReplaceSelf { new_upper_bound: Type<'db> },
6698-
/// Marks the typevars that are bound by a generic class or function as inferable.
6699-
MarkTypeVarsInferable(BindingContext<'db>),
6733+
/// Marks type variables as inferable.
6734+
///
6735+
/// When we create the signature for a generic function, we mark its type variables as inferable. Since
6736+
/// the generic function might reference type variables from enclosing generic scopes, we include the
6737+
/// function's binding context in order to only mark those type variables as inferable that are actually
6738+
/// bound by that function.
6739+
///
6740+
/// When the parameter is set to `None`, *all* type variables will be marked as inferable. We use this
6741+
/// variant when descending into the bounds and/or constraints, and the default value of a type variable,
6742+
/// which may include nested type variables (`Self` has a bound of `C[T]` for a generic class `C[T]`).
6743+
MarkTypeVarsInferable(Option<BindingContext<'db>>),
67006744
/// Create the top or bottom materialization of a type.
67016745
Materialize(MaterializationKind),
67026746
}
@@ -7637,6 +7681,43 @@ impl<'db> TypeVarInstance<'db> {
76377681
)
76387682
}
76397683

7684+
fn mark_typevars_inferable(
7685+
self,
7686+
db: &'db dyn Db,
7687+
visitor: &ApplyTypeMappingVisitor<'db>,
7688+
) -> Self {
7689+
// Type variables can have nested type variables in their bounds, constraints, or default value.
7690+
// When we mark a type variable as inferable, we also mark all of these nested type variables as
7691+
// inferable, so we set the parameter to `None` here.
7692+
let type_mapping = &TypeMapping::MarkTypeVarsInferable(None);
7693+
7694+
Self::new(
7695+
db,
7696+
self.name(db),
7697+
self.definition(db),
7698+
self._bound_or_constraints(db)
7699+
.map(|bound_or_constraints| match bound_or_constraints {
7700+
TypeVarBoundOrConstraintsEvaluation::Eager(bound_or_constraints) => {
7701+
bound_or_constraints
7702+
.mark_typevars_inferable(db, visitor)
7703+
.into()
7704+
}
7705+
TypeVarBoundOrConstraintsEvaluation::LazyUpperBound
7706+
| TypeVarBoundOrConstraintsEvaluation::LazyConstraints => bound_or_constraints,
7707+
}),
7708+
self.explicit_variance(db),
7709+
self._default(db).and_then(|default| match default {
7710+
TypeVarDefaultEvaluation::Eager(ty) => {
7711+
Some(ty.apply_type_mapping_impl(db, type_mapping, visitor).into())
7712+
}
7713+
TypeVarDefaultEvaluation::Lazy => self
7714+
.lazy_default(db)
7715+
.map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor).into()),
7716+
}),
7717+
self.kind(db),
7718+
)
7719+
}
7720+
76407721
fn to_instance(self, db: &'db dyn Db) -> Option<Self> {
76417722
let bound_or_constraints = match self.bound_or_constraints(db)? {
76427723
TypeVarBoundOrConstraints::UpperBound(upper_bound) => {
@@ -7867,6 +7948,18 @@ impl<'db> BoundTypeVarInstance<'db> {
78677948
)
78687949
}
78697950

7951+
fn mark_typevars_inferable(
7952+
self,
7953+
db: &'db dyn Db,
7954+
visitor: &ApplyTypeMappingVisitor<'db>,
7955+
) -> Self {
7956+
Self::new(
7957+
db,
7958+
self.typevar(db).mark_typevars_inferable(db, visitor),
7959+
self.binding_context(db),
7960+
)
7961+
}
7962+
78707963
fn to_instance(self, db: &'db dyn Db) -> Option<Self> {
78717964
Some(Self::new(
78727965
db,
@@ -7972,6 +8065,31 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
79728065
}
79738066
}
79748067
}
8068+
8069+
fn mark_typevars_inferable(
8070+
self,
8071+
db: &'db dyn Db,
8072+
visitor: &ApplyTypeMappingVisitor<'db>,
8073+
) -> Self {
8074+
let type_mapping = &TypeMapping::MarkTypeVarsInferable(None);
8075+
8076+
match self {
8077+
TypeVarBoundOrConstraints::UpperBound(bound) => TypeVarBoundOrConstraints::UpperBound(
8078+
bound.apply_type_mapping_impl(db, type_mapping, visitor),
8079+
),
8080+
TypeVarBoundOrConstraints::Constraints(constraints) => {
8081+
TypeVarBoundOrConstraints::Constraints(UnionType::new(
8082+
db,
8083+
constraints
8084+
.elements(db)
8085+
.iter()
8086+
.map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor))
8087+
.collect::<Vec<_>>()
8088+
.into_boxed_slice(),
8089+
))
8090+
}
8091+
}
8092+
}
79758093
}
79768094

79778095
/// Error returned if a type is not awaitable.

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,18 @@ fn is_subtype_in_invariant_position<'db>(
491491
let base_bottom = base_type.bottom_materialization(db);
492492

493493
let is_subtype_of = |derived: Type<'db>, base: Type<'db>| {
494+
// TODO:
495+
// This should be removed and properly handled in the respective
496+
// `(Type::TypeVar(_), _) | (_, Type::TypeVar(_))` branch of
497+
// `Type::has_relation_to_impl`. Right now, we can not generally
498+
// return `ConstraintSet::from(true)` from that branch, as that
499+
// leads to union simplification, which means that we lose track
500+
// of type variables without recording the constraints under which
501+
// the relation holds.
502+
if matches!(base, Type::TypeVar(_)) || matches!(derived, Type::TypeVar(_)) {
503+
return ConstraintSet::from(true);
504+
}
505+
494506
derived.has_relation_to_impl(db, base, TypeRelation::Subtyping, visitor)
495507
};
496508
match (derived_materialization, base_materialization) {

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ impl<'db> Signature<'db> {
367367
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref())
368368
.apply_type_mapping(
369369
db,
370-
&TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)),
370+
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(
371+
definition,
372+
))),
371373
);
372374
if function_node.is_async && !is_generator {
373375
KnownClass::CoroutineType
@@ -1549,7 +1551,9 @@ impl<'db> Parameter<'db> {
15491551
annotated_type: parameter.annotation().map(|annotation| {
15501552
definition_expression_type(db, definition, annotation).apply_type_mapping(
15511553
db,
1552-
&TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)),
1554+
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(
1555+
definition,
1556+
))),
15531557
)
15541558
}),
15551559
kind,

0 commit comments

Comments
 (0)