Skip to content

Commit 66e9d57

Browse files
authored
[ty] Support legacy typing special forms in implicit type aliases (#21433)
## Summary Support various legacy `typing` special forms (`List`, `Dict`, …) in implicit type aliases. ## Ecosystem impact A lot of true positives (e.g. on `alerta`)! ## Test Plan New Markdown tests
1 parent 87dafb8 commit 66e9d57

File tree

2 files changed

+307
-8
lines changed

2 files changed

+307
-8
lines changed

crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md

Lines changed: 218 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,21 @@ def _(
680680
Invalid uses result in diagnostics:
681681

682682
```py
683+
from typing import Literal
684+
683685
# error: [invalid-type-form]
684-
InvalidSubclass = type[1]
686+
InvalidSubclassOf1 = type[1]
687+
688+
# TODO: This should be an error
689+
InvalidSubclassOfLiteral = type[Literal[42]]
690+
691+
def _(
692+
invalid_subclass_of_1: InvalidSubclassOf1,
693+
invalid_subclass_of_literal: InvalidSubclassOfLiteral,
694+
):
695+
reveal_type(invalid_subclass_of_1) # revealed: type[Unknown]
696+
# TODO: this should be `type[Unknown]` or `Unknown`
697+
reveal_type(invalid_subclass_of_literal) # revealed: <class 'int'>
685698
```
686699

687700
### `Type[…]`
@@ -759,6 +772,178 @@ Invalid uses result in diagnostics:
759772
InvalidSubclass = Type[1]
760773
```
761774

775+
## Other `typing` special forms
776+
777+
The following special forms from the `typing` module are also supported in implicit type aliases:
778+
779+
```py
780+
from typing import List, Dict, Set, FrozenSet, ChainMap, Counter, DefaultDict, Deque, OrderedDict
781+
782+
MyList = List[str]
783+
MySet = Set[str]
784+
MyDict = Dict[str, int]
785+
MyFrozenSet = FrozenSet[str]
786+
MyChainMap = ChainMap[str, int]
787+
MyCounter = Counter[str]
788+
MyDefaultDict = DefaultDict[str, int]
789+
MyDeque = Deque[str]
790+
MyOrderedDict = OrderedDict[str, int]
791+
792+
reveal_type(MyList) # revealed: <class 'list[str]'>
793+
reveal_type(MySet) # revealed: <class 'set[str]'>
794+
reveal_type(MyDict) # revealed: <class 'dict[str, int]'>
795+
reveal_type(MyFrozenSet) # revealed: <class 'frozenset[str]'>
796+
reveal_type(MyChainMap) # revealed: <class 'ChainMap[str, int]'>
797+
reveal_type(MyCounter) # revealed: <class 'Counter[str]'>
798+
reveal_type(MyDefaultDict) # revealed: <class 'defaultdict[str, int]'>
799+
reveal_type(MyDeque) # revealed: <class 'deque[str]'>
800+
reveal_type(MyOrderedDict) # revealed: <class 'OrderedDict[str, int]'>
801+
802+
def _(
803+
my_list: MyList,
804+
my_set: MySet,
805+
my_dict: MyDict,
806+
my_frozen_set: MyFrozenSet,
807+
my_chain_map: MyChainMap,
808+
my_counter: MyCounter,
809+
my_default_dict: MyDefaultDict,
810+
my_deque: MyDeque,
811+
my_ordered_dict: MyOrderedDict,
812+
):
813+
reveal_type(my_list) # revealed: list[str]
814+
reveal_type(my_set) # revealed: set[str]
815+
reveal_type(my_dict) # revealed: dict[str, int]
816+
reveal_type(my_frozen_set) # revealed: frozenset[str]
817+
reveal_type(my_chain_map) # revealed: ChainMap[str, int]
818+
reveal_type(my_counter) # revealed: Counter[str]
819+
reveal_type(my_default_dict) # revealed: defaultdict[str, int]
820+
reveal_type(my_deque) # revealed: deque[str]
821+
reveal_type(my_ordered_dict) # revealed: OrderedDict[str, int]
822+
```
823+
824+
All of them are supported in unions:
825+
826+
```py
827+
NoneOrList = None | List[str]
828+
NoneOrSet = None | Set[str]
829+
NoneOrDict = None | Dict[str, int]
830+
NoneOrFrozenSet = None | FrozenSet[str]
831+
NoneOrChainMap = None | ChainMap[str, int]
832+
NoneOrCounter = None | Counter[str]
833+
NoneOrDefaultDict = None | DefaultDict[str, int]
834+
NoneOrDeque = None | Deque[str]
835+
NoneOrOrderedDict = None | OrderedDict[str, int]
836+
837+
ListOrNone = List[int] | None
838+
SetOrNone = Set[int] | None
839+
DictOrNone = Dict[str, int] | None
840+
FrozenSetOrNone = FrozenSet[int] | None
841+
ChainMapOrNone = ChainMap[str, int] | None
842+
CounterOrNone = Counter[str] | None
843+
DefaultDictOrNone = DefaultDict[str, int] | None
844+
DequeOrNone = Deque[str] | None
845+
OrderedDictOrNone = OrderedDict[str, int] | None
846+
847+
reveal_type(NoneOrList) # revealed: types.UnionType
848+
reveal_type(NoneOrSet) # revealed: types.UnionType
849+
reveal_type(NoneOrDict) # revealed: types.UnionType
850+
reveal_type(NoneOrFrozenSet) # revealed: types.UnionType
851+
reveal_type(NoneOrChainMap) # revealed: types.UnionType
852+
reveal_type(NoneOrCounter) # revealed: types.UnionType
853+
reveal_type(NoneOrDefaultDict) # revealed: types.UnionType
854+
reveal_type(NoneOrDeque) # revealed: types.UnionType
855+
reveal_type(NoneOrOrderedDict) # revealed: types.UnionType
856+
857+
reveal_type(ListOrNone) # revealed: types.UnionType
858+
reveal_type(SetOrNone) # revealed: types.UnionType
859+
reveal_type(DictOrNone) # revealed: types.UnionType
860+
reveal_type(FrozenSetOrNone) # revealed: types.UnionType
861+
reveal_type(ChainMapOrNone) # revealed: types.UnionType
862+
reveal_type(CounterOrNone) # revealed: types.UnionType
863+
reveal_type(DefaultDictOrNone) # revealed: types.UnionType
864+
reveal_type(DequeOrNone) # revealed: types.UnionType
865+
reveal_type(OrderedDictOrNone) # revealed: types.UnionType
866+
867+
def _(
868+
none_or_list: NoneOrList,
869+
none_or_set: NoneOrSet,
870+
none_or_dict: NoneOrDict,
871+
none_or_frozen_set: NoneOrFrozenSet,
872+
none_or_chain_map: NoneOrChainMap,
873+
none_or_counter: NoneOrCounter,
874+
none_or_default_dict: NoneOrDefaultDict,
875+
none_or_deque: NoneOrDeque,
876+
none_or_ordered_dict: NoneOrOrderedDict,
877+
list_or_none: ListOrNone,
878+
set_or_none: SetOrNone,
879+
dict_or_none: DictOrNone,
880+
frozen_set_or_none: FrozenSetOrNone,
881+
chain_map_or_none: ChainMapOrNone,
882+
counter_or_none: CounterOrNone,
883+
default_dict_or_none: DefaultDictOrNone,
884+
deque_or_none: DequeOrNone,
885+
ordered_dict_or_none: OrderedDictOrNone,
886+
):
887+
reveal_type(none_or_list) # revealed: None | list[str]
888+
reveal_type(none_or_set) # revealed: None | set[str]
889+
reveal_type(none_or_dict) # revealed: None | dict[str, int]
890+
reveal_type(none_or_frozen_set) # revealed: None | frozenset[str]
891+
reveal_type(none_or_chain_map) # revealed: None | ChainMap[str, int]
892+
reveal_type(none_or_counter) # revealed: None | Counter[str]
893+
reveal_type(none_or_default_dict) # revealed: None | defaultdict[str, int]
894+
reveal_type(none_or_deque) # revealed: None | deque[str]
895+
reveal_type(none_or_ordered_dict) # revealed: None | OrderedDict[str, int]
896+
897+
reveal_type(list_or_none) # revealed: list[int] | None
898+
reveal_type(set_or_none) # revealed: set[int] | None
899+
reveal_type(dict_or_none) # revealed: dict[str, int] | None
900+
reveal_type(frozen_set_or_none) # revealed: frozenset[int] | None
901+
reveal_type(chain_map_or_none) # revealed: ChainMap[str, int] | None
902+
reveal_type(counter_or_none) # revealed: Counter[str] | None
903+
reveal_type(default_dict_or_none) # revealed: defaultdict[str, int] | None
904+
reveal_type(deque_or_none) # revealed: deque[str] | None
905+
reveal_type(ordered_dict_or_none) # revealed: OrderedDict[str, int] | None
906+
```
907+
908+
Invalid uses result in diagnostics:
909+
910+
```py
911+
from typing import List, Dict
912+
913+
# error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"
914+
InvalidList = List[1]
915+
916+
# error: [invalid-type-form] "`typing.typing.List` requires exactly one argument"
917+
ListTooManyArgs = List[int, str]
918+
919+
# error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"
920+
InvalidDict1 = Dict[1, str]
921+
922+
# error: [invalid-type-form] "Int literals are not allowed in this context in a type expression"
923+
InvalidDict2 = Dict[str, 2]
924+
925+
# error: [invalid-type-form] "`typing.typing.Dict` requires exactly two arguments, got 1"
926+
DictTooFewArgs = Dict[str]
927+
928+
# error: [invalid-type-form] "`typing.typing.Dict` requires exactly two arguments, got 3"
929+
DictTooManyArgs = Dict[str, int, float]
930+
931+
def _(
932+
invalid_list: InvalidList,
933+
list_too_many_args: ListTooManyArgs,
934+
invalid_dict1: InvalidDict1,
935+
invalid_dict2: InvalidDict2,
936+
dict_too_few_args: DictTooFewArgs,
937+
dict_too_many_args: DictTooManyArgs,
938+
):
939+
reveal_type(invalid_list) # revealed: list[Unknown]
940+
reveal_type(list_too_many_args) # revealed: list[Unknown]
941+
reveal_type(invalid_dict1) # revealed: dict[Unknown, str]
942+
reveal_type(invalid_dict2) # revealed: dict[str, Unknown]
943+
reveal_type(dict_too_few_args) # revealed: dict[str, Unknown]
944+
reveal_type(dict_too_many_args) # revealed: dict[Unknown, Unknown]
945+
```
946+
762947
## Stringified annotations?
763948

764949
From the [typing spec on type aliases](https://typing.python.org/en/latest/spec/aliases.html):
@@ -789,22 +974,28 @@ We *do* support stringified annotations if they appear in a position where a typ
789974
syntactically expected:
790975

791976
```py
792-
from typing import Union
977+
from typing import Union, List, Dict
793978

794-
ListOfInts = list["int"]
979+
ListOfInts1 = list["int"]
980+
ListOfInts2 = List["int"]
795981
StrOrStyle = Union[str, "Style"]
796982
SubclassOfStyle = type["Style"]
983+
DictStrToStyle = Dict[str, "Style"]
797984

798985
class Style: ...
799986

800987
def _(
801-
list_of_ints: ListOfInts,
988+
list_of_ints1: ListOfInts1,
989+
list_of_ints2: ListOfInts2,
802990
str_or_style: StrOrStyle,
803991
subclass_of_style: SubclassOfStyle,
992+
dict_str_to_style: DictStrToStyle,
804993
):
805-
reveal_type(list_of_ints) # revealed: list[int]
994+
reveal_type(list_of_ints1) # revealed: list[int]
995+
reveal_type(list_of_ints2) # revealed: list[int]
806996
reveal_type(str_or_style) # revealed: str | Style
807997
reveal_type(subclass_of_style) # revealed: type[Style]
998+
reveal_type(dict_str_to_style) # revealed: dict[str, Style]
808999
```
8091000

8101001
## Recursive
@@ -828,8 +1019,27 @@ python-version = "3.12"
8281019
```
8291020

8301021
```py
831-
Recursive = list["Recursive" | None]
1022+
from typing import List, Dict
8321023

833-
def _(r: Recursive):
834-
reveal_type(r) # revealed: list[Divergent]
1024+
RecursiveList1 = list["RecursiveList1" | None]
1025+
RecursiveList2 = List["RecursiveList2" | None]
1026+
RecursiveDict1 = dict[str, "RecursiveDict1" | None]
1027+
RecursiveDict2 = Dict[str, "RecursiveDict2" | None]
1028+
RecursiveDict3 = dict["RecursiveDict3", int]
1029+
RecursiveDict4 = Dict["RecursiveDict4", int]
1030+
1031+
def _(
1032+
recursive_list1: RecursiveList1,
1033+
recursive_list2: RecursiveList2,
1034+
recursive_dict1: RecursiveDict1,
1035+
recursive_dict2: RecursiveDict2,
1036+
recursive_dict3: RecursiveDict3,
1037+
recursive_dict4: RecursiveDict4,
1038+
):
1039+
reveal_type(recursive_list1) # revealed: list[Divergent]
1040+
reveal_type(recursive_list2) # revealed: list[Divergent]
1041+
reveal_type(recursive_dict1) # revealed: dict[str, Divergent]
1042+
reveal_type(recursive_dict2) # revealed: dict[str, Divergent]
1043+
reveal_type(recursive_dict3) # revealed: dict[Divergent, int]
1044+
reveal_type(recursive_dict4) # revealed: dict[Divergent, int]
8351045
```

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10779,6 +10779,95 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
1077910779
InternedType::new(self.db(), argument_ty),
1078010780
));
1078110781
}
10782+
// `typing` special forms with a single generic argument
10783+
Type::SpecialForm(
10784+
special_form @ (SpecialFormType::List
10785+
| SpecialFormType::Set
10786+
| SpecialFormType::FrozenSet
10787+
| SpecialFormType::Counter
10788+
| SpecialFormType::Deque),
10789+
) => {
10790+
let slice_ty = self.infer_type_expression(slice);
10791+
10792+
let element_ty = if matches!(**slice, ast::Expr::Tuple(_)) {
10793+
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
10794+
builder.into_diagnostic(format_args!(
10795+
"`typing.{}` requires exactly one argument",
10796+
special_form.repr()
10797+
));
10798+
}
10799+
Type::unknown()
10800+
} else {
10801+
slice_ty
10802+
};
10803+
10804+
let class = special_form
10805+
.aliased_stdlib_class()
10806+
.expect("A known stdlib class is available");
10807+
10808+
return class
10809+
.to_specialized_class_type(self.db(), [element_ty])
10810+
.map(Type::from)
10811+
.unwrap_or_else(Type::unknown);
10812+
}
10813+
// `typing` special forms with two generic arguments
10814+
Type::SpecialForm(
10815+
special_form @ (SpecialFormType::Dict
10816+
| SpecialFormType::ChainMap
10817+
| SpecialFormType::DefaultDict
10818+
| SpecialFormType::OrderedDict),
10819+
) => {
10820+
let (first_ty, second_ty) = if let ast::Expr::Tuple(ast::ExprTuple {
10821+
elts: ref arguments,
10822+
..
10823+
}) = **slice
10824+
{
10825+
if arguments.len() != 2 {
10826+
if let Some(builder) =
10827+
self.context.report_lint(&INVALID_TYPE_FORM, subscript)
10828+
{
10829+
builder.into_diagnostic(format_args!(
10830+
"`typing.{}` requires exactly two arguments, got {}",
10831+
special_form.repr(),
10832+
arguments.len()
10833+
));
10834+
}
10835+
}
10836+
10837+
if let [first_expr, second_expr] = &arguments[..] {
10838+
let first_ty = self.infer_type_expression(first_expr);
10839+
let second_ty = self.infer_type_expression(second_expr);
10840+
10841+
(first_ty, second_ty)
10842+
} else {
10843+
for argument in arguments {
10844+
self.infer_type_expression(argument);
10845+
}
10846+
10847+
(Type::unknown(), Type::unknown())
10848+
}
10849+
} else {
10850+
let first_ty = self.infer_type_expression(slice);
10851+
10852+
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
10853+
builder.into_diagnostic(format_args!(
10854+
"`typing.{}` requires exactly two arguments, got 1",
10855+
special_form.repr()
10856+
));
10857+
}
10858+
10859+
(first_ty, Type::unknown())
10860+
};
10861+
10862+
let class = special_form
10863+
.aliased_stdlib_class()
10864+
.expect("Stdlib class available");
10865+
10866+
return class
10867+
.to_specialized_class_type(self.db(), [first_ty, second_ty])
10868+
.map(Type::from)
10869+
.unwrap_or_else(Type::unknown);
10870+
}
1078210871
_ => {}
1078310872
}
1078410873

0 commit comments

Comments
 (0)