diff --git a/crates/red_knot_python_semantic/resources/mdtest/call/union.md b/crates/red_knot_python_semantic/resources/mdtest/call/union.md index d76fc2ed2b4225..b3615496c1fe9d 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/call/union.md +++ b/crates/red_knot_python_semantic/resources/mdtest/call/union.md @@ -175,3 +175,29 @@ def _(flag: bool): # error: [conflicting-argument-forms] "Argument is used as both a value and a type form in call" reveal_type(f(int)) # revealed: str | Literal[True] ``` + +## Size limit on unions of literals + +Beyond a certain size, large unions of literal types collapse to their nearest super-type (`int`, +`bytes`, `str`). + +```py +from typing import Literal + +def _(literals_2: Literal[0, 1], b: bool, flag: bool): + literals_4 = 2 * literals_2 + literals_2 # Literal[0, 1, 2, 3] + literals_16 = 4 * literals_4 + literals_4 # Literal[0, 1, .., 15] + literals_64 = 4 * literals_16 + literals_4 # Literal[0, 1, .., 63] + literals_128 = 2 * literals_64 + literals_2 # Literal[0, 1, .., 127] + + # Going beyond the MAX_UNION_LITERALS limit (currently 200): + literals_256 = 16 * literals_16 + literals_16 + reveal_type(literals_256) # revealed: int + + # Going beyond the limit when another type is already part of the union + bool_and_literals_128 = b if flag else literals_128 # bool | Literal[0, 1, ..., 127] + literals_128_shifted = literals_128 + 128 # Literal[128, 129, ..., 255] + + # Now union the two: + reveal_type(bool_and_literals_128 if flag else literals_128_shifted) # revealed: int +``` diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 43f46cf715c505..bc2b1fb802771c 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -51,6 +51,10 @@ enum UnionElement<'db> { Type(Type<'db>), } +// TODO increase this once we extend `UnionElement` throughout all union/intersection +// representations, so that we can make large unions of literals fast in all operations. +const MAX_UNION_LITERALS: usize = 200; + pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, @@ -93,10 +97,15 @@ impl<'db> UnionBuilder<'db> { // means we shouldn't add it. Otherwise, add a new `UnionElement::StringLiterals` // containing it. Type::StringLiteral(literal) => { + let mut too_large = false; let mut found = false; for element in &mut self.elements { match element { UnionElement::StringLiterals(literals) => { + if literals.len() >= MAX_UNION_LITERALS { + too_large = true; + break; + } literals.insert(literal); found = true; break; @@ -107,6 +116,10 @@ impl<'db> UnionBuilder<'db> { _ => {} } } + if too_large { + let replace_with = KnownClass::Str.to_instance(self.db); + return self.add(replace_with); + } if !found { self.elements .push(UnionElement::StringLiterals(FxOrderSet::from_iter([ @@ -117,9 +130,14 @@ impl<'db> UnionBuilder<'db> { // Same for bytes literals as for string literals, above. Type::BytesLiteral(literal) => { let mut found = false; + let mut too_large = false; for element in &mut self.elements { match element { UnionElement::BytesLiterals(literals) => { + if literals.len() >= MAX_UNION_LITERALS { + too_large = true; + break; + } literals.insert(literal); found = true; break; @@ -130,6 +148,10 @@ impl<'db> UnionBuilder<'db> { _ => {} } } + if too_large { + let replace_with = KnownClass::Bytes.to_instance(self.db); + return self.add(replace_with); + } if !found { self.elements .push(UnionElement::BytesLiterals(FxOrderSet::from_iter([ @@ -140,9 +162,14 @@ impl<'db> UnionBuilder<'db> { // And same for int literals as well. Type::IntLiteral(literal) => { let mut found = false; + let mut too_large = false; for element in &mut self.elements { match element { UnionElement::IntLiterals(literals) => { + if literals.len() >= MAX_UNION_LITERALS { + too_large = true; + break; + } literals.insert(literal); found = true; break; @@ -153,6 +180,10 @@ impl<'db> UnionBuilder<'db> { _ => {} } } + if too_large { + let replace_with = KnownClass::Int.to_instance(self.db); + return self.add(replace_with); + } if !found { self.elements .push(UnionElement::IntLiterals(FxOrderSet::from_iter([literal])));