Skip to content

Commit

Permalink
Add a helper function for narrowing list[Expression] to `list[StrEx…
Browse files Browse the repository at this point in the history
…pr]` (#14877)

There are several places in the code base where we need to narrow
`list[Expression]` -> `list[StrExpr]`. Currently we do this using
`cast`s, but `TypeGuard`s are arguably a much more idiomatic way to do
this kind of operation nowadays.
  • Loading branch information
AlexWaygood authored Mar 11, 2023
1 parent 32dd3c0 commit 106d57e
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
6 changes: 5 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Union,
cast,
)
from typing_extensions import Final, TypeAlias as _TypeAlias
from typing_extensions import Final, TypeAlias as _TypeAlias, TypeGuard

from mypy_extensions import trait

Expand Down Expand Up @@ -1635,6 +1635,10 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_str_expr(self)


def is_StrExpr_list(seq: list[Expression]) -> TypeGuard[list[StrExpr]]:
return all(isinstance(item, StrExpr) for item in seq)


class BytesExpr(Expression):
"""Bytes literal"""

Expand Down
5 changes: 3 additions & 2 deletions mypy/semanal_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TupleExpr,
TypeInfo,
Var,
is_StrExpr_list,
)
from mypy.options import Options
from mypy.semanal_shared import SemanticAnalyzerInterface
Expand Down Expand Up @@ -177,8 +178,8 @@ def parse_enum_call_args(
items.append(field)
elif isinstance(names, (TupleExpr, ListExpr)):
seq_items = names.items
if all(isinstance(seq_item, StrExpr) for seq_item in seq_items):
items = [cast(StrExpr, seq_item).value for seq_item in seq_items]
if is_StrExpr_list(seq_items):
items = [seq_item.value for seq_item in seq_items]
elif all(
isinstance(seq_item, (TupleExpr, ListExpr))
and len(seq_item.items) == 2
Expand Down
5 changes: 3 additions & 2 deletions mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TypeInfo,
TypeVarExpr,
Var,
is_StrExpr_list,
)
from mypy.options import Options
from mypy.semanal_shared import (
Expand Down Expand Up @@ -392,10 +393,10 @@ def parse_namedtuple_args(
listexpr = args[1]
if fullname == "collections.namedtuple":
# The fields argument contains just names, with implicit Any types.
if any(not isinstance(item, StrExpr) for item in listexpr.items):
if not is_StrExpr_list(listexpr.items):
self.fail('String literal expected as "namedtuple()" item', call)
return None
items = [cast(StrExpr, item).value for item in listexpr.items]
items = [item.value for item in listexpr.items]
else:
type_exprs = [
t.items[1]
Expand Down
6 changes: 4 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import sys
import traceback
from collections import defaultdict
from typing import Iterable, List, Mapping, cast
from typing import Iterable, List, Mapping
from typing_extensions import Final

import mypy.build
Expand Down Expand Up @@ -102,6 +102,7 @@
TupleExpr,
TypeInfo,
UnaryExpr,
is_StrExpr_list,
)
from mypy.options import Options as MypyOptions
from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures
Expand Down Expand Up @@ -1052,7 +1053,8 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
if isinstance(rvalue.args[1], StrExpr):
items = rvalue.args[1].value.replace(",", " ").split()
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
list_items = cast(List[StrExpr], rvalue.args[1].items)
list_items = rvalue.args[1].items
assert is_StrExpr_list(list_items)
items = [item.value for item in list_items]
else:
self.add(f"{self._indent}{lvalue.name}: Incomplete")
Expand Down

0 comments on commit 106d57e

Please sign in to comment.