Skip to content

Commit 07faf92

Browse files
committed
Enum expansion in overload resolution
1 parent cb1dc4c commit 07faf92

File tree

4 files changed

+40
-29
lines changed

4 files changed

+40
-29
lines changed

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,7 @@ def _(x: SomeEnum):
400400
reveal_type(f(SomeEnum.A)) # revealed: A
401401
reveal_type(f(SomeEnum.B)) # revealed: B
402402
reveal_type(f(SomeEnum.C)) # revealed: C
403-
# TODO: This should not be an error. The return type should be `A | B | C` once enums are expanded
404-
# error: [no-matching-overload]
405-
reveal_type(f(x)) # revealed: Unknown
403+
reveal_type(f(x)) # revealed: A | B | C
406404
```
407405

408406
### No matching overloads

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
//! are subtypes of each other (unless exactly the same literal type), we can avoid many
3838
//! unnecessary `is_subtype_of` checks.
3939
40-
use crate::types::enums::{enum_metadata, expand_enum_to_member_union};
40+
use crate::types::enums::{enum_member_literals, enum_metadata};
4141
use crate::types::{
4242
BytesLiteralType, IntersectionType, KnownClass, StringLiteralType, Type,
4343
TypeVarBoundOrConstraints, UnionType,
@@ -858,10 +858,13 @@ impl<'db> InnerIntersectionBuilder<'db> {
858858
{
859859
self.add_positive(
860860
db,
861-
expand_enum_to_member_union(
861+
UnionType::from_elements(
862862
db,
863-
enum_literal.enum_class(db),
864-
Some(enum_literal.name(db)),
863+
enum_member_literals(
864+
db,
865+
enum_literal.enum_class(db),
866+
Some(enum_literal.name(db)),
867+
),
865868
),
866869
);
867870
}

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use ruff_python_ast as ast;
55

66
use crate::Db;
77
use crate::types::KnownClass;
8+
use crate::types::enums::{enum_member_literals, enum_metadata};
89
use crate::types::tuple::{TupleSpec, TupleType};
910

1011
use super::Type;
@@ -199,13 +200,22 @@ impl<'a, 'db> FromIterator<(Argument<'a>, Option<Type<'db>>)> for CallArguments<
199200
///
200201
/// Returns [`None`] if the type cannot be expanded.
201202
fn expand_type<'db>(db: &'db dyn Db, ty: Type<'db>) -> Option<Vec<Type<'db>>> {
202-
// TODO: Expand enums to their variants
203203
match ty {
204-
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Bool) => {
205-
Some(vec![
206-
Type::BooleanLiteral(true),
207-
Type::BooleanLiteral(false),
208-
])
204+
Type::NominalInstance(instance) => {
205+
if instance.class.is_known(db, KnownClass::Bool) {
206+
return Some(vec![
207+
Type::BooleanLiteral(true),
208+
Type::BooleanLiteral(false),
209+
]);
210+
}
211+
212+
let class_literal = instance.class.class_literal(db).0;
213+
214+
if enum_metadata(db, class_literal).is_some() {
215+
return Some(enum_member_literals(db, class_literal, None).collect());
216+
}
217+
218+
None
209219
}
210220
Type::Tuple(tuple_type) => {
211221
// Note: This should only account for tuples of known length, i.e., `tuple[bool, ...]`

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use itertools::Either;
12
use ruff_python_ast::name::Name;
23
use rustc_hash::FxHashMap;
34

@@ -7,7 +8,7 @@ use crate::{
78
semantic_index::{place_table, use_def_map},
89
types::{
910
ClassLiteral, DynamicType, EnumLiteralType, KnownClass, MemberLookupPolicy, Type,
10-
TypeQualifiers, UnionType,
11+
TypeQualifiers,
1112
},
1213
};
1314

@@ -221,21 +222,20 @@ pub(crate) fn enum_metadata<'db>(
221222
Some(EnumMetadata { members, aliases })
222223
}
223224

224-
pub(crate) fn expand_enum_to_member_union<'db>(
225+
pub(crate) fn enum_member_literals<'a, 'db: 'a>(
225226
db: &'db dyn Db,
226227
class: ClassLiteral<'db>,
227-
exclude_member: Option<&Name>,
228-
) -> Type<'db> {
229-
enum_metadata(db, class)
230-
.as_ref()
231-
.map_or(Type::Never, |metadata| {
232-
UnionType::from_elements(
233-
db,
234-
metadata
235-
.members
236-
.iter()
237-
.filter(|name| Some(*name) != exclude_member)
238-
.map(|name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone()))),
239-
)
240-
})
228+
exclude_member: Option<&'a Name>,
229+
) -> impl Iterator<Item = Type<'a>> + 'a {
230+
if let Some(metadata) = enum_metadata(db, class) {
231+
Either::Left(
232+
metadata
233+
.members
234+
.iter()
235+
.filter(move |name| Some(*name) != exclude_member)
236+
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone()))),
237+
)
238+
} else {
239+
Either::Right(std::iter::empty())
240+
}
241241
}

0 commit comments

Comments
 (0)