Skip to content

Commit

Permalink
Fix Literal strings containing pipe characters (#17148)
Browse files Browse the repository at this point in the history
Fixes #16367

During semantic analysis, we try to parse all strings as types,
including those inside Literal[]. Previously, we preserved the original
string in the `UnboundType.original_str_expr` attribute, but if a type
is parsed as a Union, we didn't have a place to put the value.

This PR instead always wraps string types in a RawExpressionType node,
which now optionally includes a `.node` attribute containing the parsed
type. This way, we don't need to worry about preserving the original
string as a custom attribute on different kinds of types that can appear
in this context.

The downside is that more code needs to be aware of RawExpressionType.
  • Loading branch information
JelleZijlstra committed Apr 22, 2024
1 parent df35dcf commit 1072c78
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 90 deletions.
11 changes: 3 additions & 8 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,7 @@ def parse_type_string(
"""
try:
_, node = parse_type_comment(f"({expr_string})", line=line, column=column, errors=None)
if isinstance(node, UnboundType) and node.original_str_expr is None:
node.original_str_expr = expr_string
node.original_str_fallback = expr_fallback_name
return node
elif isinstance(node, UnionType):
return node
else:
return RawExpressionType(expr_string, expr_fallback_name, line, column)
return RawExpressionType(expr_string, expr_fallback_name, line, column, node=node)
except (SyntaxError, ValueError):
# Note: the parser will raise a `ValueError` instead of a SyntaxError if
# the string happens to contain things like \x00.
Expand Down Expand Up @@ -1034,6 +1027,8 @@ def set_type_optional(self, type: Type | None, initializer: Expression | None) -
return
# Indicate that type should be wrapped in an Optional if arg is initialized to None.
optional = isinstance(initializer, NameExpr) and initializer.name == "None"
if isinstance(type, RawExpressionType) and type.node is not None:
type = type.node
if isinstance(type, UnboundType):
type.optional = optional

Expand Down
31 changes: 18 additions & 13 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3231,10 +3231,10 @@ def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
def analyze_lvalues(self, s: AssignmentStmt) -> None:
# We cannot use s.type, because analyze_simple_literal_type() will set it.
explicit = s.unanalyzed_type is not None
if self.is_final_type(s.unanalyzed_type):
final_type = self.unwrap_final_type(s.unanalyzed_type)
if final_type is not None:
# We need to exclude bare Final.
assert isinstance(s.unanalyzed_type, UnboundType)
if not s.unanalyzed_type.args:
if not final_type.args:
explicit = False

if s.rvalue:
Expand Down Expand Up @@ -3300,19 +3300,19 @@ def unwrap_final(self, s: AssignmentStmt) -> bool:
Returns True if Final[...] was present.
"""
if not s.unanalyzed_type or not self.is_final_type(s.unanalyzed_type):
final_type = self.unwrap_final_type(s.unanalyzed_type)
if final_type is None:
return False
assert isinstance(s.unanalyzed_type, UnboundType)
if len(s.unanalyzed_type.args) > 1:
self.fail("Final[...] takes at most one type argument", s.unanalyzed_type)
if len(final_type.args) > 1:
self.fail("Final[...] takes at most one type argument", final_type)
invalid_bare_final = False
if not s.unanalyzed_type.args:
if not final_type.args:
s.type = None
if isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs:
invalid_bare_final = True
self.fail("Type in Final[...] can only be omitted if there is an initializer", s)
else:
s.type = s.unanalyzed_type.args[0]
s.type = final_type.args[0]

if s.type is not None and self.is_classvar(s.type):
self.fail("Variable should not be annotated with both ClassVar and Final", s)
Expand Down Expand Up @@ -4713,13 +4713,18 @@ def is_classvar(self, typ: Type) -> bool:
return False
return sym.node.fullname == "typing.ClassVar"

def is_final_type(self, typ: Type | None) -> bool:
def unwrap_final_type(self, typ: Type | None) -> UnboundType | None:
if typ is None:
return None
typ = typ.resolve_string_annotation()
if not isinstance(typ, UnboundType):
return False
return None
sym = self.lookup_qualified(typ.name, typ)
if not sym or not sym.node:
return False
return sym.node.fullname in FINAL_TYPE_NAMES
return None
if sym.node.fullname in FINAL_TYPE_NAMES:
return typ
return None

def fail_invalid_classvar(self, context: Context) -> None:
self.fail(message_registry.CLASS_VAR_OUTSIDE_OF_CLASS, context)
Expand Down
3 changes: 2 additions & 1 deletion mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,8 @@ def visit_typeddict_type(self, typ: TypedDictType) -> None:
typ.fallback.accept(self)

def visit_raw_expression_type(self, t: RawExpressionType) -> None:
pass
if t.node is not None:
t.node.accept(self)

def visit_literal_type(self, typ: LiteralType) -> None:
typ.fallback.accept(self)
Expand Down
16 changes: 12 additions & 4 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@
from mypy.modulefinder import ModuleNotFoundReason
from mypy.moduleinspect import InspectError, ModuleInspect
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.types import AnyType, NoneType, Type, TypeList, TypeStrVisitor, UnboundType, UnionType
from mypy.types import (
AnyType,
NoneType,
RawExpressionType,
Type,
TypeList,
TypeStrVisitor,
UnboundType,
UnionType,
)

# Modules that may fail when imported, or that may have side effects (fully qualified).
NOT_IMPORTABLE_MODULES = ()
Expand Down Expand Up @@ -291,12 +300,11 @@ def args_str(self, args: Iterable[Type]) -> str:
The main difference from list_str is the preservation of quotes for string
arguments
"""
types = ["builtins.bytes", "builtins.str"]
res = []
for arg in args:
arg_str = arg.accept(self)
if isinstance(arg, UnboundType) and arg.original_str_fallback in types:
res.append(f"'{arg_str}'")
if isinstance(arg, RawExpressionType):
res.append(repr(arg.literal_value))
else:
res.append(arg_str)
return ", ".join(res)
Expand Down
4 changes: 4 additions & 0 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> T:
return self.query_types(t.items.values())

def visit_raw_expression_type(self, t: RawExpressionType) -> T:
if t.node is not None:
return t.node.accept(self)
return self.strategy([])

def visit_literal_type(self, t: LiteralType) -> T:
Expand Down Expand Up @@ -516,6 +518,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> bool:
return self.query_types(list(t.items.values()))

def visit_raw_expression_type(self, t: RawExpressionType) -> bool:
if t.node is not None:
return t.node.accept(self)
return self.default

def visit_literal_type(self, t: LiteralType) -> bool:
Expand Down
21 changes: 8 additions & 13 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
return ret

def anal_type_guard(self, t: Type) -> Type | None:
t = t.resolve_string_annotation()
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym is not None and sym.node is not None:
Expand All @@ -1088,6 +1089,7 @@ def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Type | None:
return None

def anal_type_is(self, t: Type) -> Type | None:
t = t.resolve_string_annotation()
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym is not None and sym.node is not None:
Expand All @@ -1105,6 +1107,7 @@ def anal_type_is_arg(self, t: UnboundType, fullname: str) -> Type | None:

def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type:
"""Analyze signature argument type for *args and **kwargs argument."""
t = t.resolve_string_annotation()
if isinstance(t, UnboundType) and t.name and "." in t.name and not t.args:
components = t.name.split(".")
tvar_name = ".".join(components[:-1])
Expand Down Expand Up @@ -1195,6 +1198,8 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type:
# make signatures like "foo(x: 20) -> None" legal, we can change
# this method so it generates and returns an actual LiteralType
# instead.
if t.node is not None:
return t.node.accept(self)

if self.report_invalid_types:
if t.base_type_name in ("builtins.int", "builtins.bool"):
Expand Down Expand Up @@ -1455,6 +1460,7 @@ def analyze_callable_args(
invalid_unpacks: list[Type] = []
second_unpack_last = False
for i, arg in enumerate(arglist.items):
arg = arg.resolve_string_annotation()
if isinstance(arg, CallableArgument):
args.append(arg.typ)
names.append(arg.name)
Expand Down Expand Up @@ -1535,18 +1541,6 @@ def analyze_literal_type(self, t: UnboundType) -> Type:
return UnionType.make_union(output, line=t.line)

def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> list[Type] | None:
# This UnboundType was originally defined as a string.
if isinstance(arg, UnboundType) and arg.original_str_expr is not None:
assert arg.original_str_fallback is not None
return [
LiteralType(
value=arg.original_str_expr,
fallback=self.named_type(arg.original_str_fallback),
line=arg.line,
column=arg.column,
)
]

# If arg is an UnboundType that was *not* originally defined as
# a string, try expanding it in case it's a type alias or something.
if isinstance(arg, UnboundType):
Expand Down Expand Up @@ -2528,7 +2522,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> None:
self.process_types(list(t.items.values()))

def visit_raw_expression_type(self, t: RawExpressionType) -> None:
pass
if t.node is not None:
t.node.accept(self)

def visit_literal_type(self, t: LiteralType) -> None:
pass
Expand Down
Loading

0 comments on commit 1072c78

Please sign in to comment.