Skip to content

Commit 4e68dd9

Browse files
authored
[ty] Infer types for ty_extensions.Intersection[A, B] tuple expressions (#18321)
## Summary fixes astral-sh/ty#366 ## Test Plan * Added panic corpus regression tests * I also wrote a hover regression test (see below), but decided not to include it. The corpus tests are much more "effective" at finding these types of errors, since they exhaustively check all expressions for types. <details> ```rs #[test] fn hover_regression_test_366() { let test = cursor_test( r#" from ty_extensions import Intersection class A: ... class B: ... def _(x: Intersection[A,<CURSOR> B]): pass "#, ); assert_snapshot!(test.hover(), @r" A & B --------------------------------------------- ```text A & B ``` --------------------------------------------- info[hover]: Hovered content is --> main.py:7:31 | 5 | class B: ... 6 | 7 | def _(x: Intersection[A, B]): | ^^-^ | | | | | Cursor offset | source 8 | pass | "); } ``` </details>
1 parent b25b642 commit 4e68dd9

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Make sure that types are inferred for all subexpressions of the following
3+
annotations involving ty_extension `_SpecialForm`s.
4+
5+
This is a regression test for https://github.com/astral-sh/ty/issues/366
6+
"""
7+
8+
from ty_extensions import CallableTypeOf, Intersection, Not, TypeOf
9+
10+
11+
class A: ...
12+
13+
14+
class B: ...
15+
16+
17+
def _(x: Not[A]):
18+
pass
19+
20+
21+
def _(x: Intersection[A], y: Intersection[A, B]):
22+
pass
23+
24+
25+
def _(x: TypeOf[1j]):
26+
pass
27+
28+
29+
def _(x: CallableTypeOf[str]):
30+
pass

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8648,11 +8648,16 @@ impl<'db> TypeInferenceBuilder<'db> {
86488648
element => Either::Right(std::iter::once(element)),
86498649
};
86508650

8651-
elements
8651+
let ty = elements
86528652
.fold(IntersectionBuilder::new(db), |builder, element| {
86538653
builder.add_positive(self.infer_type_expression(element))
86548654
})
8655-
.build()
8655+
.build();
8656+
8657+
if matches!(arguments_slice, ast::Expr::Tuple(_)) {
8658+
self.store_expression_type(arguments_slice, ty);
8659+
}
8660+
ty
86568661
}
86578662
KnownInstanceType::TypeOf => match arguments_slice {
86588663
ast::Expr::Tuple(_) => {

0 commit comments

Comments
 (0)