diff --git a/pytype/blocks/process_blocks.py b/pytype/blocks/process_blocks.py index 52183f328..80a6d5b9c 100644 --- a/pytype/blocks/process_blocks.py +++ b/pytype/blocks/process_blocks.py @@ -128,20 +128,3 @@ def adjust_returns(code, block_returns): new_line = next(lines, None) if new_line: op.line = new_line - - -def check_out_of_order(code): - """Check if a line of code is executed out of order.""" - # This sometimes happens due to compiler optimisations, and needs to be - # recorded so that we don't trigger code that is only meant to execute when - # the main flow of control reaches a certain line. - last_line = [] - for block in code.order: - for op in block: - if not last_line or last_line[-1].line == op.line: - last_line.append(op) - else: - if op.line < last_line[-1].line: - for x in last_line: - x.metadata.is_out_of_order = True - last_line = [op] diff --git a/pytype/pattern_matching.py b/pytype/pattern_matching.py index d4c464186..5fdd73d42 100644 --- a/pytype/pattern_matching.py +++ b/pytype/pattern_matching.py @@ -252,12 +252,12 @@ class _Matches: """Tracks branches of match statements.""" def __init__(self, ast_matches): - self.start_to_end = {} + self.start_to_end = {} # match_line : match_end_line self.end_to_starts = collections.defaultdict(list) - self.match_cases = {} - self.defaults = set() - self.as_names = {} - self.matches = [] + self.match_cases = {} # opcode_line : match_line + self.defaults = set() # lines with defaults + self.as_names = {} # case_end_line : case_as_name + self.unseen_cases = {} # match_line : num_unseen_cases for m in ast_matches.matches: self._add_match(m.start, m.end, m.cases) @@ -265,6 +265,7 @@ def __init__(self, ast_matches): def _add_match(self, start, end, cases): self.start_to_end[start] = end self.end_to_starts[end].append(start) + self.unseen_cases[start] = len(cases) for c in cases: for i in range(c.start, c.end + 1): self.match_cases[i] = start @@ -273,6 +274,10 @@ def _add_match(self, start, end, cases): if c.as_name: self.as_names[c.end] = c.as_name + def register_case(self, match_line, case_line): + assert self.match_cases[case_line] == match_line + self.unseen_cases[match_line] -= 1 + def __repr__(self): return f""" Matches: {sorted(self.start_to_end.items())} @@ -301,10 +306,9 @@ def __init__(self, ast_matches, ctx): self.ctx = ctx def _get_option_tracker( - self, match_var: cfg.Variable, case_line: int + self, match_var: cfg.Variable, match_line: int ) -> _OptionTracker: """Get the option tracker for a match line.""" - match_line = self.matches.match_cases[case_line] if (match_line not in self._option_tracker or match_var.id not in self._option_tracker[match_line]): self._option_tracker[match_line][match_var.id] = ( @@ -323,8 +327,16 @@ def _make_instance_for_match(self, node, types): ret.append(self.ctx.vm.init_class(node, cls)) return self.ctx.join_variables(node, ret) + def _register_case_branch(self, op: opcodes.Opcode) -> Optional[int]: + match_line = self.matches.match_cases.get(op.line) + if match_line is None: + return None + self.matches.register_case(match_line, op.line) + return match_line + def instantiate_case_var(self, op, match_var, node): - tracker = self._get_option_tracker(match_var, op.line) + match_line = self.matches.match_cases[op.line] + tracker = self._get_option_tracker(match_var, match_line) if tracker.cases[op.line]: # We have matched on one or more classes in this case. types = [x.typ for x in tracker.cases[op.line]] @@ -360,14 +372,16 @@ def register_match_type(self, op: opcodes.Opcode): self._match_types[match_line].add(_MatchTypes.make(op)) def add_none_branch(self, op: opcodes.Opcode, match_var: cfg.Variable): - if op.line in self.matches.match_cases: - tracker = self._get_option_tracker(match_var, op.line) - tracker.cover_from_none(op.line) - if not tracker.is_complete: - return None - else: - # This is the last remaining case, and will always succeed. - return True + match_line = self._register_case_branch(op) + if not match_line: + return None + tracker = self._get_option_tracker(match_var, match_line) + tracker.cover_from_none(op.line) + if not tracker.is_complete: + return None + else: + # This is the last remaining case, and will always succeed. + return True def add_cmp_branch( self, @@ -377,12 +391,13 @@ def add_cmp_branch( case_var: cfg.Variable ) -> _MatchSuccessType: """Add a compare-based match case branch to the tracker.""" - if cmp_type not in (slots.CMP_EQ, slots.CMP_IS): + match_line = self._register_case_branch(op) + if not match_line: return None - match_line = self.matches.match_cases.get(op.line) - if not match_line: + if cmp_type not in (slots.CMP_EQ, slots.CMP_IS): return None + match_type = self._match_types[match_line] try: @@ -403,7 +418,7 @@ def add_cmp_branch( # (enum or union of literals) that we are tracking. if not tracker: if _is_literal_match(match_var) or _is_enum_match(match_var, case_val): - tracker = self._get_option_tracker(match_var, op.line) + tracker = self._get_option_tracker(match_var, match_line) # If none of the above apply we cannot do any sort of tracking. if not tracker: @@ -425,32 +440,31 @@ def add_cmp_branch( def add_class_branch(self, op: opcodes.Opcode, match_var: cfg.Variable, case_var: cfg.Variable) -> _MatchSuccessType: """Add a class-based match case branch to the tracker.""" - tracker = self._get_option_tracker(match_var, op.line) + match_line = self._register_case_branch(op) + if not match_line: + return None + tracker = self._get_option_tracker(match_var, match_line) tracker.cover(op.line, case_var) return tracker.is_complete or None def add_default_branch(self, op: opcodes.Opcode) -> _MatchSuccessType: """Add a default match case branch to the tracker.""" - match_line = self.matches.match_cases.get(op.line) - if match_line is None: - return None - if match_line in self._option_tracker: - for opt in self._option_tracker[match_line].values(): - # We no longer check for exhaustive or redundant matches once we hit a - # default case. - opt.invalidate() - return True - else: + match_line = self._register_case_branch(op) + if not match_line or match_line not in self._option_tracker: return None + for opt in self._option_tracker[match_line].values(): + # We no longer check for exhaustive or redundant matches once we hit a + # default case. + opt.invalidate() + return True + def check_ending( self, op: opcodes.Opcode, implicit_return: bool = False ) -> List[IncompleteMatch]: """Check if we have ended a match statement with leftover cases.""" - if op.metadata.is_out_of_order: - return [] line = op.line if implicit_return: done = set() @@ -464,6 +478,10 @@ def check_ending( ret = [] for i in done: for start in self.matches.end_to_starts[i]: + if self.matches.unseen_cases[start] > 0: + # We have executed some opcode out of order and thus gone past the end + # of the match block before seeing all case branches. + continue trackers = self._option_tracker[start] for tracker in trackers.values(): if tracker.is_valid: diff --git a/pytype/tests/test_pattern_matching.py b/pytype/tests/test_pattern_matching.py index 996255d0d..8debc72e9 100644 --- a/pytype/tests/test_pattern_matching.py +++ b/pytype/tests/test_pattern_matching.py @@ -1559,7 +1559,7 @@ def f(self, x: A): raise ValueError('foo') """) - def test_optimized_bytecode_out_of_order(self): + def test_optimized_bytecode_out_of_order_1(self): """Regression test for a bug resulting from compiler optimisations.""" # Compier optimisations that inline code can put blocks out of order, which # could potentially interfere with our checks for the end of a match block. @@ -1584,6 +1584,34 @@ def test(color: Color): return color """) + def test_optimized_bytecode_out_of_order_2(self): + """Regression test for a bug resulting from compiler optimisations.""" + # See comment in the previous test case. + self.Check(""" + import enum + + class A(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + def f(x: A): + ret = True + count = 0 + while count < 10: + match x: + case A.RED: + ret = 10 + case A.BLUE: + ret = 20 + case _: + return ret + if ret: + break + else: + count += 1 + """) + @test_utils.skipBeforePy((3, 10), "New syntax in 3.10") class LiteralMatchCoverageTest(test_base.BaseTest): diff --git a/pytype/vm.py b/pytype/vm.py index f577c0c7b..c5acf67f1 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -308,7 +308,6 @@ def _run_frame_blocks(self, frame, node, annotated_locals): return_nodes = [] finally_tracker = vm_utils.FinallyStateTracker() process_blocks.adjust_returns(frame.f_code, self._director.block_returns) - process_blocks.check_out_of_order(frame.f_code) for block in frame.f_code.order: state = frame.states.get(block[0]) if not state: