Skip to content

Commit 7995c73

Browse files
committed
[red-knot] optimize building large unions of literals
1 parent 1dedcb9 commit 7995c73

File tree

3 files changed

+150
-23
lines changed

3 files changed

+150
-23
lines changed

crates/red_knot_python_semantic/resources/mdtest/annotations/literal.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def x(
6868
a3: Literal[Literal["w"], Literal["r"], Literal[Literal["w+"]]],
6969
a4: Literal[True] | Literal[1, 2] | Literal["foo"],
7070
):
71-
reveal_type(a1) # revealed: Literal[1, 2, 3, "foo", 5] | None
71+
reveal_type(a1) # revealed: Literal[1, 2, 3, 5, "foo"] | None
7272
reveal_type(a2) # revealed: Literal["w", "r"]
7373
reveal_type(a3) # revealed: Literal["w", "r", "w+"]
7474
reveal_type(a4) # revealed: Literal[True, 1, 2, "foo"]
@@ -108,7 +108,7 @@ def union_example(
108108
None,
109109
],
110110
):
111-
reveal_type(x) # revealed: Unknown | Literal[-1, "A", b"A", b"\x00", b"\x07", 0, 1, "B", "foo", "bar", True] | None
111+
reveal_type(x) # revealed: Unknown | Literal[-1, 0, 1, "A", "B", "foo", "bar", b"A", b"\x00", b"\x07", True] | None
112112
```
113113

114114
## Detecting Literal outside typing and typing_extensions

crates/red_knot_python_semantic/resources/mdtest/annotations/string.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def f1(
105105
from typing import Literal
106106

107107
def f(v: Literal["a", r"b", b"c", "d" "e", "\N{LATIN SMALL LETTER F}", "\x67", """h"""]):
108-
reveal_type(v) # revealed: Literal["a", "b", b"c", "de", "f", "g", "h"]
108+
reveal_type(v) # revealed: Literal["a", "b", "de", "f", "g", "h", b"c"]
109109
```
110110

111111
## Class variables

crates/red_knot_python_semantic/src/types/builder.rs

Lines changed: 147 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,47 @@
1212
//! flattens into the outer one), intersections cannot contain other intersections (also
1313
//! flattens), and intersections cannot contain unions (the intersection distributes over the
1414
//! union, inverting it into a union-of-intersections).
15+
//! * No type in a union can be a subtype of any other type in the union (just eliminate the
16+
//! subtype from the union).
17+
//! * No type in an intersection can be a supertype of any other type in the intersection (just
18+
//! eliminate the supertype from the intersection).
19+
//! * An intersection containing two non-overlapping types simplifies to [`Type::Never`].
1520
//!
1621
//! The implication of these invariants is that a [`UnionBuilder`] does not necessarily build a
1722
//! [`Type::Union`]. For example, if only one type is added to the [`UnionBuilder`], `build()` will
1823
//! just return that type directly. The same is true for [`IntersectionBuilder`]; for example, if a
1924
//! union type is added to the intersection, it will distribute and [`IntersectionBuilder::build`]
2025
//! may end up returning a [`Type::Union`] of intersections.
2126
//!
22-
//! In the future we should have these additional invariants, but they aren't implemented yet:
23-
//! * No type in a union can be a subtype of any other type in the union (just eliminate the
24-
//! subtype from the union).
25-
//! * No type in an intersection can be a supertype of any other type in the intersection (just
26-
//! eliminate the supertype from the intersection).
27-
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
28-
29-
use crate::types::{IntersectionType, KnownClass, Type, TypeVarBoundOrConstraints, UnionType};
27+
//! ## Performance
28+
//!
29+
//! In practice, there are two kinds of unions found in the wild: relatively-small unions made up
30+
//! of normal user types (classes, etc), and large unions made up of literals, which can occur via
31+
//! large enums (not yet implemented) or from string/integer/bytes literals, which can grow due to
32+
//! literal arithmetic or operations on literal strings/bytes. For normal unions, it's most
33+
//! efficient to just store the member types in a vector, and do O(n^2) `is_subtype_of` checks to
34+
//! maintain the union in simplified form. But literal unions can grow to a size where this becomes
35+
//! a performance problem. For this reason, we group literal types in `UnionBuilder`. Since every
36+
//! different string literal type shares exactly the same possible super-types, and none of them
37+
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
38+
//! unnecessary `is_subtype_of` checks.
39+
40+
use crate::types::{
41+
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
42+
TypeVarBoundOrConstraints, UnionType,
43+
};
3044
use crate::{Db, FxOrderSet};
3145
use smallvec::SmallVec;
3246

47+
enum UnionElement<'db> {
48+
IntLiterals(FxOrderSet<i64>),
49+
StringLiterals(FxOrderSet<StringLiteralType<'db>>),
50+
BytesLiterals(FxOrderSet<BytesLiteralType<'db>>),
51+
Type(Type<'db>),
52+
}
53+
3354
pub(crate) struct UnionBuilder<'db> {
34-
elements: Vec<Type<'db>>,
55+
elements: Vec<UnionElement<'db>>,
3556
db: &'db dyn Db,
3657
}
3758

@@ -50,7 +71,8 @@ impl<'db> UnionBuilder<'db> {
5071
/// Collapse the union to a single type: `object`.
5172
fn collapse_to_object(mut self) -> Self {
5273
self.elements.clear();
53-
self.elements.push(Type::object(self.db));
74+
self.elements
75+
.push(UnionElement::Type(Type::object(self.db)));
5476
self
5577
}
5678

@@ -66,6 +88,75 @@ impl<'db> UnionBuilder<'db> {
6688
}
6789
// Adding `Never` to a union is a no-op.
6890
Type::Never => {}
91+
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
92+
// add it to, or an existing element that is a super-type of string literals, which
93+
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
94+
// containing it.
95+
Type::StringLiteral(literal) => {
96+
let mut found = false;
97+
for element in &mut self.elements {
98+
match element {
99+
UnionElement::StringLiterals(literals) => {
100+
literals.insert(literal);
101+
found = true;
102+
break;
103+
}
104+
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
105+
return self;
106+
}
107+
_ => {}
108+
}
109+
}
110+
if !found {
111+
let mut literals = FxOrderSet::default();
112+
literals.insert(literal);
113+
self.elements.push(UnionElement::StringLiterals(literals));
114+
}
115+
}
116+
// Same for bytes literals as for string literals, above.
117+
Type::BytesLiteral(literal) => {
118+
let mut found = false;
119+
for element in &mut self.elements {
120+
match element {
121+
UnionElement::BytesLiterals(literals) => {
122+
literals.insert(literal);
123+
found = true;
124+
break;
125+
}
126+
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
127+
return self;
128+
}
129+
_ => {}
130+
}
131+
}
132+
if !found {
133+
let mut literals = FxOrderSet::default();
134+
literals.insert(literal);
135+
self.elements.push(UnionElement::BytesLiterals(literals));
136+
}
137+
}
138+
// And same for int literals as well.
139+
Type::IntLiteral(literal) => {
140+
let mut found = false;
141+
for element in &mut self.elements {
142+
match element {
143+
UnionElement::IntLiterals(literals) => {
144+
literals.insert(literal);
145+
found = true;
146+
break;
147+
}
148+
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
149+
return self;
150+
}
151+
_ => {}
152+
}
153+
}
154+
if !found {
155+
let mut literals = FxOrderSet::default();
156+
literals.insert(literal);
157+
self.elements.push(UnionElement::IntLiterals(literals));
158+
}
159+
}
69160
// Adding `object` to a union results in `object`.
70161
ty if ty.is_object(self.db) => {
71162
return self.collapse_to_object();
@@ -81,8 +172,29 @@ impl<'db> UnionBuilder<'db> {
81172
let mut to_remove = SmallVec::<[usize; 2]>::new();
82173
let ty_negated = ty.negate(self.db);
83174

84-
for (index, element) in self.elements.iter().enumerate() {
85-
if Some(*element) == bool_pair {
175+
for (index, element) in self
176+
.elements
177+
.iter()
178+
.map(|element| {
179+
// For literals, the first element in the set can stand in for all the rest,
180+
// since they all have the same super-types. SAFETY: a `UnionElement` of
181+
// literal kind must always have at least one element in it.
182+
match element {
183+
UnionElement::IntLiterals(literals) => {
184+
Type::IntLiteral(*literals.iter().next().unwrap())
185+
}
186+
UnionElement::StringLiterals(literals) => {
187+
Type::StringLiteral(*literals.iter().next().unwrap())
188+
}
189+
UnionElement::BytesLiterals(literals) => {
190+
Type::BytesLiteral(*literals.iter().next().unwrap())
191+
}
192+
UnionElement::Type(ty) => *ty,
193+
}
194+
})
195+
.enumerate()
196+
{
197+
if Some(element) == bool_pair {
86198
to_add = KnownClass::Bool.to_instance(self.db);
87199
to_remove.push(index);
88200
// The type we are adding is a BooleanLiteral, which doesn't have any
@@ -92,14 +204,14 @@ impl<'db> UnionBuilder<'db> {
92204
break;
93205
}
94206

95-
if ty.is_same_gradual_form(*element)
96-
|| ty.is_subtype_of(self.db, *element)
207+
if ty.is_same_gradual_form(element)
208+
|| ty.is_subtype_of(self.db, element)
97209
|| element.is_object(self.db)
98210
{
99211
return self;
100212
} else if element.is_subtype_of(self.db, ty) {
101213
to_remove.push(index);
102-
} else if ty_negated.is_subtype_of(self.db, *element) {
214+
} else if ty_negated.is_subtype_of(self.db, element) {
103215
// We add `ty` to the union. We just checked that `~ty` is a subtype of an existing `element`.
104216
// This also means that `~ty | ty` is a subtype of `element | ty`, because both elements in the
105217
// first union are subtypes of the corresponding elements in the second union. But `~ty | ty` is
@@ -111,24 +223,39 @@ impl<'db> UnionBuilder<'db> {
111223
}
112224
}
113225
if let Some((&first, rest)) = to_remove.split_first() {
114-
self.elements[first] = to_add;
226+
self.elements[first] = UnionElement::Type(to_add);
115227
// We iterate in descending order to keep remaining indices valid after `swap_remove`.
116228
for &index in rest.iter().rev() {
117229
self.elements.swap_remove(index);
118230
}
119231
} else {
120-
self.elements.push(to_add);
232+
self.elements.push(UnionElement::Type(to_add));
121233
}
122234
}
123235
}
124236
self
125237
}
126238

127239
pub(crate) fn build(self) -> Type<'db> {
128-
match self.elements.len() {
240+
let mut types = vec![];
241+
for element in self.elements {
242+
match element {
243+
UnionElement::IntLiterals(literals) => {
244+
types.extend(literals.into_iter().map(Type::IntLiteral));
245+
}
246+
UnionElement::StringLiterals(literals) => {
247+
types.extend(literals.into_iter().map(Type::StringLiteral));
248+
}
249+
UnionElement::BytesLiterals(literals) => {
250+
types.extend(literals.into_iter().map(Type::BytesLiteral));
251+
}
252+
UnionElement::Type(ty) => types.push(ty),
253+
}
254+
}
255+
match types.len() {
129256
0 => Type::Never,
130-
1 => self.elements[0],
131-
_ => Type::Union(UnionType::new(self.db, self.elements.into_boxed_slice())),
257+
1 => types[0],
258+
_ => Type::Union(UnionType::new(self.db, types.into_boxed_slice())),
132259
}
133260
}
134261
}

0 commit comments

Comments
 (0)