Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,31 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
```

## Passing generic functions to generic functions

```py
from typing import Callable, TypeVar

A = TypeVar("A")
B = TypeVar("B")
T = TypeVar("T")

def invoke(fn: Callable[[A], B], value: A) -> B:
return fn(value)

def identity(x: T) -> T:
return x

def head(xs: list[T]) -> T:
return xs[0]

# TODO: this should be `Literal[1]`
reveal_type(invoke(identity, 1)) # revealed: Unknown

# TODO: this should be `Unknown | int`
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
```

## Opaque decorators don't affect typevar binding

Inside the body of a generic function, we should be able to see that the typevars bound by that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,27 @@ reveal_type(f(g("a"))) # revealed: tuple[Literal["a"] | None, int]
reveal_type(g(f("a"))) # revealed: tuple[Literal["a"], int] | None
```

## Passing generic functions to generic functions

```py
from typing import Callable

def invoke[A, B](fn: Callable[[A], B], value: A) -> B:
return fn(value)

def identity[T](x: T) -> T:
return x

def head[T](xs: list[T]) -> T:
return xs[0]

# TODO: this should be `Literal[1]`
reveal_type(invoke(identity, 1)) # revealed: Unknown

# TODO: this should be `Unknown | int`
reveal_type(invoke(head, [1, 2, 3])) # revealed: Unknown
```

## Protocols as TypeVar bounds

Protocol types can be used as TypeVar bounds, just like nominal types.
Expand Down
14 changes: 10 additions & 4 deletions crates/ty_python_semantic/resources/mdtest/narrow/isinstance.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,11 @@ a covariant generic, this is equivalent to using the upper bound of the type par
`object`):

```py
from typing import Self

class Covariant[T]:
def get(self) -> T:
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
def get(self: Self) -> T:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the explicit Self annotations in this test?

Is this the right way to write the tests long-term (i.e. once we support implicit-type-of-self properly)? If not, should we have a TODO here to remove these annotations?

Copy link
Contributor Author

@sharkdp sharkdp Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the explicit Self annotations in this test?

With these explicit Self annotations, these all become regression tests for astral-sh/ty#1196. I could have added new tests, but would have had to duplicate the different variance examples from here. Hopefully self and self: Self will soon be exactly equivalent (?), so yes, we will be able to remove these soon. I added TODO comments.

raise NotImplementedError

def _(x: object):
Expand All @@ -335,7 +338,8 @@ Similarly, contravariant type parameters use their lower bound of `Never`:

```py
class Contravariant[T]:
def push(self, x: T) -> None: ...
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
def push(self: Self, x: T) -> None: ...

def _(x: object):
if isinstance(x, Contravariant):
Expand All @@ -350,8 +354,10 @@ the type system, so we represent it with the internal `Top[]` special form.

```py
class Invariant[T]:
def push(self, x: T) -> None: ...
def get(self) -> T:
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
def push(self: Self, x: T) -> None: ...
# TODO: remove the explicit `Self` annotation, once we support the implicit type of `self`
def get(self: Self) -> T:
raise NotImplementedError

def _(x: object):
Expand Down
14 changes: 13 additions & 1 deletion crates/ty_python_semantic/resources/mdtest/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ reveal_type(foo(b"")) # revealed: bytes
## Methods

```py
from typing import overload
from typing_extensions import Self, overload

class Foo1:
@overload
Expand All @@ -126,6 +126,18 @@ foo2 = Foo2()
reveal_type(foo2.method) # revealed: Overload[() -> None, (x: str) -> str]
reveal_type(foo2.method()) # revealed: None
reveal_type(foo2.method("")) # revealed: str

class Foo3:
@overload
def takes_self_or_int(self: Self, x: Self) -> Self: ...
@overload
def takes_self_or_int(self: Self, x: int) -> int: ...
def takes_self_or_int(self: Self, x: Self | int) -> Self | int:
return x

foo3 = Foo3()
reveal_type(foo3.takes_self_or_int(foo3)) # revealed: Foo3
reveal_type(foo3.takes_self_or_int(1)) # revealed: int
Copy link
Contributor Author

@sharkdp sharkdp Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new test case adapted from a datetime/timedelta example observed in the ecosystem.

The lower of these two assertions was failing with the induct-into-bounds-of-typevars approach, because we would pick the first overload and return Foo3 | Literal[1], if assignability to the upper bound of Self was not properly checked. And if we properly check the assignability to the upper bound of Self, we're getting false from the missing has_relation_to_impl branch, and can't fix astral-sh/ty#1196. So at that point, we need to bring back all of the subtyping/assignability adjustments. And then the induction into the bounds of Self doesn't have any observable effect anymore, so I reverted it.

```

## Constructor
Expand Down
134 changes: 126 additions & 8 deletions crates/ty_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1593,9 +1593,16 @@ impl<'db> Type<'db> {
})
}

(Type::TypeVar(_), _) if relation.is_assignability() => {
// The implicit lower bound of a typevar is `Never`, which means
// that it is always assignable to any other type.

// TODO: record the unification constraints

ConstraintSet::from(true)
}

// `Never` is the bottom type, the empty set.
// Other than one unlikely edge case (TypeVars bound to `Never`),
// no other type is a subtype of or assignable to `Never`.
(_, Type::Never) => ConstraintSet::from(false),

(Type::Union(union), _) => union.elements(db).iter().when_all(db, |&elem_ty| {
Expand Down Expand Up @@ -1632,6 +1639,22 @@ impl<'db> Type<'db> {
// be specialized to `Never`.)
(_, Type::NonInferableTypeVar(_)) => ConstraintSet::from(false),

(_, Type::TypeVar(typevar))
if relation.is_assignability()
&& typevar.typevar(db).upper_bound(db).is_none_or(|bound| {
!self
.has_relation_to_impl(db, bound, relation, visitor)
.is_never_satisfied()
}) =>
{
// TODO: record the unification constraints

typevar
.typevar(db)
.upper_bound(db)
.when_none_or(|bound| self.has_relation_to_impl(db, bound, relation, visitor))
}

// TODO: Infer specializations here
(Type::TypeVar(_), _) | (_, Type::TypeVar(_)) => ConstraintSet::from(false),

Expand Down Expand Up @@ -5662,13 +5685,25 @@ impl<'db> Type<'db> {
],
});
};
let instance = Type::instance(db, class.unknown_specialization(db));

let upper_bound = Type::instance(
db,
class.apply_specialization(db, |generic_context| {
let types = generic_context
.variables(db)
.iter()
.map(|typevar| Type::NonInferableTypeVar(*typevar));

generic_context.specialize(db, types.collect())
}),
);

let class_definition = class.definition(db);
let typevar = TypeVarInstance::new(
db,
ast::name::Name::new_static("Self"),
Some(class_definition),
Some(TypeVarBoundOrConstraints::UpperBound(instance).into()),
Some(TypeVarBoundOrConstraints::UpperBound(upper_bound).into()),
// According to the [spec], we can consider `Self`
// equivalent to an invariant type variable
// [spec]: https://typing.python.org/en/latest/spec/generics.html#self
Expand Down Expand Up @@ -6010,8 +6045,8 @@ impl<'db> Type<'db> {
partial.get(db, bound_typevar).unwrap_or(self)
}
TypeMapping::MarkTypeVarsInferable(binding_context) => {
if bound_typevar.binding_context(db) == *binding_context {
Type::TypeVar(bound_typevar)
if binding_context.is_none_or(|context| context == bound_typevar.binding_context(db)) {
Type::TypeVar(bound_typevar.mark_typevars_inferable(db, visitor))
} else {
self
}
Expand Down Expand Up @@ -6695,8 +6730,17 @@ pub enum TypeMapping<'a, 'db> {
BindSelf(Type<'db>),
/// Replaces occurrences of `typing.Self` with a new `Self` type variable with the given upper bound.
ReplaceSelf { new_upper_bound: Type<'db> },
/// Marks the typevars that are bound by a generic class or function as inferable.
MarkTypeVarsInferable(BindingContext<'db>),
/// Marks type variables as inferable.
///
/// When we create the signature for a generic function, we mark its type variables as inferable. Since
/// the generic function might reference type variables from enclosing generic scopes, we include the
/// function's binding context in order to only mark those type variables as inferable that are actually
/// bound by that function.
///
/// When the parameter is set to `None`, *all* type variables will be marked as inferable. We use this
/// variant when descending into the bounds and/or constraints, and the default value of a type variable,
/// which may include nested type variables (`Self` has a bound of `C[T]` for a generic class `C[T]`).
MarkTypeVarsInferable(Option<BindingContext<'db>>),
/// Create the top or bottom materialization of a type.
Materialize(MaterializationKind),
}
Expand Down Expand Up @@ -7637,6 +7681,43 @@ impl<'db> TypeVarInstance<'db> {
)
}

fn mark_typevars_inferable(
self,
db: &'db dyn Db,
visitor: &ApplyTypeMappingVisitor<'db>,
) -> Self {
// Type variables can have nested type variables in their bounds, constraints, or default value.
// When we mark a type variable as inferable, we also mark all of these nested type variables as
// inferable, so we set the parameter to `None` here.
let type_mapping = &TypeMapping::MarkTypeVarsInferable(None);

Self::new(
db,
self.name(db),
self.definition(db),
self._bound_or_constraints(db)
.map(|bound_or_constraints| match bound_or_constraints {
TypeVarBoundOrConstraintsEvaluation::Eager(bound_or_constraints) => {
bound_or_constraints
.mark_typevars_inferable(db, visitor)
.into()
}
TypeVarBoundOrConstraintsEvaluation::LazyUpperBound
| TypeVarBoundOrConstraintsEvaluation::LazyConstraints => bound_or_constraints,
}),
self.explicit_variance(db),
self._default(db).and_then(|default| match default {
TypeVarDefaultEvaluation::Eager(ty) => {
Some(ty.apply_type_mapping_impl(db, type_mapping, visitor).into())
}
TypeVarDefaultEvaluation::Lazy => self
.lazy_default(db)
.map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor).into()),
}),
self.kind(db),
)
}

fn to_instance(self, db: &'db dyn Db) -> Option<Self> {
let bound_or_constraints = match self.bound_or_constraints(db)? {
TypeVarBoundOrConstraints::UpperBound(upper_bound) => {
Expand Down Expand Up @@ -7867,6 +7948,18 @@ impl<'db> BoundTypeVarInstance<'db> {
)
}

fn mark_typevars_inferable(
self,
db: &'db dyn Db,
visitor: &ApplyTypeMappingVisitor<'db>,
) -> Self {
Self::new(
db,
self.typevar(db).mark_typevars_inferable(db, visitor),
self.binding_context(db),
)
}

fn to_instance(self, db: &'db dyn Db) -> Option<Self> {
Some(Self::new(
db,
Expand Down Expand Up @@ -7972,6 +8065,31 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
}
}
}

fn mark_typevars_inferable(
self,
db: &'db dyn Db,
visitor: &ApplyTypeMappingVisitor<'db>,
) -> Self {
let type_mapping = &TypeMapping::MarkTypeVarsInferable(None);

match self {
TypeVarBoundOrConstraints::UpperBound(bound) => TypeVarBoundOrConstraints::UpperBound(
bound.apply_type_mapping_impl(db, type_mapping, visitor),
),
TypeVarBoundOrConstraints::Constraints(constraints) => {
TypeVarBoundOrConstraints::Constraints(UnionType::new(
db,
constraints
.elements(db)
.iter()
.map(|ty| ty.apply_type_mapping_impl(db, type_mapping, visitor))
.collect::<Vec<_>>()
.into_boxed_slice(),
))
}
}
}
}

/// Error returned if a type is not awaitable.
Expand Down
12 changes: 12 additions & 0 deletions crates/ty_python_semantic/src/types/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,18 @@ fn is_subtype_in_invariant_position<'db>(
let base_bottom = base_type.bottom_materialization(db);

let is_subtype_of = |derived: Type<'db>, base: Type<'db>| {
// TODO:
// This should be removed and properly handled in the respective
// `(Type::TypeVar(_), _) | (_, Type::TypeVar(_))` branch of
// `Type::has_relation_to_impl`. Right now, we can not generally
// return `ConstraintSet::from(true)` from that branch, as that
// leads to union simplification, which means that we lose track
// of type variables without recording the constraints under which
// the relation holds.
if matches!(base, Type::TypeVar(_)) || matches!(derived, Type::TypeVar(_)) {
return ConstraintSet::from(true);
}

derived.has_relation_to_impl(db, base, TypeRelation::Subtyping, visitor)
};
match (derived_materialization, base_materialization) {
Expand Down
8 changes: 6 additions & 2 deletions crates/ty_python_semantic/src/types/signatures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ impl<'db> Signature<'db> {
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref())
.apply_type_mapping(
db,
&TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)),
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(
definition,
))),
);
if function_node.is_async && !is_generator {
KnownClass::CoroutineType
Expand Down Expand Up @@ -1549,7 +1551,9 @@ impl<'db> Parameter<'db> {
annotated_type: parameter.annotation().map(|annotation| {
definition_expression_type(db, definition, annotation).apply_type_mapping(
db,
&TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)),
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(
definition,
))),
)
}),
kind,
Expand Down
Loading