Skip to content

Commit a1f3619

Browse files
carljmAlexWaygood
andauthored
[red-knot] optimize building large unions of literals (#17403)
## Summary Special-case literal types in `UnionBuilder` to speed up building large unions of literals. This optimization is extremely effective at speeding up building even a very large union (it improves the large-unions benchmark by 41x!). The problem we can run into is that it is easy to then run into another operation on the very large union (for instance, narrowing may add it to an intersection, which then distributes it over the intersection) which is still slow. I think it is possible to avoid this by extending this optimized "grouped" representation throughout not just `UnionBuilder`, but all of our union and intersection representations. I have some work in this direction, but rather than spending more time on it right now, I'd rather just land this much, along with a limit on the size of these unions (to avoid building really big unions quickly and then hitting issues where they are used.) ## Test Plan Existing tests and benchmarks. --------- Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1 parent 13ea4e5 commit a1f3619

File tree

3 files changed

+149
-23
lines changed

3 files changed

+149
-23
lines changed

crates/red_knot_python_semantic/resources/mdtest/annotations/literal.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def x(
6868
a3: Literal[Literal["w"], Literal["r"], Literal[Literal["w+"]]],
6969
a4: Literal[True] | Literal[1, 2] | Literal["foo"],
7070
):
71-
reveal_type(a1) # revealed: Literal[1, 2, 3, "foo", 5] | None
71+
reveal_type(a1) # revealed: Literal[1, 2, 3, 5, "foo"] | None
7272
reveal_type(a2) # revealed: Literal["w", "r"]
7373
reveal_type(a3) # revealed: Literal["w", "r", "w+"]
7474
reveal_type(a4) # revealed: Literal[True, 1, 2, "foo"]
@@ -108,7 +108,7 @@ def union_example(
108108
None,
109109
],
110110
):
111-
reveal_type(x) # revealed: Unknown | Literal[-1, "A", b"A", b"\x00", b"\x07", 0, 1, "B", "foo", "bar", True] | None
111+
reveal_type(x) # revealed: Unknown | Literal[-1, 0, 1, "A", "B", "foo", "bar", b"A", b"\x00", b"\x07", True] | None
112112
```
113113

114114
## Detecting Literal outside typing and typing_extensions

crates/red_knot_python_semantic/resources/mdtest/annotations/string.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def f1(
105105
from typing import Literal
106106

107107
def f(v: Literal["a", r"b", b"c", "d" "e", "\N{LATIN SMALL LETTER F}", "\x67", """h"""]):
108-
reveal_type(v) # revealed: Literal["a", "b", b"c", "de", "f", "g", "h"]
108+
reveal_type(v) # revealed: Literal["a", "b", "de", "f", "g", "h", b"c"]
109109
```
110110

111111
## Class variables

crates/red_knot_python_semantic/src/types/builder.rs

Lines changed: 146 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,47 @@
1212
//! flattens into the outer one), intersections cannot contain other intersections (also
1313
//! flattens), and intersections cannot contain unions (the intersection distributes over the
1414
//! union, inverting it into a union-of-intersections).
15+
//! * No type in a union can be a subtype of any other type in the union (just eliminate the
16+
//! subtype from the union).
17+
//! * No type in an intersection can be a supertype of any other type in the intersection (just
18+
//! eliminate the supertype from the intersection).
19+
//! * An intersection containing two non-overlapping types simplifies to [`Type::Never`].
1520
//!
1621
//! The implication of these invariants is that a [`UnionBuilder`] does not necessarily build a
1722
//! [`Type::Union`]. For example, if only one type is added to the [`UnionBuilder`], `build()` will
1823
//! just return that type directly. The same is true for [`IntersectionBuilder`]; for example, if a
1924
//! union type is added to the intersection, it will distribute and [`IntersectionBuilder::build`]
2025
//! may end up returning a [`Type::Union`] of intersections.
2126
//!
22-
//! In the future we should have these additional invariants, but they aren't implemented yet:
23-
//! * No type in a union can be a subtype of any other type in the union (just eliminate the
24-
//! subtype from the union).
25-
//! * No type in an intersection can be a supertype of any other type in the intersection (just
26-
//! eliminate the supertype from the intersection).
27-
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
28-
29-
use crate::types::{IntersectionType, KnownClass, Type, TypeVarBoundOrConstraints, UnionType};
27+
//! ## Performance
28+
//!
29+
//! In practice, there are two kinds of unions found in the wild: relatively-small unions made up
30+
//! of normal user types (classes, etc), and large unions made up of literals, which can occur via
31+
//! large enums (not yet implemented) or from string/integer/bytes literals, which can grow due to
32+
//! literal arithmetic or operations on literal strings/bytes. For normal unions, it's most
33+
//! efficient to just store the member types in a vector, and do O(n^2) `is_subtype_of` checks to
34+
//! maintain the union in simplified form. But literal unions can grow to a size where this becomes
35+
//! a performance problem. For this reason, we group literal types in `UnionBuilder`. Since every
36+
//! different string literal type shares exactly the same possible super-types, and none of them
37+
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
38+
//! unnecessary `is_subtype_of` checks.
39+
40+
use crate::types::{
41+
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
42+
TypeVarBoundOrConstraints, UnionType,
43+
};
3044
use crate::{Db, FxOrderSet};
3145
use smallvec::SmallVec;
3246

47+
enum UnionElement<'db> {
48+
IntLiterals(FxOrderSet<i64>),
49+
StringLiterals(FxOrderSet<StringLiteralType<'db>>),
50+
BytesLiterals(FxOrderSet<BytesLiteralType<'db>>),
51+
Type(Type<'db>),
52+
}
53+
3354
pub(crate) struct UnionBuilder<'db> {
34-
elements: Vec<Type<'db>>,
55+
elements: Vec<UnionElement<'db>>,
3556
db: &'db dyn Db,
3657
}
3758

@@ -50,7 +71,8 @@ impl<'db> UnionBuilder<'db> {
5071
/// Collapse the union to a single type: `object`.
5172
fn collapse_to_object(mut self) -> Self {
5273
self.elements.clear();
53-
self.elements.push(Type::object(self.db));
74+
self.elements
75+
.push(UnionElement::Type(Type::object(self.db)));
5476
self
5577
}
5678

@@ -66,6 +88,76 @@ impl<'db> UnionBuilder<'db> {
6688
}
6789
// Adding `Never` to a union is a no-op.
6890
Type::Never => {}
91+
// If adding a string literal, look for an existing `UnionElement::StringLiterals` to
92+
// add it to, or an existing element that is a super-type of string literals, which
93+
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
94+
// containing it.
95+
Type::StringLiteral(literal) => {
96+
let mut found = false;
97+
for element in &mut self.elements {
98+
match element {
99+
UnionElement::StringLiterals(literals) => {
100+
literals.insert(literal);
101+
found = true;
102+
break;
103+
}
104+
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
105+
return self;
106+
}
107+
_ => {}
108+
}
109+
}
110+
if !found {
111+
self.elements
112+
.push(UnionElement::StringLiterals(FxOrderSet::from_iter([
113+
literal,
114+
])));
115+
}
116+
}
117+
// Same for bytes literals as for string literals, above.
118+
Type::BytesLiteral(literal) => {
119+
let mut found = false;
120+
for element in &mut self.elements {
121+
match element {
122+
UnionElement::BytesLiterals(literals) => {
123+
literals.insert(literal);
124+
found = true;
125+
break;
126+
}
127+
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
128+
return self;
129+
}
130+
_ => {}
131+
}
132+
}
133+
if !found {
134+
self.elements
135+
.push(UnionElement::BytesLiterals(FxOrderSet::from_iter([
136+
literal,
137+
])));
138+
}
139+
}
140+
// And same for int literals as well.
141+
Type::IntLiteral(literal) => {
142+
let mut found = false;
143+
for element in &mut self.elements {
144+
match element {
145+
UnionElement::IntLiterals(literals) => {
146+
literals.insert(literal);
147+
found = true;
148+
break;
149+
}
150+
UnionElement::Type(existing) if ty.is_subtype_of(self.db, *existing) => {
151+
return self;
152+
}
153+
_ => {}
154+
}
155+
}
156+
if !found {
157+
self.elements
158+
.push(UnionElement::IntLiterals(FxOrderSet::from_iter([literal])));
159+
}
160+
}
69161
// Adding `object` to a union results in `object`.
70162
ty if ty.is_object(self.db) => {
71163
return self.collapse_to_object();
@@ -81,8 +173,27 @@ impl<'db> UnionBuilder<'db> {
81173
let mut to_remove = SmallVec::<[usize; 2]>::new();
82174
let ty_negated = ty.negate(self.db);
83175

84-
for (index, element) in self.elements.iter().enumerate() {
85-
if Some(*element) == bool_pair {
176+
for (index, element) in self
177+
.elements
178+
.iter()
179+
.map(|element| {
180+
// For literals, the first element in the set can stand in for all the rest,
181+
// since they all have the same super-types. SAFETY: a `UnionElement` of
182+
// literal kind must always have at least one element in it.
183+
match element {
184+
UnionElement::IntLiterals(literals) => Type::IntLiteral(literals[0]),
185+
UnionElement::StringLiterals(literals) => {
186+
Type::StringLiteral(literals[0])
187+
}
188+
UnionElement::BytesLiterals(literals) => {
189+
Type::BytesLiteral(literals[0])
190+
}
191+
UnionElement::Type(ty) => *ty,
192+
}
193+
})
194+
.enumerate()
195+
{
196+
if Some(element) == bool_pair {
86197
to_add = KnownClass::Bool.to_instance(self.db);
87198
to_remove.push(index);
88199
// The type we are adding is a BooleanLiteral, which doesn't have any
@@ -92,14 +203,14 @@ impl<'db> UnionBuilder<'db> {
92203
break;
93204
}
94205

95-
if ty.is_same_gradual_form(*element)
96-
|| ty.is_subtype_of(self.db, *element)
206+
if ty.is_same_gradual_form(element)
207+
|| ty.is_subtype_of(self.db, element)
97208
|| element.is_object(self.db)
98209
{
99210
return self;
100211
} else if element.is_subtype_of(self.db, ty) {
101212
to_remove.push(index);
102-
} else if ty_negated.is_subtype_of(self.db, *element) {
213+
} else if ty_negated.is_subtype_of(self.db, element) {
103214
// We add `ty` to the union. We just checked that `~ty` is a subtype of an existing `element`.
104215
// This also means that `~ty | ty` is a subtype of `element | ty`, because both elements in the
105216
// first union are subtypes of the corresponding elements in the second union. But `~ty | ty` is
@@ -111,24 +222,39 @@ impl<'db> UnionBuilder<'db> {
111222
}
112223
}
113224
if let Some((&first, rest)) = to_remove.split_first() {
114-
self.elements[first] = to_add;
225+
self.elements[first] = UnionElement::Type(to_add);
115226
// We iterate in descending order to keep remaining indices valid after `swap_remove`.
116227
for &index in rest.iter().rev() {
117228
self.elements.swap_remove(index);
118229
}
119230
} else {
120-
self.elements.push(to_add);
231+
self.elements.push(UnionElement::Type(to_add));
121232
}
122233
}
123234
}
124235
self
125236
}
126237

127238
pub(crate) fn build(self) -> Type<'db> {
128-
match self.elements.len() {
239+
let mut types = vec![];
240+
for element in self.elements {
241+
match element {
242+
UnionElement::IntLiterals(literals) => {
243+
types.extend(literals.into_iter().map(Type::IntLiteral));
244+
}
245+
UnionElement::StringLiterals(literals) => {
246+
types.extend(literals.into_iter().map(Type::StringLiteral));
247+
}
248+
UnionElement::BytesLiterals(literals) => {
249+
types.extend(literals.into_iter().map(Type::BytesLiteral));
250+
}
251+
UnionElement::Type(ty) => types.push(ty),
252+
}
253+
}
254+
match types.len() {
129255
0 => Type::Never,
130-
1 => self.elements[0],
131-
_ => Type::Union(UnionType::new(self.db, self.elements.into_boxed_slice())),
256+
1 => types[0],
257+
_ => Type::Union(UnionType::new(self.db, types.into_boxed_slice())),
132258
}
133259
}
134260
}

0 commit comments

Comments
 (0)