Skip to content

Commit 42f18ed

Browse files
committed
Fix unions-of-literals to account for subtypes of literals
1 parent b69013e commit 42f18ed

File tree

2 files changed

+89
-6
lines changed

2 files changed

+89
-6
lines changed

crates/ty_python_semantic/resources/mdtest/union_types.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,42 @@ def _(
214214
reveal_type(bytes_or_falsy) # revealed: Literal[b"foo"] | AlwaysFalsy
215215
reveal_type(falsy_or_bytes) # revealed: AlwaysFalsy | Literal[b"foo"]
216216
```
217+
218+
## Unions with intersections of literals and Any
219+
220+
```toml
221+
[environment]
222+
python-version = "3.12"
223+
```
224+
225+
```py
226+
from typing import Any, Literal
227+
from ty_extensions import Intersection
228+
229+
type SA = Literal[""]
230+
type SB = Intersection[Literal[""], Any]
231+
type SC = SA | SB
232+
type SD = SB | SA
233+
234+
def _(c: SC, d: SD):
235+
reveal_type(c) # revealed: Literal[""]
236+
reveal_type(d) # revealed: Literal[""]
237+
238+
type IA = Literal[0]
239+
type IB = Intersection[Literal[0], Any]
240+
type IC = IA | IB
241+
type ID = IB | IA
242+
243+
def _(c: IC, d: ID):
244+
reveal_type(c) # revealed: Literal[0]
245+
reveal_type(d) # revealed: Literal[0]
246+
247+
type BA = Literal[b""]
248+
type BB = Intersection[Literal[b""], Any]
249+
type BC = BA | BB
250+
type BD = BB | BA
251+
252+
def _(c: BC, d: BD):
253+
reveal_type(c) # revealed: Literal[b""]
254+
reveal_type(d) # revealed: Literal[b""]
255+
```

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,21 @@ impl<'db> UnionElement<'db> {
9999
UnionElement::IntLiterals(literals) => {
100100
if other_type.splits_literals(db, LiteralKind::Int) {
101101
let mut collapse = false;
102+
let mut ignore = false;
102103
let negated = other_type.negate(db);
103104
literals.retain(|literal| {
104105
let ty = Type::IntLiteral(*literal);
105106
if negated.is_subtype_of(db, ty) {
106107
collapse = true;
107108
}
109+
if other_type.is_subtype_of(db, ty) {
110+
ignore = true;
111+
}
108112
!ty.is_subtype_of(db, other_type)
109113
});
110-
if collapse {
114+
if ignore {
115+
ReduceResult::Ignore
116+
} else if collapse {
111117
ReduceResult::CollapseToObject
112118
} else {
113119
ReduceResult::KeepIf(!literals.is_empty())
@@ -121,15 +127,21 @@ impl<'db> UnionElement<'db> {
121127
UnionElement::StringLiterals(literals) => {
122128
if other_type.splits_literals(db, LiteralKind::String) {
123129
let mut collapse = false;
130+
let mut ignore = false;
124131
let negated = other_type.negate(db);
125132
literals.retain(|literal| {
126133
let ty = Type::StringLiteral(*literal);
127134
if negated.is_subtype_of(db, ty) {
128135
collapse = true;
129136
}
137+
if other_type.is_subtype_of(db, ty) {
138+
ignore = true;
139+
}
130140
!ty.is_subtype_of(db, other_type)
131141
});
132-
if collapse {
142+
if ignore {
143+
ReduceResult::Ignore
144+
} else if collapse {
133145
ReduceResult::CollapseToObject
134146
} else {
135147
ReduceResult::KeepIf(!literals.is_empty())
@@ -143,15 +155,21 @@ impl<'db> UnionElement<'db> {
143155
UnionElement::BytesLiterals(literals) => {
144156
if other_type.splits_literals(db, LiteralKind::Bytes) {
145157
let mut collapse = false;
158+
let mut ignore = false;
146159
let negated = other_type.negate(db);
147160
literals.retain(|literal| {
148161
let ty = Type::BytesLiteral(*literal);
149162
if negated.is_subtype_of(db, ty) {
150163
collapse = true;
151164
}
165+
if other_type.is_subtype_of(db, ty) {
166+
ignore = true;
167+
}
152168
!ty.is_subtype_of(db, other_type)
153169
});
154-
if collapse {
170+
if ignore {
171+
ReduceResult::Ignore
172+
} else if collapse {
155173
ReduceResult::CollapseToObject
156174
} else {
157175
ReduceResult::KeepIf(!literals.is_empty())
@@ -173,6 +191,8 @@ enum ReduceResult<'db> {
173191
KeepIf(bool),
174192
/// Collapse this entire union to `object`.
175193
CollapseToObject,
194+
/// The new element is a subtype of an existing part of the `UnionElement`, ignore it.
195+
Ignore,
176196
/// The given `Type` can stand-in for the entire `UnionElement` for further union
177197
/// simplification checks.
178198
Type(Type<'db>),
@@ -230,8 +250,9 @@ impl<'db> UnionBuilder<'db> {
230250
// containing it.
231251
Type::StringLiteral(literal) => {
232252
let mut found = false;
253+
let mut to_remove = None;
233254
let ty_negated = ty.negate(self.db);
234-
for element in &mut self.elements {
255+
for (index, element) in self.elements.iter_mut().enumerate() {
235256
match element {
236257
UnionElement::StringLiterals(literals) => {
237258
if literals.len() >= MAX_UNION_LITERALS {
@@ -247,6 +268,9 @@ impl<'db> UnionBuilder<'db> {
247268
if ty.is_subtype_of(self.db, *existing) {
248269
return;
249270
}
271+
if existing.is_subtype_of(self.db, ty) {
272+
to_remove = Some(index);
273+
}
250274
if ty_negated.is_subtype_of(self.db, *existing) {
251275
// The type that includes both this new element, and its negation
252276
// (or a supertype of its negation), must be simply `object`.
@@ -257,6 +281,9 @@ impl<'db> UnionBuilder<'db> {
257281
_ => {}
258282
}
259283
}
284+
if let Some(index) = to_remove {
285+
self.elements.swap_remove(index);
286+
}
260287
if !found {
261288
self.elements
262289
.push(UnionElement::StringLiterals(FxOrderSet::from_iter([
@@ -267,8 +294,9 @@ impl<'db> UnionBuilder<'db> {
267294
// Same for bytes literals as for string literals, above.
268295
Type::BytesLiteral(literal) => {
269296
let mut found = false;
297+
let mut to_remove = None;
270298
let ty_negated = ty.negate(self.db);
271-
for element in &mut self.elements {
299+
for (index, element) in self.elements.iter_mut().enumerate() {
272300
match element {
273301
UnionElement::BytesLiterals(literals) => {
274302
if literals.len() >= MAX_UNION_LITERALS {
@@ -284,6 +312,9 @@ impl<'db> UnionBuilder<'db> {
284312
if ty.is_subtype_of(self.db, *existing) {
285313
return;
286314
}
315+
if existing.is_subtype_of(self.db, ty) {
316+
to_remove = Some(index);
317+
}
287318
if ty_negated.is_subtype_of(self.db, *existing) {
288319
// The type that includes both this new element, and its negation
289320
// (or a supertype of its negation), must be simply `object`.
@@ -294,6 +325,9 @@ impl<'db> UnionBuilder<'db> {
294325
_ => {}
295326
}
296327
}
328+
if let Some(index) = to_remove {
329+
self.elements.swap_remove(index);
330+
}
297331
if !found {
298332
self.elements
299333
.push(UnionElement::BytesLiterals(FxOrderSet::from_iter([
@@ -304,8 +338,9 @@ impl<'db> UnionBuilder<'db> {
304338
// And same for int literals as well.
305339
Type::IntLiteral(literal) => {
306340
let mut found = false;
341+
let mut to_remove = None;
307342
let ty_negated = ty.negate(self.db);
308-
for element in &mut self.elements {
343+
for (index, element) in self.elements.iter_mut().enumerate() {
309344
match element {
310345
UnionElement::IntLiterals(literals) => {
311346
if literals.len() >= MAX_UNION_LITERALS {
@@ -321,6 +356,9 @@ impl<'db> UnionBuilder<'db> {
321356
if ty.is_subtype_of(self.db, *existing) {
322357
return;
323358
}
359+
if existing.is_subtype_of(self.db, ty) {
360+
to_remove = Some(index);
361+
}
324362
if ty_negated.is_subtype_of(self.db, *existing) {
325363
// The type that includes both this new element, and its negation
326364
// (or a supertype of its negation), must be simply `object`.
@@ -331,6 +369,9 @@ impl<'db> UnionBuilder<'db> {
331369
_ => {}
332370
}
333371
}
372+
if let Some(index) = to_remove {
373+
self.elements.swap_remove(index);
374+
}
334375
if !found {
335376
self.elements
336377
.push(UnionElement::IntLiterals(FxOrderSet::from_iter([literal])));
@@ -363,6 +404,9 @@ impl<'db> UnionBuilder<'db> {
363404
self.collapse_to_object();
364405
return;
365406
}
407+
ReduceResult::Ignore => {
408+
return;
409+
}
366410
};
367411
if Some(element_type) == bool_pair {
368412
self.add_in_place(KnownClass::Bool.to_instance(self.db));

0 commit comments

Comments
 (0)