Skip to content

Commit cb1dc4c

Browse files
committed
[ty] Expansion of enums into unions of literals
1 parent b8dddd5 commit cb1dc4c

File tree

7 files changed

+145
-5
lines changed

7 files changed

+145
-5
lines changed

crates/ty_python_semantic/resources/mdtest/intersection_types.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,49 @@ def f(
763763
reveal_type(j) # revealed: Unknown & Literal[""]
764764
```
765765

766+
## Simplification of enum literals
767+
768+
```toml
769+
[environment]
770+
python-version = "3.12"
771+
```
772+
773+
```py
774+
from ty_extensions import Intersection, Not
775+
from typing import Literal
776+
from enum import Enum
777+
778+
class Color(Enum):
779+
RED = "red"
780+
GREEN = "green"
781+
BLUE = "blue"
782+
783+
type Red = Literal[Color.RED]
784+
type Green = Literal[Color.GREEN]
785+
type Blue = Literal[Color.BLUE]
786+
787+
def f(
788+
a: Intersection[Color, Red],
789+
b: Intersection[Color, Not[Red]],
790+
c: Intersection[Color, Not[Red | Green]],
791+
d: Intersection[Color, Not[Red | Green | Blue]],
792+
e: Intersection[Red, Not[Color]],
793+
f: Intersection[Red | Green, Not[Color]],
794+
g: Intersection[Not[Red], Color],
795+
h: Intersection[Red, Green],
796+
i: Intersection[Red | Green, Green | Blue],
797+
):
798+
reveal_type(a) # revealed: Literal[Color.RED]
799+
reveal_type(b) # revealed: Literal[Color.GREEN, Color.BLUE]
800+
reveal_type(c) # revealed: Literal[Color.BLUE]
801+
reveal_type(d) # revealed: Never
802+
reveal_type(e) # revealed: Never
803+
reveal_type(f) # revealed: Never
804+
reveal_type(g) # revealed: Literal[Color.GREEN, Color.BLUE]
805+
reveal_type(h) # revealed: Never
806+
reveal_type(i) # revealed: Literal[Color.GREEN]
807+
```
808+
766809
## Addition of a type to an intersection with many non-disjoint types
767810

768811
This slightly strange-looking test is a regression test for a mistake that was nearly made in a PR:

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/eq.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ class Answer(Enum):
3535

3636
def _(answer: Answer):
3737
if answer != Answer.NO:
38-
# TODO: This should be simplified to `Literal[Answer.YES]`
39-
reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO]
38+
reveal_type(answer) # revealed: Literal[Answer.YES]
4039
else:
4140
# TODO: This should be `Literal[Answer.NO]`
4241
reveal_type(answer) # revealed: Answer

crates/ty_python_semantic/resources/mdtest/narrow/conditionals/is.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def _(answer: Answer):
7878
if answer is Answer.NO:
7979
reveal_type(answer) # revealed: Literal[Answer.NO]
8080
else:
81-
# TODO: This should be `Literal[Answer.YES]`
82-
reveal_type(answer) # revealed: Answer & ~Literal[Answer.NO]
81+
reveal_type(answer) # revealed: Literal[Answer.YES]
8382
```
8483

8584
## `is` for `EllipsisType` (Python 3.10+)

crates/ty_python_semantic/resources/mdtest/union_types.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,31 @@ def _(
114114
reveal_type(u5) # revealed: bool | Literal[17]
115115
```
116116

117+
## Enum literals
118+
119+
```py
120+
from enum import Enum
121+
from typing import Literal
122+
123+
class Color(Enum):
124+
RED = "red"
125+
GREEN = "green"
126+
BLUE = "blue"
127+
128+
def _(
129+
u1: Literal[Color.RED, Color.GREEN],
130+
u2: Color | Literal[Color.RED],
131+
u3: Literal[Color.RED] | Color,
132+
u4: Literal[Color.RED] | Literal[Color.RED, Color.GREEN],
133+
u5: Literal[Color.RED, Color.GREEN, Color.BLUE],
134+
) -> None:
135+
reveal_type(u1) # revealed: Literal[Color.RED, Color.GREEN]
136+
reveal_type(u2) # revealed: Color
137+
reveal_type(u3) # revealed: Color
138+
reveal_type(u4) # revealed: Literal[Color.RED, Color.GREEN]
139+
reveal_type(u5) # revealed: Color
140+
```
141+
117142
## Do not erase `Unknown`
118143

119144
```py

crates/ty_python_semantic/src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8334,6 +8334,7 @@ pub struct EnumLiteralType<'db> {
83348334
/// A reference to the enum class this literal belongs to
83358335
enum_class: ClassLiteral<'db>,
83368336
/// The name of the enum member
8337+
#[returns(ref)]
83378338
name: Name,
83388339
}
83398340

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
3838
//! unnecessary `is_subtype_of` checks.
3939
40+
use crate::types::enums::{enum_metadata, expand_enum_to_member_union};
4041
use crate::types::{
4142
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
4243
TypeVarBoundOrConstraints, UnionType,
@@ -374,6 +375,42 @@ impl<'db> UnionBuilder<'db> {
374375
self.elements.swap_remove(index);
375376
}
376377
}
378+
Type::EnumLiteral(enum_literal) => {
379+
let enum_class = enum_literal.enum_class(self.db);
380+
381+
let metadata = enum_metadata(self.db, enum_class);
382+
let metadata = metadata.as_ref().expect("Class of enum literal is an enum");
383+
384+
// TODO: This is a horribly slow implementation just to check if this works.
385+
let all_members = metadata.members.iter().collect::<FxOrderSet<_>>();
386+
387+
let mut existing_enum_members = self
388+
.elements
389+
.iter()
390+
.filter_map(|element| {
391+
if let UnionElement::Type(Type::EnumLiteral(lit)) = element {
392+
Some(lit.name(self.db))
393+
} else {
394+
None
395+
}
396+
})
397+
.collect::<FxOrderSet<_>>();
398+
existing_enum_members.insert(enum_literal.name(self.db));
399+
400+
if all_members == existing_enum_members {
401+
self.add_in_place(enum_literal.enum_class_instance(self.db));
402+
} else {
403+
if !self.elements.iter().any(|element| match element {
404+
UnionElement::Type(ty) => {
405+
Type::EnumLiteral(enum_literal).is_subtype_of(self.db, *ty)
406+
}
407+
_ => false,
408+
}) {
409+
self.elements
410+
.push(UnionElement::Type(Type::EnumLiteral(enum_literal)));
411+
}
412+
}
413+
}
377414
// Adding `object` to a union results in `object`.
378415
ty if ty.is_object(self.db) => {
379416
self.collapse_to_object();
@@ -773,6 +810,8 @@ impl<'db> InnerIntersectionBuilder<'db> {
773810
.any(KnownClass::is_bool)
774811
};
775812

813+
let contains_enum = |enum_instance| self.positive.iter().any(|ty| *ty == enum_instance);
814+
776815
match new_negative {
777816
Type::Intersection(inter) => {
778817
for pos in inter.positive(db) {
@@ -814,6 +853,18 @@ impl<'db> InnerIntersectionBuilder<'db> {
814853
Type::AlwaysFalsy if self.positive.contains(&Type::LiteralString) => {
815854
self.add_negative(db, Type::string_literal(db, ""));
816855
}
856+
Type::EnumLiteral(enum_literal)
857+
if contains_enum(enum_literal.enum_class_instance(db)) =>
858+
{
859+
self.add_positive(
860+
db,
861+
expand_enum_to_member_union(
862+
db,
863+
enum_literal.enum_class(db),
864+
Some(enum_literal.name(db)),
865+
),
866+
);
867+
}
817868
_ => {
818869
let mut to_remove = SmallVec::<[usize; 1]>::new();
819870
for (index, existing_negative) in self.negative.iter().enumerate() {

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ use crate::{
55
Db,
66
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
77
semantic_index::{place_table, use_def_map},
8-
types::{ClassLiteral, DynamicType, KnownClass, MemberLookupPolicy, Type, TypeQualifiers},
8+
types::{
9+
ClassLiteral, DynamicType, EnumLiteralType, KnownClass, MemberLookupPolicy, Type,
10+
TypeQualifiers, UnionType,
11+
},
912
};
1013

1114
#[derive(Debug, PartialEq, Eq, get_size2::GetSize)]
@@ -217,3 +220,22 @@ pub(crate) fn enum_metadata<'db>(
217220

218221
Some(EnumMetadata { members, aliases })
219222
}
223+
224+
pub(crate) fn expand_enum_to_member_union<'db>(
225+
db: &'db dyn Db,
226+
class: ClassLiteral<'db>,
227+
exclude_member: Option<&Name>,
228+
) -> Type<'db> {
229+
enum_metadata(db, class)
230+
.as_ref()
231+
.map_or(Type::Never, |metadata| {
232+
UnionType::from_elements(
233+
db,
234+
metadata
235+
.members
236+
.iter()
237+
.filter(|name| Some(*name) != exclude_member)
238+
.map(|name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone()))),
239+
)
240+
})
241+
}

0 commit comments

Comments
 (0)