Skip to content

Commit 1e9a823

Browse files
Add SequenceValue for heterogeneous sequences (#515)
PEP 646 requires that we support heterogeneous, variable-length sequences like tuple[int, *tuple[str, ...]]. This requires replacing the SequenceIncompleteValue class, so we're doing this in steps: 1. Add SequenceValue and accept it in all places that currently accept SequenceIncompleteValue, but don't infer it in any existing contexts (this PR). 2. Update internal usages of SequenceIncompleteValue to also accept SequenceValue. 3. Infer only SequenceValue inside pyanalyze. 4. Remove internal usage of SequenceIncompleteValue. 5. Drop SequenceIncompleteValue.
1 parent bf6e792 commit 1e9a823

14 files changed

+757
-38
lines changed

pyanalyze/annotations.py

+142-19
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
ParamSpecKwargsValue,
8383
ParameterTypeGuardExtension,
8484
SelfTVV,
85+
SequenceValue,
8586
TypeGuardExtension,
8687
TypedValue,
8788
SequenceIncompleteValue,
@@ -344,14 +345,26 @@ def value_from_ast(
344345
return val
345346

346347

347-
def _type_from_ast(node: ast.AST, ctx: Context, is_typeddict: bool = False) -> Value:
348+
def _type_from_ast(
349+
node: ast.AST,
350+
ctx: Context,
351+
*,
352+
is_typeddict: bool = False,
353+
unpack_allowed: bool = False,
354+
) -> Value:
348355
val = value_from_ast(node, ctx)
349-
return _type_from_value(val, ctx, is_typeddict=is_typeddict)
356+
return _type_from_value(
357+
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
358+
)
350359

351360

352-
def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Value:
361+
def _type_from_runtime(
362+
val: Any, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False
363+
) -> Value:
353364
if isinstance(val, str):
354-
return _eval_forward_ref(val, ctx, is_typeddict=is_typeddict)
365+
return _eval_forward_ref(
366+
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
367+
)
355368
elif isinstance(val, tuple):
356369
# This happens under some Python versions for types
357370
# nested in tuples, e.g. on 3.6:
@@ -365,13 +378,17 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
365378
args = (val[1],)
366379
else:
367380
args = val[1:]
368-
return _value_of_origin_args(origin, args, val, ctx)
381+
return _value_of_origin_args(
382+
origin, args, val, ctx, unpack_allowed=unpack_allowed
383+
)
369384
elif GenericAlias is not None and isinstance(val, GenericAlias):
370385
origin = get_origin(val)
371386
args = get_args(val)
372387
if origin is tuple and not args:
373388
return SequenceIncompleteValue(tuple, [])
374-
return _value_of_origin_args(origin, args, val, ctx)
389+
return _value_of_origin_args(
390+
origin, args, val, ctx, unpack_allowed=origin is tuple
391+
)
375392
elif typing_inspect.is_literal_type(val):
376393
args = typing_inspect.get_args(val)
377394
if len(args) == 0:
@@ -393,7 +410,17 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
393410
elif len(args) == 1 and args[0] == ():
394411
return SequenceIncompleteValue(tuple, []) # empty tuple
395412
else:
396-
args_vals = [_type_from_runtime(arg, ctx) for arg in args]
413+
args_vals = [
414+
_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args
415+
]
416+
if any(isinstance(val, UnpackedValue) for val in args_vals):
417+
members = []
418+
for val in args_vals:
419+
if isinstance(val, UnpackedValue):
420+
members += val.elements
421+
else:
422+
members.append((False, val))
423+
return SequenceValue(tuple, members)
397424
return SequenceIncompleteValue(tuple, args_vals)
398425
elif is_instance_of_typing_name(val, "_TypedDictMeta"):
399426
required_keys = getattr(val, "__required_keys__", None)
@@ -434,7 +461,14 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
434461
args = typing_inspect.get_args(val)
435462
if getattr(val, "_special", False):
436463
args = [] # distinguish List from List[T] on 3.7 and 3.8
437-
return _value_of_origin_args(origin, args, val, ctx, is_typeddict=is_typeddict)
464+
return _value_of_origin_args(
465+
origin,
466+
args,
467+
val,
468+
ctx,
469+
is_typeddict=is_typeddict,
470+
unpack_allowed=unpack_allowed or origin is tuple or origin is Tuple,
471+
)
438472
elif typing_inspect.is_callable_type(val):
439473
args = typing_inspect.get_args(val)
440474
return _value_of_origin_args(Callable, args, val, ctx)
@@ -535,6 +569,13 @@ def _type_from_runtime(val: Any, ctx: Context, is_typeddict: bool = False) -> Va
535569
cls = "Required" if required else "NotRequired"
536570
ctx.show_error(f"{cls}[] used in unsupported context")
537571
return AnyValue(AnySource.error)
572+
# Also 3.6 only.
573+
elif is_instance_of_typing_name(val, "_Unpack"):
574+
if unpack_allowed:
575+
return _make_unpacked_value(_type_from_runtime(val.__type__, ctx), ctx)
576+
else:
577+
ctx.show_error("Unpack[] used in unsupported context")
578+
return AnyValue(AnySource.error)
538579
elif is_typing_name(val, "TypeAlias"):
539580
return AnyValue(AnySource.incomplete_annotation)
540581
elif is_typing_name(val, "TypedDict"):
@@ -638,28 +679,51 @@ def _get_typeddict_value(
638679
return required, val
639680

640681

641-
def _eval_forward_ref(val: str, ctx: Context, is_typeddict: bool = False) -> Value:
682+
def _eval_forward_ref(
683+
val: str, ctx: Context, *, is_typeddict: bool = False, unpack_allowed: bool = False
684+
) -> Value:
642685
try:
643686
tree = ast.parse(val, mode="eval")
644687
except SyntaxError:
645688
ctx.show_error(f"Syntax error in type annotation: {val}")
646689
return AnyValue(AnySource.error)
647690
else:
648-
return _type_from_ast(tree.body, ctx, is_typeddict=is_typeddict)
691+
return _type_from_ast(
692+
tree.body, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
693+
)
649694

650695

651-
def _type_from_value(value: Value, ctx: Context, is_typeddict: bool = False) -> Value:
696+
def _type_from_value(
697+
value: Value,
698+
ctx: Context,
699+
*,
700+
is_typeddict: bool = False,
701+
unpack_allowed: bool = False,
702+
) -> Value:
652703
if isinstance(value, KnownValue):
653-
return _type_from_runtime(value.val, ctx, is_typeddict=is_typeddict)
704+
return _type_from_runtime(
705+
value.val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
706+
)
654707
elif isinstance(value, TypeVarValue):
655708
return value
656709
elif isinstance(value, MultiValuedValue):
657-
return unite_values(*[_type_from_value(val, ctx) for val in value.vals])
710+
return unite_values(
711+
*[
712+
_type_from_value(
713+
val, ctx, is_typeddict=is_typeddict, unpack_allowed=unpack_allowed
714+
)
715+
for val in value.vals
716+
]
717+
)
658718
elif isinstance(value, AnnotatedValue):
659719
return _type_from_value(value.value, ctx)
660720
elif isinstance(value, _SubscriptedValue):
661721
return _type_from_subscripted_value(
662-
value.root, value.members, ctx, is_typeddict=is_typeddict
722+
value.root,
723+
value.members,
724+
ctx,
725+
is_typeddict=is_typeddict,
726+
unpack_allowed=unpack_allowed,
663727
)
664728
elif isinstance(value, AnyValue):
665729
return value
@@ -677,7 +741,9 @@ def _type_from_subscripted_value(
677741
root: Optional[Value],
678742
members: Sequence[Value],
679743
ctx: Context,
744+
*,
680745
is_typeddict: bool = False,
746+
unpack_allowed: bool = False,
681747
) -> Value:
682748
if isinstance(root, GenericValue):
683749
if len(root.args) == len(members):
@@ -690,7 +756,13 @@ def _type_from_subscripted_value(
690756
elif isinstance(root, MultiValuedValue):
691757
return unite_values(
692758
*[
693-
_type_from_subscripted_value(subval, members, ctx, is_typeddict)
759+
_type_from_subscripted_value(
760+
subval,
761+
members,
762+
ctx,
763+
is_typeddict=is_typeddict,
764+
unpack_allowed=unpack_allowed,
765+
)
694766
for subval in root.vals
695767
]
696768
)
@@ -729,9 +801,16 @@ def _type_from_subscripted_value(
729801
elif len(members) == 1 and members[0] == KnownValue(()):
730802
return SequenceIncompleteValue(tuple, [])
731803
else:
732-
return SequenceIncompleteValue(
733-
tuple, [_type_from_value(arg, ctx) for arg in members]
734-
)
804+
args = [_type_from_value(arg, ctx, unpack_allowed=True) for arg in members]
805+
if any(isinstance(val, UnpackedValue) for val in args):
806+
tuple_members = []
807+
for val in args:
808+
if isinstance(val, UnpackedValue):
809+
tuple_members += val.elements
810+
else:
811+
tuple_members.append((False, val))
812+
return SequenceValue(tuple, tuple_members)
813+
return SequenceIncompleteValue(tuple, args)
735814
elif root is typing.Optional:
736815
if len(members) != 1:
737816
ctx.show_error("Optional[] takes only one argument")
@@ -769,6 +848,14 @@ def _type_from_subscripted_value(
769848
ctx.show_error("NotRequired[] requires a single argument")
770849
return AnyValue(AnySource.error)
771850
return Pep655Value(False, _type_from_value(members[0], ctx))
851+
elif is_typing_name(root, "Unpack"):
852+
if not unpack_allowed:
853+
ctx.show_error("Unpack[] used in unsupported context")
854+
return AnyValue(AnySource.error)
855+
if len(members) != 1:
856+
ctx.show_error("Unpack requires a single argument")
857+
return AnyValue(AnySource.error)
858+
return _make_unpacked_value(_type_from_value(members[0], ctx), ctx)
772859
elif root is Callable or root is typing.Callable:
773860
if len(members) == 2:
774861
args, return_value = members
@@ -877,6 +964,11 @@ class Pep655Value(Value):
877964
value: Value
878965

879966

967+
@dataclass
968+
class UnpackedValue(Value):
969+
elements: Sequence[Tuple[bool, Value]]
970+
971+
880972
class _Visitor(ast.NodeVisitor):
881973
def __init__(self, ctx: Context) -> None:
882974
self.ctx = ctx
@@ -892,6 +984,12 @@ def visit_Subscript(self, node: ast.Subscript) -> Value:
892984
index = self.visit(node.slice)
893985
if isinstance(index, SequenceIncompleteValue):
894986
members = index.members
987+
elif isinstance(index, SequenceValue):
988+
members = index.get_member_sequence()
989+
if members is None:
990+
# TODO support unpacking here
991+
return AnyValue(AnySource.inference)
992+
members = tuple(members)
895993
else:
896994
members = (index,)
897995
return _SubscriptedValue(value, members)
@@ -1047,7 +1145,9 @@ def _value_of_origin_args(
10471145
args: Sequence[object],
10481146
val: object,
10491147
ctx: Context,
1148+
*,
10501149
is_typeddict: bool = False,
1150+
unpack_allowed: bool = False,
10511151
) -> Value:
10521152
if origin is typing.Type or origin is type:
10531153
if not args:
@@ -1061,7 +1161,9 @@ def _value_of_origin_args(
10611161
elif len(args) == 1 and args[0] == ():
10621162
return SequenceIncompleteValue(tuple, [])
10631163
else:
1064-
args_vals = [_type_from_runtime(arg, ctx) for arg in args]
1164+
args_vals = [
1165+
_type_from_runtime(arg, ctx, unpack_allowed=True) for arg in args
1166+
]
10651167
return SequenceIncompleteValue(tuple, args_vals)
10661168
elif origin is typing.Union:
10671169
return unite_values(*[_type_from_runtime(arg, ctx) for arg in args])
@@ -1126,6 +1228,14 @@ def _value_of_origin_args(
11261228
ctx.show_error("NotRequired[] requires a single argument")
11271229
return AnyValue(AnySource.error)
11281230
return Pep655Value(False, _type_from_runtime(args[0], ctx))
1231+
elif is_typing_name(origin, "Unpack"):
1232+
if not unpack_allowed:
1233+
ctx.show_error("Invalid usage of Unpack")
1234+
return AnyValue(AnySource.error)
1235+
if len(args) != 1:
1236+
ctx.show_error("Unpack requires a single argument")
1237+
return AnyValue(AnySource.error)
1238+
return _make_unpacked_value(_type_from_runtime(args[0], ctx), ctx)
11291239
elif origin is None and isinstance(val, type):
11301240
# This happens for SupportsInt in 3.7.
11311241
return _maybe_typed_value(val)
@@ -1144,6 +1254,19 @@ def _maybe_typed_value(val: Union[type, str]) -> Value:
11441254
return TypedValue(val)
11451255

11461256

1257+
def _make_unpacked_value(val: Value, ctx: Context) -> UnpackedValue:
1258+
if isinstance(val, SequenceValue) and val.typ is tuple:
1259+
return UnpackedValue(val.members)
1260+
elif isinstance(val, SequenceIncompleteValue) and val.typ is tuple:
1261+
return UnpackedValue([(False, elt) for elt in val.members])
1262+
elif isinstance(val, GenericValue) and val.typ is tuple:
1263+
return UnpackedValue([(True, val.args[0])])
1264+
elif isinstance(val, TypedValue) and val.typ is tuple:
1265+
return UnpackedValue([(True, AnyValue(AnySource.generic_argument))])
1266+
ctx.show_error(f"Invalid argument for Unpack: {val}")
1267+
return UnpackedValue([])
1268+
1269+
11471270
def _make_callable_from_value(
11481271
args: Value, return_value: Value, ctx: Context, is_asynq: bool = False
11491272
) -> Value:

pyanalyze/boolability.py

+18
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
KnownValue,
2121
MultiValuedValue,
2222
SequenceIncompleteValue,
23+
SequenceValue,
2324
SubclassValue,
2425
TypedDictValue,
2526
TypedValue,
@@ -125,6 +126,23 @@ def _get_boolability_no_mvv(value: Value) -> Boolability:
125126
return Boolability.value_always_true_mutable
126127
else:
127128
return Boolability.value_always_false_mutable
129+
elif isinstance(value, SequenceValue):
130+
if not value.members:
131+
if value.typ is tuple:
132+
return Boolability.value_always_false
133+
else:
134+
return Boolability.value_always_false_mutable
135+
may_be_empty = all(is_many for is_many, _ in value.members)
136+
if may_be_empty:
137+
return Boolability.boolable
138+
if value.typ is tuple:
139+
# We lie slightly here, since at the type level a tuple
140+
# may be false. But tuples are a common source of boolability
141+
# bugs and they're rarely mutated, so we put a stronger
142+
# condition on them.
143+
return Boolability.type_always_true
144+
else:
145+
return Boolability.value_always_true_mutable
128146
elif isinstance(value, DictIncompleteValue):
129147
if any(pair.is_required and not pair.is_many for pair in value.kv_pairs):
130148
return Boolability.value_always_true_mutable

pyanalyze/format_strings.py

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
KnownValue,
3131
DictIncompleteValue,
3232
SequenceIncompleteValue,
33+
SequenceValue,
3334
TypedValue,
3435
Value,
3536
flatten_values,
@@ -370,6 +371,10 @@ def accept_tuple_args_no_mvv(
370371
args = replace_known_sequence_value(args)
371372
if isinstance(args, SequenceIncompleteValue):
372373
all_args = args.members
374+
elif isinstance(args, SequenceValue):
375+
all_args = args.get_member_sequence()
376+
if all_args is None:
377+
return
373378
else:
374379
# it's a tuple but we don't know what's in it, so assume it's ok
375380
return

0 commit comments

Comments
 (0)