Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/call/union.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,29 @@ def _(flag: bool):
# error: [conflicting-argument-forms] "Argument is used as both a value and a type form in call"
reveal_type(f(int)) # revealed: str | Literal[True]
```

## Size limit on unions of literals

Beyond a certain size, large unions of literal types collapse to their nearest super-type (`int`,
`bytes`, `str`).

```py
from typing import Literal

def _(literals_2: Literal[0, 1], b: bool, flag: bool):
literals_4 = 2 * literals_2 + literals_2 # Literal[0, 1, 2, 3]
literals_16 = 4 * literals_4 + literals_4 # Literal[0, 1, .., 15]
literals_64 = 4 * literals_16 + literals_4 # Literal[0, 1, .., 63]
literals_128 = 2 * literals_64 + literals_2 # Literal[0, 1, .., 127]

# Going beyond the MAX_UNION_LITERALS limit (currently 200):
literals_256 = 16 * literals_16 + literals_16
reveal_type(literals_256) # revealed: int

# Going beyond the limit when another type is already part of the union
bool_and_literals_128 = b if flag else literals_128 # bool | Literal[0, 1, ..., 127]
literals_128_shifted = literals_128 + 128 # Literal[128, 129, ..., 255]

# Now union the two:
reveal_type(bool_and_literals_128 if flag else literals_128_shifted) # revealed: int
```
31 changes: 31 additions & 0 deletions crates/red_knot_python_semantic/src/types/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ enum UnionElement<'db> {
Type(Type<'db>),
}

// TODO increase this once we extend `UnionElement` throughout all union/intersection
// representations, so that we can make large unions of literals fast in all operations.
const MAX_UNION_LITERALS: usize = 200;

pub(crate) struct UnionBuilder<'db> {
elements: Vec<UnionElement<'db>>,
db: &'db dyn Db,
Expand Down Expand Up @@ -93,10 +97,15 @@ impl<'db> UnionBuilder<'db> {
// means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals`
// containing it.
Type::StringLiteral(literal) => {
let mut too_large = false;
let mut found = false;
for element in &mut self.elements {
match element {
UnionElement::StringLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
too_large = true;
break;
}
literals.insert(literal);
found = true;
break;
Expand All @@ -107,6 +116,10 @@ impl<'db> UnionBuilder<'db> {
_ => {}
}
}
if too_large {
let replace_with = KnownClass::Str.to_instance(self.db);
return self.add(replace_with);
}
if !found {
self.elements
.push(UnionElement::StringLiterals(FxOrderSet::from_iter([
Expand All @@ -117,9 +130,14 @@ impl<'db> UnionBuilder<'db> {
// Same for bytes literals as for string literals, above.
Type::BytesLiteral(literal) => {
let mut found = false;
let mut too_large = false;
for element in &mut self.elements {
match element {
UnionElement::BytesLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
too_large = true;
break;
}
literals.insert(literal);
found = true;
break;
Expand All @@ -130,6 +148,10 @@ impl<'db> UnionBuilder<'db> {
_ => {}
}
}
if too_large {
let replace_with = KnownClass::Bytes.to_instance(self.db);
return self.add(replace_with);
}
if !found {
self.elements
.push(UnionElement::BytesLiterals(FxOrderSet::from_iter([
Expand All @@ -140,9 +162,14 @@ impl<'db> UnionBuilder<'db> {
// And same for int literals as well.
Type::IntLiteral(literal) => {
let mut found = false;
let mut too_large = false;
for element in &mut self.elements {
match element {
UnionElement::IntLiterals(literals) => {
if literals.len() >= MAX_UNION_LITERALS {
too_large = true;
break;
}
literals.insert(literal);
found = true;
break;
Expand All @@ -153,6 +180,10 @@ impl<'db> UnionBuilder<'db> {
_ => {}
}
}
if too_large {
let replace_with = KnownClass::Int.to_instance(self.db);
return self.add(replace_with);
}
if !found {
self.elements
.push(UnionElement::IntLiterals(FxOrderSet::from_iter([literal])));
Expand Down
Loading