Skip to content

Commit

Permalink
Rework the check for an out-of-order opcode in a match block.
Browse files Browse the repository at this point in the history
Since we only care about execution order for the purposes of seeing whether we
have truly seen all of a match block, check for that specifically rather than
trying to come up with some more general heuristics for detecting any
out-of-order opcodes.

PiperOrigin-RevId: 605197589
  • Loading branch information
martindemello authored and rchen152 committed Feb 9, 2024
1 parent eab9166 commit c50fa11
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 52 deletions.
17 changes: 0 additions & 17 deletions pytype/blocks/process_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,3 @@ def adjust_returns(code, block_returns):
new_line = next(lines, None)
if new_line:
op.line = new_line


def check_out_of_order(code):
"""Check if a line of code is executed out of order."""
# This sometimes happens due to compiler optimisations, and needs to be
# recorded so that we don't trigger code that is only meant to execute when
# the main flow of control reaches a certain line.
last_line = []
for block in code.order:
for op in block:
if not last_line or last_line[-1].line == op.line:
last_line.append(op)
else:
if op.line < last_line[-1].line:
for x in last_line:
x.metadata.is_out_of_order = True
last_line = [op]
84 changes: 51 additions & 33 deletions pytype/pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,19 +252,20 @@ class _Matches:
"""Tracks branches of match statements."""

def __init__(self, ast_matches):
self.start_to_end = {}
self.start_to_end = {} # match_line : match_end_line
self.end_to_starts = collections.defaultdict(list)
self.match_cases = {}
self.defaults = set()
self.as_names = {}
self.matches = []
self.match_cases = {} # opcode_line : match_line
self.defaults = set() # lines with defaults
self.as_names = {} # case_end_line : case_as_name
self.unseen_cases = {} # match_line : num_unseen_cases

for m in ast_matches.matches:
self._add_match(m.start, m.end, m.cases)

def _add_match(self, start, end, cases):
self.start_to_end[start] = end
self.end_to_starts[end].append(start)
self.unseen_cases[start] = len(cases)
for c in cases:
for i in range(c.start, c.end + 1):
self.match_cases[i] = start
Expand All @@ -273,6 +274,10 @@ def _add_match(self, start, end, cases):
if c.as_name:
self.as_names[c.end] = c.as_name

def register_case(self, match_line, case_line):
assert self.match_cases[case_line] == match_line
self.unseen_cases[match_line] -= 1

def __repr__(self):
return f"""
Matches: {sorted(self.start_to_end.items())}
Expand Down Expand Up @@ -301,10 +306,9 @@ def __init__(self, ast_matches, ctx):
self.ctx = ctx

def _get_option_tracker(
self, match_var: cfg.Variable, case_line: int
self, match_var: cfg.Variable, match_line: int
) -> _OptionTracker:
"""Get the option tracker for a match line."""
match_line = self.matches.match_cases[case_line]
if (match_line not in self._option_tracker or
match_var.id not in self._option_tracker[match_line]):
self._option_tracker[match_line][match_var.id] = (
Expand All @@ -323,8 +327,16 @@ def _make_instance_for_match(self, node, types):
ret.append(self.ctx.vm.init_class(node, cls))
return self.ctx.join_variables(node, ret)

def _register_case_branch(self, op: opcodes.Opcode) -> Optional[int]:
match_line = self.matches.match_cases.get(op.line)
if match_line is None:
return None
self.matches.register_case(match_line, op.line)
return match_line

def instantiate_case_var(self, op, match_var, node):
tracker = self._get_option_tracker(match_var, op.line)
match_line = self.matches.match_cases[op.line]
tracker = self._get_option_tracker(match_var, match_line)
if tracker.cases[op.line]:
# We have matched on one or more classes in this case.
types = [x.typ for x in tracker.cases[op.line]]
Expand Down Expand Up @@ -360,14 +372,16 @@ def register_match_type(self, op: opcodes.Opcode):
self._match_types[match_line].add(_MatchTypes.make(op))

def add_none_branch(self, op: opcodes.Opcode, match_var: cfg.Variable):
if op.line in self.matches.match_cases:
tracker = self._get_option_tracker(match_var, op.line)
tracker.cover_from_none(op.line)
if not tracker.is_complete:
return None
else:
# This is the last remaining case, and will always succeed.
return True
match_line = self._register_case_branch(op)
if not match_line:
return None
tracker = self._get_option_tracker(match_var, match_line)
tracker.cover_from_none(op.line)
if not tracker.is_complete:
return None
else:
# This is the last remaining case, and will always succeed.
return True

def add_cmp_branch(
self,
Expand All @@ -377,12 +391,13 @@ def add_cmp_branch(
case_var: cfg.Variable
) -> _MatchSuccessType:
"""Add a compare-based match case branch to the tracker."""
if cmp_type not in (slots.CMP_EQ, slots.CMP_IS):
match_line = self._register_case_branch(op)
if not match_line:
return None

match_line = self.matches.match_cases.get(op.line)
if not match_line:
if cmp_type not in (slots.CMP_EQ, slots.CMP_IS):
return None

match_type = self._match_types[match_line]

try:
Expand All @@ -403,7 +418,7 @@ def add_cmp_branch(
# (enum or union of literals) that we are tracking.
if not tracker:
if _is_literal_match(match_var) or _is_enum_match(match_var, case_val):
tracker = self._get_option_tracker(match_var, op.line)
tracker = self._get_option_tracker(match_var, match_line)

# If none of the above apply we cannot do any sort of tracking.
if not tracker:
Expand All @@ -425,32 +440,31 @@ def add_cmp_branch(
def add_class_branch(self, op: opcodes.Opcode, match_var: cfg.Variable,
case_var: cfg.Variable) -> _MatchSuccessType:
"""Add a class-based match case branch to the tracker."""
tracker = self._get_option_tracker(match_var, op.line)
match_line = self._register_case_branch(op)
if not match_line:
return None
tracker = self._get_option_tracker(match_var, match_line)
tracker.cover(op.line, case_var)
return tracker.is_complete or None

def add_default_branch(self, op: opcodes.Opcode) -> _MatchSuccessType:
"""Add a default match case branch to the tracker."""
match_line = self.matches.match_cases.get(op.line)
if match_line is None:
return None
if match_line in self._option_tracker:
for opt in self._option_tracker[match_line].values():
# We no longer check for exhaustive or redundant matches once we hit a
# default case.
opt.invalidate()
return True
else:
match_line = self._register_case_branch(op)
if not match_line or match_line not in self._option_tracker:
return None

for opt in self._option_tracker[match_line].values():
# We no longer check for exhaustive or redundant matches once we hit a
# default case.
opt.invalidate()
return True

def check_ending(
self,
op: opcodes.Opcode,
implicit_return: bool = False
) -> List[IncompleteMatch]:
"""Check if we have ended a match statement with leftover cases."""
if op.metadata.is_out_of_order:
return []
line = op.line
if implicit_return:
done = set()
Expand All @@ -464,6 +478,10 @@ def check_ending(
ret = []
for i in done:
for start in self.matches.end_to_starts[i]:
if self.matches.unseen_cases[start] > 0:
# We have executed some opcode out of order and thus gone past the end
# of the match block before seeing all case branches.
continue
trackers = self._option_tracker[start]
for tracker in trackers.values():
if tracker.is_valid:
Expand Down
30 changes: 29 additions & 1 deletion pytype/tests/test_pattern_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,7 +1559,7 @@ def f(self, x: A):
raise ValueError('foo')
""")

def test_optimized_bytecode_out_of_order(self):
def test_optimized_bytecode_out_of_order_1(self):
"""Regression test for a bug resulting from compiler optimisations."""
# Compier optimisations that inline code can put blocks out of order, which
# could potentially interfere with our checks for the end of a match block.
Expand All @@ -1584,6 +1584,34 @@ def test(color: Color):
return color
""")

def test_optimized_bytecode_out_of_order_2(self):
"""Regression test for a bug resulting from compiler optimisations."""
# See comment in the previous test case.
self.Check("""
import enum
class A(enum.Enum):
RED = 1
GREEN = 2
BLUE = 3
def f(x: A):
ret = True
count = 0
while count < 10:
match x:
case A.RED:
ret = 10
case A.BLUE:
ret = 20
case _:
return ret
if ret:
break
else:
count += 1
""")


@test_utils.skipBeforePy((3, 10), "New syntax in 3.10")
class LiteralMatchCoverageTest(test_base.BaseTest):
Expand Down
1 change: 0 additions & 1 deletion pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def _run_frame_blocks(self, frame, node, annotated_locals):
return_nodes = []
finally_tracker = vm_utils.FinallyStateTracker()
process_blocks.adjust_returns(frame.f_code, self._director.block_returns)
process_blocks.check_out_of_order(frame.f_code)
for block in frame.f_code.order:
state = frame.states.get(block[0])
if not state:
Expand Down

0 comments on commit c50fa11

Please sign in to comment.