Skip to content

Commit 9720385

Browse files
committed
[ty] Use C[T] instead of C[Unknown] for the upper bound of Self
1 parent 48ada2d commit 9720385

File tree

6 files changed

+188
-11
lines changed

6 files changed

+188
-11
lines changed

crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,31 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
366366
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
367367
```
368368

369+
## Passing generic functions to generics functions
370+
371+
```py
372+
from typing import Callable, TypeVar
373+
374+
A = TypeVar("A")
375+
B = TypeVar("B")
376+
T = TypeVar("T")
377+
378+
def invoke(fn: Callable[[A], B], value: A) -> B:
379+
return fn(value)
380+
381+
def identity(x: T) -> T:
382+
return x
383+
384+
def head(xs: list[T]) -> T:
385+
return xs[0]
386+
387+
# TODO: this should be `Literal[1]`
388+
reveal_type(invoke(identity, 1)) # revealed: Unknown
389+
390+
# TODO: this should be `Unknown | int`
391+
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
392+
```
393+
369394
## Opaque decorators don't affect typevar binding
370395

371396
Inside the body of a generic function, we should be able to see that the typevars bound by that

crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,27 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
323323
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
324324
```
325325

326+
## Passing generic functions to generics functions
327+
328+
```py
329+
from typing import Callable
330+
331+
def invoke[A, B](fn: Callable[[A], B], value: A) -> B:
332+
return fn(value)
333+
334+
def identity[T](x: T) -> T:
335+
return x
336+
337+
def head[T](xs: list[T]) -> T:
338+
return xs[0]
339+
340+
# TODO: this should be `Literal[1]`
341+
reveal_type(invoke(identity, 1)) # revealed: Unknown
342+
343+
# TODO: this should be `Unknown | int`
344+
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
345+
```
346+
326347
## Protocols as TypeVar bounds
327348

328349
Protocol types can be used as TypeVar bounds, just like nominal types.

crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,10 @@ a covariant generic, this is equivalent to using the upper bound of the type par
321321
`object`):
322322

323323
```py
324+
from typing import Self
325+
324326
class Covariant[T]:
325-
def get(self) -> T:
327+
def get(self: Self) -> T:
326328
raise NotImplementedError
327329

328330
def _(x: object):
@@ -335,11 +337,12 @@ Similarly, contravariant type parameters use their lower bound of `Never`:
335337

336338
```py
337339
class Contravariant[T]:
338-
def push(self, x: T) -> None: ...
340+
def push(self: Self, x: T) -> None: ...
339341

340342
def _(x: object):
341343
if isinstance(x, Contravariant):
342344
reveal_type(x) # revealed: Contravariant[Never]
345+
# error: [invalid-argument-type] "Argument to bound method `push` is incorrect: Argument type `Contravariant[Never]` does not satisfy upper bound `Contravariant[T@Contravariant]` of type variable `Self`"
343346
# error: [invalid-argument-type] "Argument to bound method `push` is incorrect: Expected `Never`, found `Literal[42]`"
344347
x.push(42)
345348
```
@@ -350,8 +353,8 @@ the type system, so we represent it with the internal `Top[]` special form.
350353

351354
```py
352355
class Invariant[T]:
353-
def push(self, x: T) -> None: ...
354-
def get(self) -> T:
356+
def push(self: Self, x: T) -> None: ...
357+
def get(self: Self) -> T:
355358
raise NotImplementedError
356359

357360
def _(x: object):

crates/ty_python_semantic/src/types.rs

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,34 @@ impl<'db> Type<'db> {
16311631
// be specialized to `Never`.)
16321632
(_, Type::NonInferableTypeVar(_)) => ConstraintSet::from(false),
16331633

1634+
(Type::TypeVar(typevar), _)
1635+
if relation.is_assignability()
1636+
&& typevar.typevar(db).upper_bound(db).is_none_or(|bound| {
1637+
!bound
1638+
.has_relation_to_impl(db, target, relation, visitor)
1639+
.is_never_satisfied()
1640+
}) =>
1641+
{
1642+
typevar
1643+
.typevar(db)
1644+
.upper_bound(db)
1645+
.when_none_or(|bound| bound.has_relation_to_impl(db, target, relation, visitor))
1646+
}
1647+
1648+
(_, Type::TypeVar(typevar))
1649+
if relation.is_assignability()
1650+
&& typevar.typevar(db).upper_bound(db).is_none_or(|bound| {
1651+
!self
1652+
.has_relation_to_impl(db, bound, relation, visitor)
1653+
.is_never_satisfied()
1654+
}) =>
1655+
{
1656+
typevar
1657+
.typevar(db)
1658+
.upper_bound(db)
1659+
.when_none_or(|bound| self.has_relation_to_impl(db, bound, relation, visitor))
1660+
}
1661+
16341662
// TODO: Infer specializations here
16351663
(Type::TypeVar(_), _) | (_, Type::TypeVar(_)) => ConstraintSet::from(false),
16361664

@@ -5661,13 +5689,25 @@ impl<'db> Type<'db> {
56615689
],
56625690
});
56635691
};
5664-
let instance = Type::instance(db, class.unknown_specialization(db));
5692+
5693+
let upper_bound = Type::instance(
5694+
db,
5695+
class.apply_specialization(db, |generic_context| {
5696+
let types = generic_context
5697+
.variables(db)
5698+
.iter()
5699+
.map(|typevar| Type::NonInferableTypeVar(*typevar));
5700+
5701+
generic_context.specialize(db, types.collect())
5702+
}),
5703+
);
5704+
56655705
let class_definition = class.definition(db);
56665706
let typevar = TypeVarInstance::new(
56675707
db,
56685708
ast::name::Name::new_static("Self"),
56695709
Some(class_definition),
5670-
Some(TypeVarBoundOrConstraints::UpperBound(instance).into()),
5710+
Some(TypeVarBoundOrConstraints::UpperBound(upper_bound).into()),
56715711
// According to the [spec], we can consider `Self`
56725712
// equivalent to an invariant type variable
56735713
// [spec]: https://typing.python.org/en/latest/spec/generics.html#self
@@ -6009,8 +6049,8 @@ impl<'db> Type<'db> {
60096049
partial.get(db, bound_typevar).unwrap_or(self)
60106050
}
60116051
TypeMapping::MarkTypeVarsInferable(binding_context) => {
6012-
if bound_typevar.binding_context(db) == *binding_context {
6013-
Type::TypeVar(bound_typevar)
6052+
if binding_context.is_none_or(|context| context == bound_typevar.binding_context(db)) {
6053+
Type::TypeVar(bound_typevar.mark_typevars_inferable(db, visitor))
60146054
} else {
60156055
self
60166056
}
@@ -6695,7 +6735,7 @@ pub enum TypeMapping<'a, 'db> {
66956735
/// Replaces occurrences of `typing.Self` with a new `Self` type variable with the given upper bound.
66966736
ReplaceSelf { new_upper_bound: Type<'db> },
66976737
/// Marks the typevars that are bound by a generic class or function as inferable.
6698-
MarkTypeVarsInferable(BindingContext<'db>),
6738+
MarkTypeVarsInferable(Option<BindingContext<'db>>),
66996739
/// Create the top or bottom materialization of a type.
67006740
Materialize(MaterializationKind),
67016741
}
@@ -7636,6 +7676,40 @@ impl<'db> TypeVarInstance<'db> {
76367676
)
76377677
}
76387678

7679+
fn mark_typevars_inferable(
7680+
self,
7681+
db: &'db dyn Db,
7682+
visitor: &ApplyTypeMappingVisitor<'db>,
7683+
) -> Self {
7684+
let type_mapping = &TypeMapping::MarkTypeVarsInferable(None);
7685+
7686+
Self::new(
7687+
db,
7688+
self.name(db),
7689+
self.definition(db),
7690+
self._bound_or_constraints(db)
7691+
.map(|bound_or_constraints| match bound_or_constraints {
7692+
TypeVarBoundOrConstraintsEvaluation::Eager(bound_or_constraints) => {
7693+
bound_or_constraints
7694+
.mark_typevars_inferable(db, type_mapping, visitor)
7695+
.into()
7696+
}
7697+
TypeVarBoundOrConstraintsEvaluation::LazyUpperBound
7698+
| TypeVarBoundOrConstraintsEvaluation::LazyConstraints => bound_or_constraints,
7699+
}),
7700+
self.explicit_variance(db),
7701+
self._default(db).and_then(|default| match default {
7702+
TypeVarDefaultEvaluation::Eager(ty) => {
7703+
Some(ty.apply_type_mapping_impl(db, type_mapping, visitor).into())
7704+
}
7705+
TypeVarDefaultEvaluation::Lazy => self
7706+
.lazy_default(db)
7707+
.map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor).into()),
7708+
}),
7709+
self.kind(db),
7710+
)
7711+
}
7712+
76397713
fn to_instance(self, db: &'db dyn Db) -> Option<Self> {
76407714
let bound_or_constraints = match self.bound_or_constraints(db)? {
76417715
TypeVarBoundOrConstraints::UpperBound(upper_bound) => {
@@ -7866,6 +7940,18 @@ impl<'db> BoundTypeVarInstance<'db> {
78667940
)
78677941
}
78687942

7943+
fn mark_typevars_inferable(
7944+
self,
7945+
db: &'db dyn Db,
7946+
visitor: &ApplyTypeMappingVisitor<'db>,
7947+
) -> Self {
7948+
Self::new(
7949+
db,
7950+
self.typevar(db).mark_typevars_inferable(db, visitor),
7951+
self.binding_context(db),
7952+
)
7953+
}
7954+
78697955
fn to_instance(self, db: &'db dyn Db) -> Option<Self> {
78707956
Some(Self::new(
78717957
db,
@@ -7971,6 +8057,30 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
79718057
}
79728058
}
79738059
}
8060+
8061+
fn mark_typevars_inferable<'a>(
8062+
self,
8063+
db: &'db dyn Db,
8064+
type_mapping: &TypeMapping<'a, 'db>,
8065+
visitor: &ApplyTypeMappingVisitor<'db>,
8066+
) -> Self {
8067+
match self {
8068+
TypeVarBoundOrConstraints::UpperBound(bound) => TypeVarBoundOrConstraints::UpperBound(
8069+
bound.apply_type_mapping_impl(db, type_mapping, visitor),
8070+
),
8071+
TypeVarBoundOrConstraints::Constraints(constraints) => {
8072+
TypeVarBoundOrConstraints::Constraints(UnionType::new(
8073+
db,
8074+
constraints
8075+
.elements(db)
8076+
.iter()
8077+
.map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor))
8078+
.collect::<Vec<_>>()
8079+
.into_boxed_slice(),
8080+
))
8081+
}
8082+
}
8083+
}
79748084
}
79758085

79768086
/// Error returned if a type is not awaitable.

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,16 @@ fn is_subtype_in_invariant_position<'db>(
491491
let base_bottom = base_type.bottom_materialization(db);
492492

493493
let is_subtype_of = |derived: Type<'db>, base: Type<'db>| {
494+
// TODO:
495+
// This should be removed and properly handled in the respective
496+
// `(Type::TypeVar(_), _) | (_, Type::TypeVar(_))` branch of
497+
// `Type::has_relation_to_impl`. Right now, we can not generally
498+
// return `ConstraintSet::from(true)` inside that branch, as is
499+
//
500+
if matches!(base, Type::TypeVar(_)) || matches!(derived, Type::TypeVar(_)) {
501+
return ConstraintSet::from(true);
502+
}
503+
494504
derived.has_relation_to_impl(db, base, TypeRelation::Subtyping, visitor)
495505
};
496506
match (derived_materialization, base_materialization) {
@@ -829,6 +839,10 @@ impl<'db> Specialization<'db> {
829839
.zip(self.types(db))
830840
.zip(other.types(db))
831841
{
842+
// if relation.is_assignability() && (matches!(other_type, Type::TypeVar(_))) {
843+
// return ConstraintSet::from(true);
844+
// }
845+
832846
// As an optimization, we can return early if either type is dynamic, unless
833847
// we're dealing with a top or bottom materialization.
834848
if other_materialization_kind.is_none()

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ impl<'db> Signature<'db> {
367367
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref())
368368
.apply_type_mapping(
369369
db,
370-
&TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)),
370+
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(
371+
definition,
372+
))),
371373
);
372374
if function_node.is_async && !is_generator {
373375
KnownClass::CoroutineType
@@ -1549,7 +1551,9 @@ impl<'db> Parameter<'db> {
15491551
annotated_type: parameter.annotation().map(|annotation| {
15501552
definition_expression_type(db, definition, annotation).apply_type_mapping(
15511553
db,
1552-
&TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)),
1554+
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(
1555+
definition,
1556+
))),
15531557
)
15541558
}),
15551559
kind,

0 commit comments

Comments
 (0)