diff --git a/mypy/nodes.py b/mypy/nodes.py index e4d8514ad6e2..af9f941e6aad 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -20,7 +20,7 @@ Union, cast, ) -from typing_extensions import Final, TypeAlias as _TypeAlias +from typing_extensions import Final, TypeAlias as _TypeAlias, TypeGuard from mypy_extensions import trait @@ -1635,6 +1635,10 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_str_expr(self) +def is_StrExpr_list(seq: list[Expression]) -> TypeGuard[list[StrExpr]]: + return all(isinstance(item, StrExpr) for item in seq) + + class BytesExpr(Expression): """Bytes literal""" diff --git a/mypy/semanal_enum.py b/mypy/semanal_enum.py index c7b8e44f65aa..efb9764545eb 100644 --- a/mypy/semanal_enum.py +++ b/mypy/semanal_enum.py @@ -27,6 +27,7 @@ TupleExpr, TypeInfo, Var, + is_StrExpr_list, ) from mypy.options import Options from mypy.semanal_shared import SemanticAnalyzerInterface @@ -177,8 +178,8 @@ def parse_enum_call_args( items.append(field) elif isinstance(names, (TupleExpr, ListExpr)): seq_items = names.items - if all(isinstance(seq_item, StrExpr) for seq_item in seq_items): - items = [cast(StrExpr, seq_item).value for seq_item in seq_items] + if is_StrExpr_list(seq_items): + items = [seq_item.value for seq_item in seq_items] elif all( isinstance(seq_item, (TupleExpr, ListExpr)) and len(seq_item.items) == 2 diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index 1194557836b1..bfd73eca59ef 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -41,6 +41,7 @@ TypeInfo, TypeVarExpr, Var, + is_StrExpr_list, ) from mypy.options import Options from mypy.semanal_shared import ( @@ -392,10 +393,10 @@ def parse_namedtuple_args( listexpr = args[1] if fullname == "collections.namedtuple": # The fields argument contains just names, with implicit Any types. - if any(not isinstance(item, StrExpr) for item in listexpr.items): + if not is_StrExpr_list(listexpr.items): self.fail('String literal expected as "namedtuple()" item', call) return None - items = [cast(StrExpr, item).value for item in listexpr.items] + items = [item.value for item in listexpr.items] else: type_exprs = [ t.items[1] diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 6cb4669887fe..212d934a11b7 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -48,7 +48,7 @@ import sys import traceback from collections import defaultdict -from typing import Iterable, List, Mapping, cast +from typing import Iterable, List, Mapping from typing_extensions import Final import mypy.build @@ -102,6 +102,7 @@ TupleExpr, TypeInfo, UnaryExpr, + is_StrExpr_list, ) from mypy.options import Options as MypyOptions from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures @@ -1052,7 +1053,8 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: if isinstance(rvalue.args[1], StrExpr): items = rvalue.args[1].value.replace(",", " ").split() elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)): - list_items = cast(List[StrExpr], rvalue.args[1].items) + list_items = rvalue.args[1].items + assert is_StrExpr_list(list_items) items = [item.value for item in list_items] else: self.add(f"{self._indent}{lvalue.name}: Incomplete")