From 3e6744e183dc2d4a1c38ce1753fc0670586b6955 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 9 Oct 2022 16:25:48 -0700 Subject: [PATCH 01/97] Add `match` statement entrypoint: Still getting a grip with how mypyc does things, but I feel like I am at a spot where I can actually start doing things now. --- mypyc/irbuild/statement.py | 17 +++++++++++ mypyc/irbuild/visitor.py | 3 +- mypyc/test-data/irbuild-match.test | 27 ++++++++++++++++ mypyc/test/test_irbuild.py | 49 +++++++++++++++--------------- mypyc/test/testutil.py | 2 +- 5 files changed, 72 insertions(+), 26 deletions(-) create mode 100644 mypyc/test-data/irbuild-match.test diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 371a305e67b9e..09b2d6112eb4c 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -28,6 +28,7 @@ ImportFrom, ListExpr, Lvalue, + MatchStmt, OperatorAssignmentStmt, RaiseStmt, ReturnStmt, @@ -896,3 +897,19 @@ def transform_yield_from_expr(builder: IRBuilder, o: YieldFromExpr) -> Value: def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: return emit_yield_from_or_await(builder, builder.accept(o.expr), o.line, is_await=True) + + +def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: + builder.accept(m.subject) + + assert len(m.bodies) == 1 + + block = BasicBlock() + next = BasicBlock() + + builder.goto(block) + builder.activate_block(block) + builder.accept(m.bodies[0]) + + builder.goto(next) + builder.activate_block(next) diff --git a/mypyc/irbuild/visitor.py b/mypyc/irbuild/visitor.py index dc126d4104094..d8725ee04dc5c 100644 --- a/mypyc/irbuild/visitor.py +++ b/mypyc/irbuild/visitor.py @@ -131,6 +131,7 @@ transform_import, transform_import_all, transform_import_from, + transform_match_stmt, transform_operator_assignment_stmt, transform_raise_stmt, transform_return_stmt, @@ -242,7 +243,7 @@ def visit_nonlocal_decl(self, stmt: NonlocalDecl) -> None: pass def visit_match_stmt(self, stmt: MatchStmt) -> None: - self.bail("Match statements are not yet supported", stmt.line) + transform_match_stmt(self.builder, stmt) # Expressions diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test new file mode 100644 index 0000000000000..f7b975f0e2d45 --- /dev/null +++ b/mypyc/test-data/irbuild-match.test @@ -0,0 +1,27 @@ +[case testMatchDefaultCase_python3_10] +def f(): + match True: + case _: + print("matched") +[out] +def f(): + r0 :: str + r1 :: object + r2 :: str + r3 :: object + r4 :: object[1] + r5 :: object_ptr + r6, r7 :: object +L0: +L1: + r0 = 'matched' + r1 = builtins :: module + r2 = 'print' + r3 = CPyObject_GetAttr(r1, r2) + r4 = [r0] + r5 = load_address r4 + r6 = _PyObject_Vectorcall(r3, r5, 1, 0) + keep_alive r0 +L2: + r7 = box(None, 1) + return r7 diff --git a/mypyc/test/test_irbuild.py b/mypyc/test/test_irbuild.py index bfce57c979036..51af88580edbd 100644 --- a/mypyc/test/test_irbuild.py +++ b/mypyc/test/test_irbuild.py @@ -21,30 +21,31 @@ ) files = [ - "irbuild-basic.test", - "irbuild-int.test", - "irbuild-lists.test", - "irbuild-tuple.test", - "irbuild-dict.test", - "irbuild-set.test", - "irbuild-str.test", - "irbuild-bytes.test", - "irbuild-statements.test", - "irbuild-nested.test", - "irbuild-classes.test", - "irbuild-optional.test", - "irbuild-any.test", - "irbuild-generics.test", - "irbuild-try.test", - "irbuild-strip-asserts.test", - "irbuild-i64.test", - "irbuild-i32.test", - "irbuild-vectorcall.test", - "irbuild-unreachable.test", - "irbuild-isinstance.test", - "irbuild-dunders.test", - "irbuild-singledispatch.test", - "irbuild-constant-fold.test", + # "irbuild-basic.test", + # "irbuild-int.test", + # "irbuild-lists.test", + # "irbuild-tuple.test", + # "irbuild-dict.test", + # "irbuild-set.test", + # "irbuild-str.test", + # "irbuild-bytes.test", + # "irbuild-statements.test", + # "irbuild-nested.test", + # "irbuild-classes.test", + # "irbuild-optional.test", + # "irbuild-any.test", + # "irbuild-generics.test", + # "irbuild-try.test", + # "irbuild-strip-asserts.test", + # "irbuild-i64.test", + # "irbuild-i32.test", + # "irbuild-vectorcall.test", + # "irbuild-unreachable.test", + # "irbuild-isinstance.test", + # "irbuild-dunders.test", + # "irbuild-singledispatch.test", + # "irbuild-constant-fold.test", + "irbuild-match.test", ] diff --git a/mypyc/test/testutil.py b/mypyc/test/testutil.py index 8339889fa9f5f..b97d8887e0f72 100644 --- a/mypyc/test/testutil.py +++ b/mypyc/test/testutil.py @@ -108,7 +108,7 @@ def build_ir_for_single_file2( options.hide_error_codes = True options.use_builtins_fixtures = True options.strict_optional = True - options.python_version = (3, 6) + options.python_version = (3, 10) options.export_types = True options.preserve_asts = True options.allow_empty_bodies = True From 39f366dddf71ba85fe5fbd91a61514ce1e636dc7 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 9 Oct 2022 16:41:16 -0700 Subject: [PATCH 02/97] Cleanup --- mypyc/irbuild/statement.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 09b2d6112eb4c..a257d99799d60 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -42,6 +42,7 @@ YieldExpr, YieldFromExpr, ) +from mypy.patterns import AsPattern from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -904,12 +905,20 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: assert len(m.bodies) == 1 - block = BasicBlock() - next = BasicBlock() + blocks = [BasicBlock() for _ in range(len(m.bodies))] + + for i, block in enumerate(blocks): + pattern = m.patterns[i] - builder.goto(block) - builder.activate_block(block) - builder.accept(m.bodies[0]) + if ( + isinstance(pattern, AsPattern) and + pattern.pattern == pattern.name == None + ): + # Default case + builder.goto(block) + builder.activate_block(block) + builder.accept(m.bodies[0]) + next = BasicBlock() builder.goto(next) builder.activate_block(next) From 8bee07a499a4750c1a12e0c37ec849c932a04c56 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 9 Oct 2022 21:54:09 -0700 Subject: [PATCH 03/97] Add value pattern check --- mypyc/irbuild/statement.py | 27 +++++++++++++++------ mypyc/test-data/irbuild-match.test | 38 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index a257d99799d60..ef207f98d5a34 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -17,6 +17,7 @@ AwaitExpr, Block, BreakStmt, + ComparisonExpr, ContinueStmt, DelStmt, Expression, @@ -42,7 +43,7 @@ YieldExpr, YieldFromExpr, ) -from mypy.patterns import AsPattern +from mypy.patterns import AsPattern, ValuePattern from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -901,13 +902,13 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: - builder.accept(m.subject) + subject = builder.accept(m.subject) assert len(m.bodies) == 1 - blocks = [BasicBlock() for _ in range(len(m.bodies))] + blocks = [BasicBlock() for _ in range(len(m.bodies) + 1)] - for i, block in enumerate(blocks): + for i, block in enumerate(blocks[:1]): pattern = m.patterns[i] if ( @@ -919,6 +920,18 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(block) builder.accept(m.bodies[0]) - next = BasicBlock() - builder.goto(next) - builder.activate_block(next) + if isinstance(pattern, ValuePattern): + # eq check + cond = builder.accept( + ComparisonExpr(["=="], [m.subject, pattern.expr]) + ) + + code_block = BasicBlock() + + builder.add_bool_branch(cond, code_block, blocks[i + 1]) + builder.goto(code_block) + builder.activate_block(code_block) + builder.accept(m.bodies[0]) + + builder.goto(blocks[-1]) + builder.activate_block(blocks[-1]) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index f7b975f0e2d45..43655f3b82cf2 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -25,3 +25,41 @@ L1: L2: r7 = box(None, 1) return r7 +[case testMatchBasicValue_python3_10] +def f(): + match 123: + case 123: + print("matched") +[out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object +L0: + r0 = 246 == 246 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 +L2: + r12 = box(None, 1) + return r12 From 0883015addc5b351bda7beb18d57fd24573f4dc7 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 9 Oct 2022 22:09:49 -0700 Subject: [PATCH 04/97] Cleanup --- mypyc/irbuild/statement.py | 20 ++++++++++++-------- mypyc/test-data/irbuild-match.test | 2 ++ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index ef207f98d5a34..95521e1e8fce0 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -906,19 +906,20 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: assert len(m.bodies) == 1 - blocks = [BasicBlock() for _ in range(len(m.bodies) + 1)] - - for i, block in enumerate(blocks[:1]): - pattern = m.patterns[i] + end_block = BasicBlock() + for pattern in m.patterns: if ( isinstance(pattern, AsPattern) and pattern.pattern == pattern.name == None ): + block = BasicBlock() + # Default case builder.goto(block) builder.activate_block(block) builder.accept(m.bodies[0]) + builder.goto(end_block) if isinstance(pattern, ValuePattern): # eq check @@ -927,11 +928,14 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: ) code_block = BasicBlock() + next_block = BasicBlock() - builder.add_bool_branch(cond, code_block, blocks[i + 1]) - builder.goto(code_block) + builder.add_bool_branch(cond, code_block, next_block) builder.activate_block(code_block) builder.accept(m.bodies[0]) + builder.goto(end_block) + + builder.activate_block(next_block) - builder.goto(blocks[-1]) - builder.activate_block(blocks[-1]) + builder.goto(end_block) + builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 43655f3b82cf2..5a3609765286e 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -60,6 +60,8 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 + goto L3 L2: +L3: r12 = box(None, 1) return r12 From e646c442693266e8e9ebfc84fb4c12f890f99f94 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 20:08:44 -0700 Subject: [PATCH 05/97] Reset --- mypyc/irbuild/statement.py | 33 ++-------------- mypyc/test-data/irbuild-match.test | 63 +----------------------------- 2 files changed, 4 insertions(+), 92 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 95521e1e8fce0..b3032fdfb658c 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -43,7 +43,7 @@ YieldExpr, YieldFromExpr, ) -from mypy.patterns import AsPattern, ValuePattern +from mypy.patterns import AsPattern, OrPattern, ValuePattern from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -902,40 +902,13 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: - subject = builder.accept(m.subject) + builder.accept(m.subject) assert len(m.bodies) == 1 - end_block = BasicBlock() - for pattern in m.patterns: if ( isinstance(pattern, AsPattern) and pattern.pattern == pattern.name == None ): - block = BasicBlock() - - # Default case - builder.goto(block) - builder.activate_block(block) - builder.accept(m.bodies[0]) - builder.goto(end_block) - - if isinstance(pattern, ValuePattern): - # eq check - cond = builder.accept( - ComparisonExpr(["=="], [m.subject, pattern.expr]) - ) - - code_block = BasicBlock() - next_block = BasicBlock() - - builder.add_bool_branch(cond, code_block, next_block) - builder.activate_block(code_block) - builder.accept(m.bodies[0]) - builder.goto(end_block) - - builder.activate_block(next_block) - - builder.goto(end_block) - builder.activate_block(end_block) + pass diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 5a3609765286e..aeaa32d75120a 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1,67 +1,6 @@ -[case testMatchDefaultCase_python3_10] +[case testMatchDefaultPattern_python3_10] def f(): match True: case _: print("matched") [out] -def f(): - r0 :: str - r1 :: object - r2 :: str - r3 :: object - r4 :: object[1] - r5 :: object_ptr - r6, r7 :: object -L0: -L1: - r0 = 'matched' - r1 = builtins :: module - r2 = 'print' - r3 = CPyObject_GetAttr(r1, r2) - r4 = [r0] - r5 = load_address r4 - r6 = _PyObject_Vectorcall(r3, r5, 1, 0) - keep_alive r0 -L2: - r7 = box(None, 1) - return r7 -[case testMatchBasicValue_python3_10] -def f(): - match 123: - case 123: - print("matched") -[out] -def f(): - r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: str - r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11, r12 :: object -L0: - r0 = 246 == 246 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool -L1: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 - goto L3 -L2: -L3: - r12 = box(None, 1) - return r12 From b27b380c4d082f246a606b4e20f5650ef9f60328 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 20:16:59 -0700 Subject: [PATCH 06/97] Add value pattern --- mypyc/irbuild/statement.py | 19 ++++++++++----- mypyc/test-data/irbuild-match.test | 38 +++++++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index b3032fdfb658c..b53a2662c4868 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -906,9 +906,16 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: assert len(m.bodies) == 1 - for pattern in m.patterns: - if ( - isinstance(pattern, AsPattern) and - pattern.pattern == pattern.name == None - ): - pass + for i, pattern in enumerate(m.patterns): + if isinstance(pattern, ValuePattern): + code_block = BasicBlock() + next_block = BasicBlock() + + cond = builder.accept(ComparisonExpr(["=="], [m.subject, pattern.expr])) + builder.add_bool_branch(cond, code_block, next_block) + + builder.activate_block(code_block) + builder.accept(m.bodies[i]) + builder.goto(next_block) + + builder.activate_block(next_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index aeaa32d75120a..151c36b527855 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1,6 +1,38 @@ -[case testMatchDefaultPattern_python3_10] +[case testMatchValue_python3_10] def f(): - match True: - case _: + match 123: + case 123: print("matched") [out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object +L0: + r0 = 246 == 246 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 +L2: + r12 = box(None, 1) + return r12 From f1e6a301149011d185035c2c7aabe3bf620aa10b Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 20:25:55 -0700 Subject: [PATCH 07/97] Explicitly type out len 2 or pattern --- mypyc/irbuild/statement.py | 26 ++++++++++++++ mypyc/test-data/irbuild-match.test | 54 +++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index b53a2662c4868..489279380afd4 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -919,3 +919,29 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.goto(next_block) builder.activate_block(next_block) + + if isinstance(pattern, OrPattern): + assert len(pattern.patterns) == 2 + assert all(isinstance(p, ValuePattern) for p in pattern.patterns) + + code_block = BasicBlock() + next_block = BasicBlock() + end_block = BasicBlock() + + cond = builder.accept( + ComparisonExpr(["=="], [m.subject, pattern.patterns[0].expr]) # type: ignore + ) + builder.add_bool_branch(cond, code_block, next_block) + + # TODO: move down below + builder.activate_block(code_block) + builder.accept(m.bodies[i]) + builder.goto(end_block) + + builder.activate_block(next_block) + next_block = BasicBlock() + cond = builder.accept( + ComparisonExpr(["=="], [m.subject, pattern.patterns[1].expr]) # type: ignore + ) + builder.add_bool_branch(cond, code_block, end_block) + builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 151c36b527855..49c599a03517e 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1,4 +1,4 @@ -[case testMatchValue_python3_10] +[case testMatchValuePattern_python3_10] def f(): match 123: case 123: @@ -36,3 +36,55 @@ L1: L2: r12 = box(None, 1) return r12 +[case testMatchOrPattern_python3_10] +def f(): + match 123: + case 123 | 456: + print("matched") +[out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11 :: object + r12 :: bit + r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: object +L0: + r0 = 246 == 246 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L3 +L2: + r12 = 246 == 912 + r13 = box(bit, r12) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L1 else goto L3 :: bool +L3: + r17 = box(None, 1) + return r17 From 31e3aced055f7e12c2941c53158a0ccc7ec5dc19 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 20:31:11 -0700 Subject: [PATCH 08/97] Add multiple Or Pattern support --- mypyc/irbuild/statement.py | 10 ++++- mypyc/test-data/irbuild-match.test | 64 ++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 489279380afd4..c99a61ae2ce3a 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -921,7 +921,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(next_block) if isinstance(pattern, OrPattern): - assert len(pattern.patterns) == 2 + assert len(pattern.patterns) == 3 assert all(isinstance(p, ValuePattern) for p in pattern.patterns) code_block = BasicBlock() @@ -943,5 +943,13 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: cond = builder.accept( ComparisonExpr(["=="], [m.subject, pattern.patterns[1].expr]) # type: ignore ) + builder.add_bool_branch(cond, code_block, next_block) + # builder.activate_block(next_block) + + builder.activate_block(next_block) + next_block = BasicBlock() + cond = builder.accept( + ComparisonExpr(["=="], [m.subject, pattern.patterns[2].expr]) # type: ignore + ) builder.add_bool_branch(cond, code_block, end_block) builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 49c599a03517e..286f0feb66a5b 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -88,3 +88,67 @@ L2: L3: r17 = box(None, 1) return r17 +[case testMatchOrPatternManyPatterns_python3_10] +def f(): + match 123: + case 123 | 456 | 789: + print("matched") +[out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11 :: object + r12 :: bit + r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: bit + r18 :: object + r19 :: int32 + r20 :: bit + r21 :: bool + r22 :: object +L0: + r0 = 246 == 246 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L4 +L2: + r12 = 246 == 912 + r13 = box(bit, r12) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L1 else goto L3 :: bool +L3: + r17 = 246 == 1578 + r18 = box(bit, r17) + r19 = PyObject_IsTrue(r18) + r20 = r19 >= 0 :: signed + r21 = truncate r19: int32 to builtins.bool + if r21 goto L1 else goto L4 :: bool +L4: + r22 = box(None, 1) + return r22 From 538d488ce8fd74796570ac1c214e8e84305706a2 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 20:45:27 -0700 Subject: [PATCH 09/97] Add one --- mypyc/irbuild/statement.py | 10 +++++++++- mypyc/test-data/irbuild-match.test | 22 +++++++++++++++++----- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index c99a61ae2ce3a..dd2ffc3616d84 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -921,7 +921,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(next_block) if isinstance(pattern, OrPattern): - assert len(pattern.patterns) == 3 + assert len(pattern.patterns) == 4 assert all(isinstance(p, ValuePattern) for p in pattern.patterns) code_block = BasicBlock() @@ -951,5 +951,13 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: cond = builder.accept( ComparisonExpr(["=="], [m.subject, pattern.patterns[2].expr]) # type: ignore ) + builder.add_bool_branch(cond, code_block, next_block) + # builder.activate_block(next_block) + + builder.activate_block(next_block) + next_block = BasicBlock() + cond = builder.accept( + ComparisonExpr(["=="], [m.subject, pattern.patterns[3].expr]) # type: ignore + ) builder.add_bool_branch(cond, code_block, end_block) builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 286f0feb66a5b..a396dd640e0ba 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -91,7 +91,7 @@ L3: [case testMatchOrPatternManyPatterns_python3_10] def f(): match 123: - case 123 | 456 | 789: + case 123 | 456 | 789 | 999: print("matched") [out] def f(): @@ -117,7 +117,12 @@ def f(): r19 :: int32 r20 :: bit r21 :: bool - r22 :: object + r22 :: bit + r23 :: object + r24 :: int32 + r25 :: bit + r26 :: bool + r27 :: object L0: r0 = 246 == 246 r1 = box(bit, r0) @@ -134,7 +139,7 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 - goto L4 + goto L5 L2: r12 = 246 == 912 r13 = box(bit, r12) @@ -150,5 +155,12 @@ L3: r21 = truncate r19: int32 to builtins.bool if r21 goto L1 else goto L4 :: bool L4: - r22 = box(None, 1) - return r22 + r22 = 246 == 1998 + r23 = box(bit, r22) + r24 = PyObject_IsTrue(r23) + r25 = r24 >= 0 :: signed + r26 = truncate r24: int32 to builtins.bool + if r26 goto L1 else goto L5 :: bool +L5: + r27 = box(None, 1) + return r27 From b65333541b4fa72de72b577a86054e3dc5d7fb43 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 20:48:45 -0700 Subject: [PATCH 10/97] Generalize --- mypyc/irbuild/statement.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index dd2ffc3616d84..05380b572fa81 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -921,7 +921,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(next_block) if isinstance(pattern, OrPattern): - assert len(pattern.patterns) == 4 + # assert len(pattern.patterns) == 4 assert all(isinstance(p, ValuePattern) for p in pattern.patterns) code_block = BasicBlock() @@ -938,26 +938,18 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.accept(m.bodies[i]) builder.goto(end_block) - builder.activate_block(next_block) - next_block = BasicBlock() - cond = builder.accept( - ComparisonExpr(["=="], [m.subject, pattern.patterns[1].expr]) # type: ignore - ) - builder.add_bool_branch(cond, code_block, next_block) - # builder.activate_block(next_block) - - builder.activate_block(next_block) - next_block = BasicBlock() - cond = builder.accept( - ComparisonExpr(["=="], [m.subject, pattern.patterns[2].expr]) # type: ignore - ) - builder.add_bool_branch(cond, code_block, next_block) - # builder.activate_block(next_block) + for p in pattern.patterns[1:-1]: + builder.activate_block(next_block) + next_block = BasicBlock() + cond = builder.accept( + ComparisonExpr(["=="], [m.subject, p.expr]) # type: ignore + ) + builder.add_bool_branch(cond, code_block, next_block) builder.activate_block(next_block) next_block = BasicBlock() cond = builder.accept( - ComparisonExpr(["=="], [m.subject, pattern.patterns[3].expr]) # type: ignore + ComparisonExpr(["=="], [m.subject, pattern.patterns[-1].expr]) # type: ignore ) builder.add_bool_branch(cond, code_block, end_block) builder.activate_block(end_block) From 2870c8ef4c64699a7fd1ec0a79fb0926b55e4f8e Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 21:14:33 -0700 Subject: [PATCH 11/97] Minimize --- mypyc/irbuild/statement.py | 8 ++------ mypyc/test-data/irbuild-match.test | 6 ++++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 05380b572fa81..e096b8ff842f1 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -938,7 +938,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.accept(m.bodies[i]) builder.goto(end_block) - for p in pattern.patterns[1:-1]: + for p in pattern.patterns[1:]: builder.activate_block(next_block) next_block = BasicBlock() cond = builder.accept( @@ -947,9 +947,5 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.add_bool_branch(cond, code_block, next_block) builder.activate_block(next_block) - next_block = BasicBlock() - cond = builder.accept( - ComparisonExpr(["=="], [m.subject, pattern.patterns[-1].expr]) # type: ignore - ) - builder.add_bool_branch(cond, code_block, end_block) + builder.goto(end_block) builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index a396dd640e0ba..dad038b17180e 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -77,7 +77,7 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 - goto L3 + goto L4 L2: r12 = 246 == 912 r13 = box(bit, r12) @@ -86,6 +86,7 @@ L2: r16 = truncate r14: int32 to builtins.bool if r16 goto L1 else goto L3 :: bool L3: +L4: r17 = box(None, 1) return r17 [case testMatchOrPatternManyPatterns_python3_10] @@ -139,7 +140,7 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 - goto L5 + goto L6 L2: r12 = 246 == 912 r13 = box(bit, r12) @@ -162,5 +163,6 @@ L4: r26 = truncate r24: int32 to builtins.bool if r26 goto L1 else goto L5 :: bool L5: +L6: r27 = box(None, 1) return r27 From 988769dc68862b04c27d8d4793edefd11472dfbc Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 21:22:52 -0700 Subject: [PATCH 12/97] Rearange codeblock --- mypyc/irbuild/statement.py | 10 +- mypyc/test-data/irbuild-match.test | 144 ++++++++++++++--------------- 2 files changed, 76 insertions(+), 78 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index e096b8ff842f1..6595a8e53e890 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -933,11 +933,6 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: ) builder.add_bool_branch(cond, code_block, next_block) - # TODO: move down below - builder.activate_block(code_block) - builder.accept(m.bodies[i]) - builder.goto(end_block) - for p in pattern.patterns[1:]: builder.activate_block(next_block) next_block = BasicBlock() @@ -948,4 +943,9 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(next_block) builder.goto(end_block) + + builder.activate_block(code_block) + builder.accept(m.bodies[i]) + builder.goto(end_block) + builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index dad038b17180e..773255ce1086e 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -48,44 +48,43 @@ def f(): r2 :: int32 r3 :: bit r4 :: bool - r5 :: str + r5 :: bit r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr + r7 :: int32 + r8 :: bit + r9 :: bool + r10 :: str r11 :: object - r12 :: bit + r12 :: str r13 :: object - r14 :: int32 - r15 :: bit - r16 :: bool - r17 :: object + r14 :: object[1] + r15 :: object_ptr + r16, r17 :: object L0: r0 = 246 == 246 r1 = box(bit, r0) r2 = PyObject_IsTrue(r1) r3 = r2 >= 0 :: signed r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool + if r4 goto L3 else goto L1 :: bool L1: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 - goto L4 + r5 = 246 == 912 + r6 = box(bit, r5) + r7 = PyObject_IsTrue(r6) + r8 = r7 >= 0 :: signed + r9 = truncate r7: int32 to builtins.bool + if r9 goto L3 else goto L2 :: bool L2: - r12 = 246 == 912 - r13 = box(bit, r12) - r14 = PyObject_IsTrue(r13) - r15 = r14 >= 0 :: signed - r16 = truncate r14: int32 to builtins.bool - if r16 goto L1 else goto L3 :: bool + goto L4 L3: + r10 = 'matched' + r11 = builtins :: module + r12 = 'print' + r13 = CPyObject_GetAttr(r11, r12) + r14 = [r10] + r15 = load_address r14 + r16 = _PyObject_Vectorcall(r13, r15, 1, 0) + keep_alive r10 L4: r17 = box(None, 1) return r17 @@ -101,68 +100,67 @@ def f(): r2 :: int32 r3 :: bit r4 :: bool - r5 :: str + r5 :: bit r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr + r7 :: int32 + r8 :: bit + r9 :: bool + r10 :: bit r11 :: object - r12 :: bit - r13 :: object - r14 :: int32 + r12 :: int32 + r13 :: bit + r14 :: bool r15 :: bit - r16 :: bool - r17 :: bit - r18 :: object - r19 :: int32 - r20 :: bit - r21 :: bool - r22 :: bit + r16 :: object + r17 :: int32 + r18 :: bit + r19 :: bool + r20 :: str + r21 :: object + r22 :: str r23 :: object - r24 :: int32 - r25 :: bit - r26 :: bool - r27 :: object + r24 :: object[1] + r25 :: object_ptr + r26, r27 :: object L0: r0 = 246 == 246 r1 = box(bit, r0) r2 = PyObject_IsTrue(r1) r3 = r2 >= 0 :: signed r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool + if r4 goto L5 else goto L1 :: bool L1: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 - goto L6 + r5 = 246 == 912 + r6 = box(bit, r5) + r7 = PyObject_IsTrue(r6) + r8 = r7 >= 0 :: signed + r9 = truncate r7: int32 to builtins.bool + if r9 goto L5 else goto L2 :: bool L2: - r12 = 246 == 912 - r13 = box(bit, r12) - r14 = PyObject_IsTrue(r13) - r15 = r14 >= 0 :: signed - r16 = truncate r14: int32 to builtins.bool - if r16 goto L1 else goto L3 :: bool + r10 = 246 == 1578 + r11 = box(bit, r10) + r12 = PyObject_IsTrue(r11) + r13 = r12 >= 0 :: signed + r14 = truncate r12: int32 to builtins.bool + if r14 goto L5 else goto L3 :: bool L3: - r17 = 246 == 1578 - r18 = box(bit, r17) - r19 = PyObject_IsTrue(r18) - r20 = r19 >= 0 :: signed - r21 = truncate r19: int32 to builtins.bool - if r21 goto L1 else goto L4 :: bool + r15 = 246 == 1998 + r16 = box(bit, r15) + r17 = PyObject_IsTrue(r16) + r18 = r17 >= 0 :: signed + r19 = truncate r17: int32 to builtins.bool + if r19 goto L5 else goto L4 :: bool L4: - r22 = 246 == 1998 - r23 = box(bit, r22) - r24 = PyObject_IsTrue(r23) - r25 = r24 >= 0 :: signed - r26 = truncate r24: int32 to builtins.bool - if r26 goto L1 else goto L5 :: bool + goto L6 L5: + r20 = 'matched' + r21 = builtins :: module + r22 = 'print' + r23 = CPyObject_GetAttr(r21, r22) + r24 = [r20] + r25 = load_address r24 + r26 = _PyObject_Vectorcall(r23, r25, 1, 0) + keep_alive r20 L6: r27 = box(None, 1) return r27 From 6cae5f83bbfa3f4e30c635c38b0dd2e7d4eb7053 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 21:25:08 -0700 Subject: [PATCH 13/97] Cleanup --- mypyc/irbuild/statement.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 6595a8e53e890..0a9f1aa05ea01 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -921,27 +921,20 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(next_block) if isinstance(pattern, OrPattern): - # assert len(pattern.patterns) == 4 assert all(isinstance(p, ValuePattern) for p in pattern.patterns) code_block = BasicBlock() next_block = BasicBlock() - end_block = BasicBlock() - cond = builder.accept( - ComparisonExpr(["=="], [m.subject, pattern.patterns[0].expr]) # type: ignore - ) - builder.add_bool_branch(cond, code_block, next_block) - - for p in pattern.patterns[1:]: - builder.activate_block(next_block) - next_block = BasicBlock() + for p in pattern.patterns: cond = builder.accept( ComparisonExpr(["=="], [m.subject, p.expr]) # type: ignore ) builder.add_bool_branch(cond, code_block, next_block) + builder.activate_block(next_block) + next_block = BasicBlock() - builder.activate_block(next_block) + end_block = BasicBlock() builder.goto(end_block) builder.activate_block(code_block) From 856dc85338a803cd5203a58783bcb3ecc6c121f2 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 21:27:17 -0700 Subject: [PATCH 14/97] Make it more readable --- mypyc/test-data/irbuild-match.test | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 773255ce1086e..d699aed75c828 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -90,8 +90,8 @@ L4: return r17 [case testMatchOrPatternManyPatterns_python3_10] def f(): - match 123: - case 123 | 456 | 789 | 999: + match 1: + case 1 | 2 | 3 | 4: print("matched") [out] def f(): @@ -123,28 +123,28 @@ def f(): r25 :: object_ptr r26, r27 :: object L0: - r0 = 246 == 246 + r0 = 2 == 2 r1 = box(bit, r0) r2 = PyObject_IsTrue(r1) r3 = r2 >= 0 :: signed r4 = truncate r2: int32 to builtins.bool if r4 goto L5 else goto L1 :: bool L1: - r5 = 246 == 912 + r5 = 2 == 4 r6 = box(bit, r5) r7 = PyObject_IsTrue(r6) r8 = r7 >= 0 :: signed r9 = truncate r7: int32 to builtins.bool if r9 goto L5 else goto L2 :: bool L2: - r10 = 246 == 1578 + r10 = 2 == 6 r11 = box(bit, r10) r12 = PyObject_IsTrue(r11) r13 = r12 >= 0 :: signed r14 = truncate r12: int32 to builtins.bool if r14 goto L5 else goto L3 :: bool L3: - r15 = 246 == 1998 + r15 = 2 == 8 r16 = box(bit, r15) r17 = PyObject_IsTrue(r16) r18 = r17 >= 0 :: signed From 65c9aaee1f403c02a6291acd28877ed4e5153d49 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 10 Oct 2022 21:57:54 -0700 Subject: [PATCH 15/97] Add class pattern support --- mypyc/irbuild/statement.py | 30 ++++++++++++++++++++++-- mypyc/test-data/irbuild-match.test | 37 ++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 0a9f1aa05ea01..e575632fd9025 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -12,11 +12,13 @@ from typing import Callable, Sequence from mypy.nodes import ( + ArgKind, AssertStmt, AssignmentStmt, AwaitExpr, Block, BreakStmt, + CallExpr, ComparisonExpr, ContinueStmt, DelStmt, @@ -30,6 +32,7 @@ ListExpr, Lvalue, MatchStmt, + NameExpr, OperatorAssignmentStmt, RaiseStmt, ReturnStmt, @@ -43,7 +46,7 @@ YieldExpr, YieldFromExpr, ) -from mypy.patterns import AsPattern, OrPattern, ValuePattern +from mypy.patterns import ClassPattern, OrPattern, ValuePattern from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -98,6 +101,7 @@ coro_op, import_from_op, send_op, + slow_isinstance_op, type_op, yield_from_except_op, ) @@ -902,7 +906,7 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: - builder.accept(m.subject) + subject = builder.accept(m.subject) assert len(m.bodies) == 1 @@ -942,3 +946,25 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.goto(end_block) builder.activate_block(end_block) + + if isinstance(pattern, ClassPattern): + assert not pattern.positionals + assert not pattern.keyword_keys + assert not pattern.keyword_values + + code_block = BasicBlock() + end_block = BasicBlock() + + cond = builder.call_c( + slow_isinstance_op, + [subject, builder.accept(pattern.class_ref)], + pattern.line + ) + + builder.add_bool_branch(cond, code_block, end_block) + + builder.activate_block(code_block) + builder.accept(m.bodies[i]) + builder.goto(end_block) + + builder.activate_block(end_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index d699aed75c828..3dad30107a838 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -164,3 +164,40 @@ L5: L6: r27 = box(None, 1) return r27 +[case testMatchClassPattern_python3_10] +def f(): + match 123: + case int(): + print("matched") +[out] +def f(): + r0, r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object +L0: + r0 = load_address PyLong_Type + r1 = object 123 + r2 = PyObject_IsInstance(r1, r0) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 +L2: + r12 = box(None, 1) + return r12 From 2955d9fe36147dd03c81a9e6ebd6572b9b337853 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 17:53:34 -0700 Subject: [PATCH 16/97] Add wildcard --- mypyc/irbuild/statement.py | 14 +++++++++++++- mypyc/test-data/irbuild-match.test | 27 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index e575632fd9025..dd2851539566f 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -46,7 +46,7 @@ YieldExpr, YieldFromExpr, ) -from mypy.patterns import ClassPattern, OrPattern, ValuePattern +from mypy.patterns import AsPattern, ClassPattern, OrPattern, ValuePattern from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -968,3 +968,15 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.goto(end_block) builder.activate_block(end_block) + + if isinstance(pattern, AsPattern): + assert not pattern.pattern + assert not pattern.name + + code_block = BasicBlock() + next_block = BasicBlock() + + builder.goto(code_block) + builder.activate_block(code_block) + builder.accept(m.bodies[i]) + builder.goto_and_activate(next_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 3dad30107a838..6c756646b5a1c 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -201,3 +201,30 @@ L1: L2: r12 = box(None, 1) return r12 +[case testMatchExaustivePattern_python3_10] +def f(): + match 123: + case _: + print("matched") +[out] +def f(): + r0 :: str + r1 :: object + r2 :: str + r3 :: object + r4 :: object[1] + r5 :: object_ptr + r6, r7 :: object +L0: +L1: + r0 = 'matched' + r1 = builtins :: module + r2 = 'print' + r3 = CPyObject_GetAttr(r1, r2) + r4 = [r0] + r5 = load_address r4 + r6 = _PyObject_Vectorcall(r3, r5, 1, 0) + keep_alive r0 +L2: + r7 = box(None, 1) + return r7 From 96fb472506606e78414fb9e98cc9544997e7de44 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:09:00 -0700 Subject: [PATCH 17/97] Add multibody support --- mypyc/irbuild/statement.py | 6 ++- mypyc/test-data/irbuild-match.test | 71 ++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index dd2851539566f..1205ab0ef4950 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -908,7 +908,7 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: subject = builder.accept(m.subject) - assert len(m.bodies) == 1 + final_block = BasicBlock() for i, pattern in enumerate(m.patterns): if isinstance(pattern, ValuePattern): @@ -920,7 +920,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(code_block) builder.accept(m.bodies[i]) - builder.goto(next_block) + builder.goto(final_block) builder.activate_block(next_block) @@ -980,3 +980,5 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(code_block) builder.accept(m.bodies[i]) builder.goto_and_activate(next_block) + + builder.goto_and_activate(final_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 6c756646b5a1c..3ed1d6a3ef231 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -228,3 +228,74 @@ L1: L2: r7 = box(None, 1) return r7 +[case testMatchMultipleBodies_python3_10] +def f(): + match 123: + case 123: + print("matched") + case 456: + print("no match") +[out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11 :: object + r12 :: bit + r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: str + r18 :: object + r19 :: str + r20 :: object + r21 :: object[1] + r22 :: object_ptr + r23, r24 :: object +L0: + r0 = 246 == 246 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L5 +L2: + r12 = 246 == 912 + r13 = box(bit, r12) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L3 else goto L4 :: bool +L3: + r17 = 'no match' + r18 = builtins :: module + r19 = 'print' + r20 = CPyObject_GetAttr(r18, r19) + r21 = [r17] + r22 = load_address r21 + r23 = _PyObject_Vectorcall(r20, r22, 1, 0) + keep_alive r17 + goto L5 +L4: +L5: + r24 = box(None, 1) + return r24 From f961fd0afb82c5aa70de58f8603fef457e48ee12 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:11:14 -0700 Subject: [PATCH 18/97] Fix failing tests --- mypyc/test-data/irbuild-match.test | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 3ed1d6a3ef231..0aa22e4183ee7 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -33,7 +33,9 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 + goto L3 L2: +L3: r12 = box(None, 1) return r12 [case testMatchOrPattern_python3_10] @@ -86,6 +88,7 @@ L3: r16 = _PyObject_Vectorcall(r13, r15, 1, 0) keep_alive r10 L4: +L5: r17 = box(None, 1) return r17 [case testMatchOrPatternManyPatterns_python3_10] @@ -162,6 +165,7 @@ L5: r26 = _PyObject_Vectorcall(r23, r25, 1, 0) keep_alive r20 L6: +L7: r27 = box(None, 1) return r27 [case testMatchClassPattern_python3_10] @@ -199,6 +203,7 @@ L1: r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 L2: +L3: r12 = box(None, 1) return r12 [case testMatchExaustivePattern_python3_10] @@ -226,6 +231,7 @@ L1: r6 = _PyObject_Vectorcall(r3, r5, 1, 0) keep_alive r0 L2: +L3: r7 = box(None, 1) return r7 [case testMatchMultipleBodies_python3_10] From 899c954e84db5cc3cee59e3f4db99190bb472c90 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:14:22 -0700 Subject: [PATCH 19/97] Add final block gotos --- mypyc/irbuild/statement.py | 8 +++++--- mypyc/test-data/irbuild-match.test | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 1205ab0ef4950..01189c4d72b1c 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -943,7 +943,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(code_block) builder.accept(m.bodies[i]) - builder.goto(end_block) + builder.goto(final_block) builder.activate_block(end_block) @@ -965,7 +965,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.activate_block(code_block) builder.accept(m.bodies[i]) - builder.goto(end_block) + builder.goto(final_block) builder.activate_block(end_block) @@ -979,6 +979,8 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.goto(code_block) builder.activate_block(code_block) builder.accept(m.bodies[i]) - builder.goto_and_activate(next_block) + builder.goto(final_block) + + builder.activate_block(next_block) builder.goto_and_activate(final_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 0aa22e4183ee7..fd0f826045c89 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -87,6 +87,7 @@ L3: r15 = load_address r14 r16 = _PyObject_Vectorcall(r13, r15, 1, 0) keep_alive r10 + goto L5 L4: L5: r17 = box(None, 1) @@ -164,6 +165,7 @@ L5: r25 = load_address r24 r26 = _PyObject_Vectorcall(r23, r25, 1, 0) keep_alive r20 + goto L7 L6: L7: r27 = box(None, 1) @@ -202,6 +204,7 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r8, r10, 1, 0) keep_alive r5 + goto L3 L2: L3: r12 = box(None, 1) @@ -230,6 +233,7 @@ L1: r5 = load_address r4 r6 = _PyObject_Vectorcall(r3, r5, 1, 0) keep_alive r0 + goto L3 L2: L3: r7 = box(None, 1) From dcf84e094d26bd2cf2fba09bd28d591d94d5dc1d Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:20:04 -0700 Subject: [PATCH 20/97] Add complex sanity test --- mypyc/test-data/irbuild-match.test | 116 +++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index fd0f826045c89..5ad073cb5cc34 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -309,3 +309,119 @@ L4: L5: r24 = box(None, 1) return r24 +[case testMatchMultiBodyAndComplexOr_python3_10] +def f(): + match 123: + case 1: + print("here 1") + case 2 | 3: + print("here 2 | 3") + case 123: + print("here 123") +[out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11 :: object + r12 :: bit + r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: bit + r18 :: object + r19 :: int32 + r20 :: bit + r21 :: bool + r22 :: str + r23 :: object + r24 :: str + r25 :: object + r26 :: object[1] + r27 :: object_ptr + r28 :: object + r29 :: bit + r30 :: object + r31 :: int32 + r32 :: bit + r33 :: bool + r34 :: str + r35 :: object + r36 :: str + r37 :: object + r38 :: object[1] + r39 :: object_ptr + r40, r41 :: object +L0: + r0 = 246 == 2 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L2 :: bool +L1: + r5 = 'here 1' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L9 +L2: + r12 = 246 == 4 + r13 = box(bit, r12) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L5 else goto L3 :: bool +L3: + r17 = 246 == 6 + r18 = box(bit, r17) + r19 = PyObject_IsTrue(r18) + r20 = r19 >= 0 :: signed + r21 = truncate r19: int32 to builtins.bool + if r21 goto L5 else goto L4 :: bool +L4: + goto L6 +L5: + r22 = 'here 2 | 3' + r23 = builtins :: module + r24 = 'print' + r25 = CPyObject_GetAttr(r23, r24) + r26 = [r22] + r27 = load_address r26 + r28 = _PyObject_Vectorcall(r25, r27, 1, 0) + keep_alive r22 + goto L9 +L6: + r29 = 246 == 246 + r30 = box(bit, r29) + r31 = PyObject_IsTrue(r30) + r32 = r31 >= 0 :: signed + r33 = truncate r31: int32 to builtins.bool + if r33 goto L7 else goto L8 :: bool +L7: + r34 = 'here 123' + r35 = builtins :: module + r36 = 'print' + r37 = CPyObject_GetAttr(r35, r36) + r38 = [r34] + r39 = load_address r38 + r40 = _PyObject_Vectorcall(r37, r39, 1, 0) + keep_alive r34 + goto L9 +L8: +L9: + r41 = box(None, 1) + return r41 From eb2dc2d5c7aa79ddb091e51da74b714a64764bc7 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:39:08 -0700 Subject: [PATCH 21/97] Add run test --- mypyc/build.py | 1 + mypyc/test-data/run-match.test | 16 ++++++++++ mypyc/test/test_run.py | 55 +++++++++++++++++----------------- 3 files changed, 45 insertions(+), 27 deletions(-) create mode 100644 mypyc/test-data/run-match.test diff --git a/mypyc/build.py b/mypyc/build.py index 4f40a6cd08659..51696e86fa941 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -534,6 +534,7 @@ def mypycify( "-Wno-unused-command-line-argument", "-Wno-unknown-warning-option", "-Wno-unused-but-set-variable", + "-Wno-cpp", ] elif compiler.compiler_type == "msvc": # msvc doesn't have levels, '/O2' is full and '/Od' is disable diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test new file mode 100644 index 0000000000000..5f90e98fbdf82 --- /dev/null +++ b/mypyc/test-data/run-match.test @@ -0,0 +1,16 @@ +[case testMatchBasic_python3_10] +def f(x): + match x: + case 123: + print("matched!") + case _: + print("no match") +[file driver.py] +from native import f + +f(123) +f(321) + +[out] +matched! +no match diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 63e4f153da401..6c675a9eab9ff 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -35,33 +35,34 @@ ) files = [ - "run-async.test", - "run-misc.test", - "run-functions.test", - "run-integers.test", - "run-i64.test", - "run-i32.test", - "run-floats.test", - "run-bools.test", - "run-strings.test", - "run-bytes.test", - "run-tuples.test", - "run-lists.test", - "run-dicts.test", - "run-sets.test", - "run-primitives.test", - "run-loops.test", - "run-exceptions.test", - "run-imports.test", - "run-classes.test", - "run-traits.test", - "run-generators.test", - "run-multimodule.test", - "run-bench.test", - "run-mypy-sim.test", - "run-dunders.test", - "run-singledispatch.test", - "run-attrs.test", + # "run-async.test", + # "run-misc.test", + # "run-functions.test", + # "run-integers.test", + # "run-i64.test", + # "run-i32.test", + # "run-floats.test", + # "run-bools.test", + # "run-strings.test", + # "run-bytes.test", + # "run-tuples.test", + # "run-lists.test", + # "run-dicts.test", + # "run-sets.test", + # "run-primitives.test", + # "run-loops.test", + # "run-exceptions.test", + # "run-imports.test", + # "run-classes.test", + # "run-traits.test", + # "run-generators.test", + # "run-multimodule.test", + # "run-bench.test", + # "run-mypy-sim.test", + # "run-dunders.test", + # "run-singledispatch.test", + # "run-attrs.test", + "run-match.test", ] files.append("run-python37.test") From 522461dbbf7262232a8bff0266a01b75df677281 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:46:15 -0700 Subject: [PATCH 22/97] Add pattern guard for value pattern --- mypyc/irbuild/statement.py | 9 +++++++ mypyc/test-data/irbuild-match.test | 42 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 01189c4d72b1c..fd70c7e3133b3 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -919,6 +919,15 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.add_bool_branch(cond, code_block, next_block) builder.activate_block(code_block) + + if guard := m.guards[i]: + new_code_block = BasicBlock() + + cond = builder.accept(guard) + builder.add_bool_branch(cond, new_code_block, next_block) + + builder.activate_block(new_code_block) + builder.accept(m.bodies[i]) builder.goto(final_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 5ad073cb5cc34..985489c0ea45a 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -425,3 +425,45 @@ L8: L9: r41 = box(None, 1) return r41 +[case testMatchWithGuard_python3_10] +def f(): + match 123: + case 123 if True: + print("matched") +[out] +def f(): + r0 :: bit + r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object +L0: + r0 = 246 == 246 + r1 = box(bit, r0) + r2 = PyObject_IsTrue(r1) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + if r4 goto L1 else goto L3 :: bool +L1: + if 1 goto L2 else goto L3 :: bool +L2: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L4 +L3: +L4: + r12 = box(None, 1) + return r12 From 8ce67d9b754b9751b608c750d53b950ba6fe2d9b Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 18:54:16 -0700 Subject: [PATCH 23/97] Add pattern guard support to all existing patterns --- mypyc/irbuild/statement.py | 42 +++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index fd70c7e3133b3..11716d61448f7 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -910,6 +910,22 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: final_block = BasicBlock() + def build_match_body( + index: int, code_block: BasicBlock, next_block: BasicBlock + ) -> None: + builder.activate_block(code_block) + + if guard := m.guards[index]: + new_code_block = BasicBlock() + + cond = builder.accept(guard) + builder.add_bool_branch(cond, new_code_block, next_block) + + builder.activate_block(new_code_block) + + builder.accept(m.bodies[index]) + builder.goto(final_block) + for i, pattern in enumerate(m.patterns): if isinstance(pattern, ValuePattern): code_block = BasicBlock() @@ -918,18 +934,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: cond = builder.accept(ComparisonExpr(["=="], [m.subject, pattern.expr])) builder.add_bool_branch(cond, code_block, next_block) - builder.activate_block(code_block) - - if guard := m.guards[i]: - new_code_block = BasicBlock() - - cond = builder.accept(guard) - builder.add_bool_branch(cond, new_code_block, next_block) - - builder.activate_block(new_code_block) - - builder.accept(m.bodies[i]) - builder.goto(final_block) + build_match_body(i, code_block, next_block) builder.activate_block(next_block) @@ -950,9 +955,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: end_block = BasicBlock() builder.goto(end_block) - builder.activate_block(code_block) - builder.accept(m.bodies[i]) - builder.goto(final_block) + build_match_body(i, code_block, end_block) builder.activate_block(end_block) @@ -972,9 +975,7 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: builder.add_bool_branch(cond, code_block, end_block) - builder.activate_block(code_block) - builder.accept(m.bodies[i]) - builder.goto(final_block) + build_match_body(i, code_block, end_block) builder.activate_block(end_block) @@ -986,9 +987,8 @@ def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: next_block = BasicBlock() builder.goto(code_block) - builder.activate_block(code_block) - builder.accept(m.bodies[i]) - builder.goto(final_block) + + build_match_body(i, code_block, next_block) builder.activate_block(next_block) From c36448975b241190a34ce3cfd9c305dbcb9a78c6 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:31:28 -0700 Subject: [PATCH 24/97] Add singleton pattern --- mypyc/irbuild/statement.py | 27 ++++++++-- mypyc/test-data/irbuild-match.test | 85 ++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 11716d61448f7..de28783004658 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -12,13 +12,11 @@ from typing import Callable, Sequence from mypy.nodes import ( - ArgKind, AssertStmt, AssignmentStmt, AwaitExpr, Block, BreakStmt, - CallExpr, ComparisonExpr, ContinueStmt, DelStmt, @@ -32,7 +30,6 @@ ListExpr, Lvalue, MatchStmt, - NameExpr, OperatorAssignmentStmt, RaiseStmt, ReturnStmt, @@ -46,7 +43,7 @@ YieldExpr, YieldFromExpr, ) -from mypy.patterns import AsPattern, ClassPattern, OrPattern, ValuePattern +from mypy.patterns import AsPattern, ClassPattern, OrPattern, SingletonPattern, ValuePattern from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -100,6 +97,7 @@ check_stop_op, coro_op, import_from_op, + none_object_op, send_op, slow_isinstance_op, type_op, @@ -992,4 +990,25 @@ def build_match_body( builder.activate_block(next_block) + if isinstance(pattern, SingletonPattern): + code_block = BasicBlock() + next_block = BasicBlock() + + if pattern.value is None: + obj = builder.none_object() + elif pattern.value is True: + obj = builder.true() + else: + obj = builder.false() + + cond = builder.binary_op(subject, obj, "is", pattern.line) + + builder.add_bool_branch(cond, code_block, next_block) + + builder.goto(code_block) + + build_match_body(i, code_block, next_block) + + builder.activate_block(next_block) + builder.goto_and_activate(final_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 985489c0ea45a..1ef6512bbd4b8 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -467,3 +467,88 @@ L3: L4: r12 = box(None, 1) return r12 +[case testMatchSingleton_python3_10] +def f(): + match 123: + case True: + print("value is True") + case False: + print("value is False") + case None: + print("value is None") +[out] +def f(): + r0, r1 :: object + r2 :: bit + r3 :: str + r4 :: object + r5 :: str + r6 :: object + r7 :: object[1] + r8 :: object_ptr + r9, r10, r11 :: object + r12 :: bit + r13 :: str + r14 :: object + r15 :: str + r16 :: object + r17 :: object[1] + r18 :: object_ptr + r19, r20, r21 :: object + r22 :: bit + r23 :: str + r24 :: object + r25 :: str + r26 :: object + r27 :: object[1] + r28 :: object_ptr + r29, r30 :: object +L0: + r0 = object 123 + r1 = box(bool, 1) + r2 = r0 == r1 + if r2 goto L1 else goto L2 :: bool +L1: + r3 = 'value is True' + r4 = builtins :: module + r5 = 'print' + r6 = CPyObject_GetAttr(r4, r5) + r7 = [r3] + r8 = load_address r7 + r9 = _PyObject_Vectorcall(r6, r8, 1, 0) + keep_alive r3 + goto L7 +L2: + r10 = object 123 + r11 = box(bool, 0) + r12 = r10 == r11 + if r12 goto L3 else goto L4 :: bool +L3: + r13 = 'value is False' + r14 = builtins :: module + r15 = 'print' + r16 = CPyObject_GetAttr(r14, r15) + r17 = [r13] + r18 = load_address r17 + r19 = _PyObject_Vectorcall(r16, r18, 1, 0) + keep_alive r13 + goto L7 +L4: + r20 = load_address _Py_NoneStruct + r21 = object 123 + r22 = r21 == r20 + if r22 goto L5 else goto L6 :: bool +L5: + r23 = 'value is None' + r24 = builtins :: module + r25 = 'print' + r26 = CPyObject_GetAttr(r24, r25) + r27 = [r23] + r28 = load_address r27 + r29 = _PyObject_Vectorcall(r26, r28, 1, 0) + keep_alive r23 + goto L7 +L6: +L7: + r30 = box(None, 1) + return r30 From d43e8ca3e4cb8f078df6398664d299038d6fb11f Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 11 Oct 2022 20:56:37 -0700 Subject: [PATCH 25/97] Greatly reduce number of opcodes --- mypyc/irbuild/statement.py | 10 +- mypyc/test-data/irbuild-match.test | 459 +++++++++++------------------ 2 files changed, 176 insertions(+), 293 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index de28783004658..25c161c333c72 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -17,7 +17,6 @@ AwaitExpr, Block, BreakStmt, - ComparisonExpr, ContinueStmt, DelStmt, Expression, @@ -929,7 +928,9 @@ def build_match_body( code_block = BasicBlock() next_block = BasicBlock() - cond = builder.accept(ComparisonExpr(["=="], [m.subject, pattern.expr])) + cond = builder.binary_op( + subject, builder.accept(pattern.expr), "==", pattern.expr.line + ) builder.add_bool_branch(cond, code_block, next_block) build_match_body(i, code_block, next_block) @@ -943,9 +944,8 @@ def build_match_body( next_block = BasicBlock() for p in pattern.patterns: - cond = builder.accept( - ComparisonExpr(["=="], [m.subject, p.expr]) # type: ignore - ) + cond = builder.binary_op(subject, builder.accept(p.expr), "==", p.expr.line) # type: ignore + builder.add_bool_branch(cond, code_block, next_block) builder.activate_block(next_block) next_block = BasicBlock() diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 1ef6512bbd4b8..3e4c9173f05fc 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -6,38 +6,30 @@ def f(): [out] def f(): r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: str - r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11, r12 :: object + r1 :: str + r2 :: object + r3 :: str + r4 :: object + r5 :: object[1] + r6 :: object_ptr + r7, r8 :: object L0: r0 = 246 == 246 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool + if r0 goto L1 else goto L2 :: bool L1: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 + r1 = 'matched' + r2 = builtins :: module + r3 = 'print' + r4 = CPyObject_GetAttr(r2, r3) + r5 = [r1] + r6 = load_address r5 + r7 = _PyObject_Vectorcall(r4, r6, 1, 0) + keep_alive r1 goto L3 L2: L3: - r12 = box(None, 1) - return r12 + r8 = box(None, 1) + return r8 [case testMatchOrPattern_python3_10] def f(): match 123: @@ -45,53 +37,36 @@ def f(): print("matched") [out] def f(): - r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: bit - r6 :: object - r7 :: int32 - r8 :: bit - r9 :: bool - r10 :: str - r11 :: object - r12 :: str - r13 :: object - r14 :: object[1] - r15 :: object_ptr - r16, r17 :: object + r0, r1 :: bit + r2 :: str + r3 :: object + r4 :: str + r5 :: object + r6 :: object[1] + r7 :: object_ptr + r8, r9 :: object L0: r0 = 246 == 246 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L3 else goto L1 :: bool + if r0 goto L3 else goto L1 :: bool L1: - r5 = 246 == 912 - r6 = box(bit, r5) - r7 = PyObject_IsTrue(r6) - r8 = r7 >= 0 :: signed - r9 = truncate r7: int32 to builtins.bool - if r9 goto L3 else goto L2 :: bool + r1 = 246 == 912 + if r1 goto L3 else goto L2 :: bool L2: goto L4 L3: - r10 = 'matched' - r11 = builtins :: module - r12 = 'print' - r13 = CPyObject_GetAttr(r11, r12) - r14 = [r10] - r15 = load_address r14 - r16 = _PyObject_Vectorcall(r13, r15, 1, 0) - keep_alive r10 + r2 = 'matched' + r3 = builtins :: module + r4 = 'print' + r5 = CPyObject_GetAttr(r3, r4) + r6 = [r2] + r7 = load_address r6 + r8 = _PyObject_Vectorcall(r5, r7, 1, 0) + keep_alive r2 goto L5 L4: L5: - r17 = box(None, 1) - return r17 + r9 = box(None, 1) + return r9 [case testMatchOrPatternManyPatterns_python3_10] def f(): match 1: @@ -99,77 +74,42 @@ def f(): print("matched") [out] def f(): - r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: bit - r6 :: object - r7 :: int32 - r8 :: bit - r9 :: bool - r10 :: bit - r11 :: object - r12 :: int32 - r13 :: bit - r14 :: bool - r15 :: bit - r16 :: object - r17 :: int32 - r18 :: bit - r19 :: bool - r20 :: str - r21 :: object - r22 :: str - r23 :: object - r24 :: object[1] - r25 :: object_ptr - r26, r27 :: object + r0, r1, r2, r3 :: bit + r4 :: str + r5 :: object + r6 :: str + r7 :: object + r8 :: object[1] + r9 :: object_ptr + r10, r11 :: object L0: r0 = 2 == 2 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L5 else goto L1 :: bool + if r0 goto L5 else goto L1 :: bool L1: - r5 = 2 == 4 - r6 = box(bit, r5) - r7 = PyObject_IsTrue(r6) - r8 = r7 >= 0 :: signed - r9 = truncate r7: int32 to builtins.bool - if r9 goto L5 else goto L2 :: bool + r1 = 2 == 4 + if r1 goto L5 else goto L2 :: bool L2: - r10 = 2 == 6 - r11 = box(bit, r10) - r12 = PyObject_IsTrue(r11) - r13 = r12 >= 0 :: signed - r14 = truncate r12: int32 to builtins.bool - if r14 goto L5 else goto L3 :: bool + r2 = 2 == 6 + if r2 goto L5 else goto L3 :: bool L3: - r15 = 2 == 8 - r16 = box(bit, r15) - r17 = PyObject_IsTrue(r16) - r18 = r17 >= 0 :: signed - r19 = truncate r17: int32 to builtins.bool - if r19 goto L5 else goto L4 :: bool + r3 = 2 == 8 + if r3 goto L5 else goto L4 :: bool L4: goto L6 L5: - r20 = 'matched' - r21 = builtins :: module - r22 = 'print' - r23 = CPyObject_GetAttr(r21, r22) - r24 = [r20] - r25 = load_address r24 - r26 = _PyObject_Vectorcall(r23, r25, 1, 0) - keep_alive r20 + r4 = 'matched' + r5 = builtins :: module + r6 = 'print' + r7 = CPyObject_GetAttr(r5, r6) + r8 = [r4] + r9 = load_address r8 + r10 = _PyObject_Vectorcall(r7, r9, 1, 0) + keep_alive r4 goto L7 L6: L7: - r27 = box(None, 1) - return r27 + r11 = box(None, 1) + return r11 [case testMatchClassPattern_python3_10] def f(): match 123: @@ -248,67 +188,51 @@ def f(): [out] def f(): r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: str - r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11 :: object - r12 :: bit - r13 :: object - r14 :: int32 - r15 :: bit - r16 :: bool - r17 :: str - r18 :: object - r19 :: str - r20 :: object - r21 :: object[1] - r22 :: object_ptr - r23, r24 :: object + r1 :: str + r2 :: object + r3 :: str + r4 :: object + r5 :: object[1] + r6 :: object_ptr + r7 :: object + r8 :: bit + r9 :: str + r10 :: object + r11 :: str + r12 :: object + r13 :: object[1] + r14 :: object_ptr + r15, r16 :: object L0: r0 = 246 == 246 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool + if r0 goto L1 else goto L2 :: bool L1: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 + r1 = 'matched' + r2 = builtins :: module + r3 = 'print' + r4 = CPyObject_GetAttr(r2, r3) + r5 = [r1] + r6 = load_address r5 + r7 = _PyObject_Vectorcall(r4, r6, 1, 0) + keep_alive r1 goto L5 L2: - r12 = 246 == 912 - r13 = box(bit, r12) - r14 = PyObject_IsTrue(r13) - r15 = r14 >= 0 :: signed - r16 = truncate r14: int32 to builtins.bool - if r16 goto L3 else goto L4 :: bool + r8 = 246 == 912 + if r8 goto L3 else goto L4 :: bool L3: - r17 = 'no match' - r18 = builtins :: module - r19 = 'print' - r20 = CPyObject_GetAttr(r18, r19) - r21 = [r17] - r22 = load_address r21 - r23 = _PyObject_Vectorcall(r20, r22, 1, 0) - keep_alive r17 + r9 = 'no match' + r10 = builtins :: module + r11 = 'print' + r12 = CPyObject_GetAttr(r10, r11) + r13 = [r9] + r14 = load_address r13 + r15 = _PyObject_Vectorcall(r12, r14, 1, 0) + keep_alive r9 goto L5 L4: L5: - r24 = box(None, 1) - return r24 + r16 = box(None, 1) + return r16 [case testMatchMultiBodyAndComplexOr_python3_10] def f(): match 123: @@ -321,110 +245,77 @@ def f(): [out] def f(): r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: str - r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr + r1 :: str + r2 :: object + r3 :: str + r4 :: object + r5 :: object[1] + r6 :: object_ptr + r7 :: object + r8, r9 :: bit + r10 :: str r11 :: object - r12 :: bit + r12 :: str r13 :: object - r14 :: int32 - r15 :: bit - r16 :: bool + r14 :: object[1] + r15 :: object_ptr + r16 :: object r17 :: bit - r18 :: object - r19 :: int32 - r20 :: bit - r21 :: bool - r22 :: str - r23 :: object - r24 :: str - r25 :: object - r26 :: object[1] - r27 :: object_ptr - r28 :: object - r29 :: bit - r30 :: object - r31 :: int32 - r32 :: bit - r33 :: bool - r34 :: str - r35 :: object - r36 :: str - r37 :: object - r38 :: object[1] - r39 :: object_ptr - r40, r41 :: object + r18 :: str + r19 :: object + r20 :: str + r21 :: object + r22 :: object[1] + r23 :: object_ptr + r24, r25 :: object L0: r0 = 246 == 2 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool + if r0 goto L1 else goto L2 :: bool L1: - r5 = 'here 1' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 + r1 = 'here 1' + r2 = builtins :: module + r3 = 'print' + r4 = CPyObject_GetAttr(r2, r3) + r5 = [r1] + r6 = load_address r5 + r7 = _PyObject_Vectorcall(r4, r6, 1, 0) + keep_alive r1 goto L9 L2: - r12 = 246 == 4 - r13 = box(bit, r12) - r14 = PyObject_IsTrue(r13) - r15 = r14 >= 0 :: signed - r16 = truncate r14: int32 to builtins.bool - if r16 goto L5 else goto L3 :: bool + r8 = 246 == 4 + if r8 goto L5 else goto L3 :: bool L3: - r17 = 246 == 6 - r18 = box(bit, r17) - r19 = PyObject_IsTrue(r18) - r20 = r19 >= 0 :: signed - r21 = truncate r19: int32 to builtins.bool - if r21 goto L5 else goto L4 :: bool + r9 = 246 == 6 + if r9 goto L5 else goto L4 :: bool L4: goto L6 L5: - r22 = 'here 2 | 3' - r23 = builtins :: module - r24 = 'print' - r25 = CPyObject_GetAttr(r23, r24) - r26 = [r22] - r27 = load_address r26 - r28 = _PyObject_Vectorcall(r25, r27, 1, 0) - keep_alive r22 + r10 = 'here 2 | 3' + r11 = builtins :: module + r12 = 'print' + r13 = CPyObject_GetAttr(r11, r12) + r14 = [r10] + r15 = load_address r14 + r16 = _PyObject_Vectorcall(r13, r15, 1, 0) + keep_alive r10 goto L9 L6: - r29 = 246 == 246 - r30 = box(bit, r29) - r31 = PyObject_IsTrue(r30) - r32 = r31 >= 0 :: signed - r33 = truncate r31: int32 to builtins.bool - if r33 goto L7 else goto L8 :: bool + r17 = 246 == 246 + if r17 goto L7 else goto L8 :: bool L7: - r34 = 'here 123' - r35 = builtins :: module - r36 = 'print' - r37 = CPyObject_GetAttr(r35, r36) - r38 = [r34] - r39 = load_address r38 - r40 = _PyObject_Vectorcall(r37, r39, 1, 0) - keep_alive r34 + r18 = 'here 123' + r19 = builtins :: module + r20 = 'print' + r21 = CPyObject_GetAttr(r19, r20) + r22 = [r18] + r23 = load_address r22 + r24 = _PyObject_Vectorcall(r21, r23, 1, 0) + keep_alive r18 goto L9 L8: L9: - r41 = box(None, 1) - return r41 + r25 = box(None, 1) + return r25 [case testMatchWithGuard_python3_10] def f(): match 123: @@ -433,40 +324,32 @@ def f(): [out] def f(): r0 :: bit - r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool - r5 :: str - r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11, r12 :: object + r1 :: str + r2 :: object + r3 :: str + r4 :: object + r5 :: object[1] + r6 :: object_ptr + r7, r8 :: object L0: r0 = 246 == 246 - r1 = box(bit, r0) - r2 = PyObject_IsTrue(r1) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L3 :: bool + if r0 goto L1 else goto L3 :: bool L1: if 1 goto L2 else goto L3 :: bool L2: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 + r1 = 'matched' + r2 = builtins :: module + r3 = 'print' + r4 = CPyObject_GetAttr(r2, r3) + r5 = [r1] + r6 = load_address r5 + r7 = _PyObject_Vectorcall(r4, r6, 1, 0) + keep_alive r1 goto L4 L3: L4: - r12 = box(None, 1) - return r12 + r8 = box(None, 1) + return r8 [case testMatchSingleton_python3_10] def f(): match 123: From acdeb33e5d07b0fe1b8371efffc513e028f01a2e Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:07:29 -0700 Subject: [PATCH 26/97] Move code_block out --- mypyc/irbuild/statement.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 25c161c333c72..dd882f6759227 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -924,8 +924,9 @@ def build_match_body( builder.goto(final_block) for i, pattern in enumerate(m.patterns): + code_block = BasicBlock() + if isinstance(pattern, ValuePattern): - code_block = BasicBlock() next_block = BasicBlock() cond = builder.binary_op( @@ -940,7 +941,6 @@ def build_match_body( if isinstance(pattern, OrPattern): assert all(isinstance(p, ValuePattern) for p in pattern.patterns) - code_block = BasicBlock() next_block = BasicBlock() for p in pattern.patterns: @@ -962,7 +962,6 @@ def build_match_body( assert not pattern.keyword_keys assert not pattern.keyword_values - code_block = BasicBlock() end_block = BasicBlock() cond = builder.call_c( @@ -981,7 +980,6 @@ def build_match_body( assert not pattern.pattern assert not pattern.name - code_block = BasicBlock() next_block = BasicBlock() builder.goto(code_block) @@ -991,7 +989,6 @@ def build_match_body( builder.activate_block(next_block) if isinstance(pattern, SingletonPattern): - code_block = BasicBlock() next_block = BasicBlock() if pattern.value is None: From beb14763faf49c98a97e11b92b6f7de9a1793be6 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:08:23 -0700 Subject: [PATCH 27/97] Move out next_block --- mypyc/irbuild/statement.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index dd882f6759227..f75eed3f0def7 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -925,10 +925,9 @@ def build_match_body( for i, pattern in enumerate(m.patterns): code_block = BasicBlock() + next_block = BasicBlock() if isinstance(pattern, ValuePattern): - next_block = BasicBlock() - cond = builder.binary_op( subject, builder.accept(pattern.expr), "==", pattern.expr.line ) @@ -941,8 +940,6 @@ def build_match_body( if isinstance(pattern, OrPattern): assert all(isinstance(p, ValuePattern) for p in pattern.patterns) - next_block = BasicBlock() - for p in pattern.patterns: cond = builder.binary_op(subject, builder.accept(p.expr), "==", p.expr.line) # type: ignore @@ -980,8 +977,6 @@ def build_match_body( assert not pattern.pattern assert not pattern.name - next_block = BasicBlock() - builder.goto(code_block) build_match_body(i, code_block, next_block) @@ -989,8 +984,6 @@ def build_match_body( builder.activate_block(next_block) if isinstance(pattern, SingletonPattern): - next_block = BasicBlock() - if pattern.value is None: obj = builder.none_object() elif pattern.value is True: From 4d8cbc39fafb65487535f2cb94a004f78f70a257 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:10:01 -0700 Subject: [PATCH 28/97] Cleanup --- mypyc/irbuild/statement.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index f75eed3f0def7..d09711a1fc537 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -947,31 +947,28 @@ def build_match_body( builder.activate_block(next_block) next_block = BasicBlock() - end_block = BasicBlock() - builder.goto(end_block) + builder.goto(next_block) - build_match_body(i, code_block, end_block) + build_match_body(i, code_block, next_block) - builder.activate_block(end_block) + builder.activate_block(next_block) if isinstance(pattern, ClassPattern): assert not pattern.positionals assert not pattern.keyword_keys assert not pattern.keyword_values - end_block = BasicBlock() - cond = builder.call_c( slow_isinstance_op, [subject, builder.accept(pattern.class_ref)], pattern.line ) - builder.add_bool_branch(cond, code_block, end_block) + builder.add_bool_branch(cond, code_block, next_block) - build_match_body(i, code_block, end_block) + build_match_body(i, code_block, next_block) - builder.activate_block(end_block) + builder.activate_block(next_block) if isinstance(pattern, AsPattern): assert not pattern.pattern From a3325518bab189c3e0eec7eca352de14c643ef30 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:11:43 -0700 Subject: [PATCH 29/97] Add elifs --- mypyc/irbuild/statement.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index d09711a1fc537..e805e9dc83bee 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -937,7 +937,7 @@ def build_match_body( builder.activate_block(next_block) - if isinstance(pattern, OrPattern): + elif isinstance(pattern, OrPattern): assert all(isinstance(p, ValuePattern) for p in pattern.patterns) for p in pattern.patterns: @@ -953,7 +953,7 @@ def build_match_body( builder.activate_block(next_block) - if isinstance(pattern, ClassPattern): + elif isinstance(pattern, ClassPattern): assert not pattern.positionals assert not pattern.keyword_keys assert not pattern.keyword_values @@ -970,7 +970,7 @@ def build_match_body( builder.activate_block(next_block) - if isinstance(pattern, AsPattern): + elif isinstance(pattern, AsPattern): assert not pattern.pattern assert not pattern.name @@ -980,7 +980,7 @@ def build_match_body( builder.activate_block(next_block) - if isinstance(pattern, SingletonPattern): + elif isinstance(pattern, SingletonPattern): if pattern.value is None: obj = builder.none_object() elif pattern.value is True: From 8fd654651a2de4b4cb42365a05772f546ea39200 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:14:32 -0700 Subject: [PATCH 30/97] Move pattern building to its own function --- mypyc/irbuild/statement.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index e805e9dc83bee..af2d1517b29bc 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -42,7 +42,14 @@ YieldExpr, YieldFromExpr, ) -from mypy.patterns import AsPattern, ClassPattern, OrPattern, SingletonPattern, ValuePattern +from mypy.patterns import ( + AsPattern, + ClassPattern, + OrPattern, + Pattern, + SingletonPattern, + ValuePattern, +) from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -923,10 +930,9 @@ def build_match_body( builder.accept(m.bodies[index]) builder.goto(final_block) - for i, pattern in enumerate(m.patterns): - code_block = BasicBlock() - next_block = BasicBlock() - + def build_pattern( + pattern: Pattern, code_block: BasicBlock, next_block: BasicBlock + ) -> None: if isinstance(pattern, ValuePattern): cond = builder.binary_op( subject, builder.accept(pattern.expr), "==", pattern.expr.line @@ -998,4 +1004,10 @@ def build_match_body( builder.activate_block(next_block) + for i, pattern in enumerate(m.patterns): + code_block = BasicBlock() + next_block = BasicBlock() + + build_pattern(pattern, code_block, next_block) + builder.goto_and_activate(final_block) From 42b8d589e747ad823948e66fdffe6ad23f45eaf8 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:30:54 -0700 Subject: [PATCH 31/97] Move body builder out of each if stmt --- mypyc/irbuild/statement.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index af2d1517b29bc..ce7ec02f1f0c7 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -932,17 +932,13 @@ def build_match_body( def build_pattern( pattern: Pattern, code_block: BasicBlock, next_block: BasicBlock - ) -> None: + ) -> BasicBlock: if isinstance(pattern, ValuePattern): cond = builder.binary_op( subject, builder.accept(pattern.expr), "==", pattern.expr.line ) builder.add_bool_branch(cond, code_block, next_block) - build_match_body(i, code_block, next_block) - - builder.activate_block(next_block) - elif isinstance(pattern, OrPattern): assert all(isinstance(p, ValuePattern) for p in pattern.patterns) @@ -955,10 +951,6 @@ def build_pattern( builder.goto(next_block) - build_match_body(i, code_block, next_block) - - builder.activate_block(next_block) - elif isinstance(pattern, ClassPattern): assert not pattern.positionals assert not pattern.keyword_keys @@ -972,20 +964,12 @@ def build_pattern( builder.add_bool_branch(cond, code_block, next_block) - build_match_body(i, code_block, next_block) - - builder.activate_block(next_block) - elif isinstance(pattern, AsPattern): assert not pattern.pattern assert not pattern.name builder.goto(code_block) - build_match_body(i, code_block, next_block) - - builder.activate_block(next_block) - elif isinstance(pattern, SingletonPattern): if pattern.value is None: obj = builder.none_object() @@ -998,16 +982,15 @@ def build_pattern( builder.add_bool_branch(cond, code_block, next_block) - builder.goto(code_block) - - build_match_body(i, code_block, next_block) - - builder.activate_block(next_block) + return next_block for i, pattern in enumerate(m.patterns): code_block = BasicBlock() next_block = BasicBlock() - build_pattern(pattern, code_block, next_block) + next_block = build_pattern(pattern, code_block, next_block) + + build_match_body(i, code_block, next_block) + builder.activate_block(next_block) builder.goto_and_activate(final_block) From 130647e08d3a2432f7de40abb34c7bce08bff2f5 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 13 Oct 2022 17:44:51 -0700 Subject: [PATCH 32/97] Add recursive matching --- mypyc/irbuild/statement.py | 5 +--- mypyc/test-data/irbuild-match.test | 45 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index ce7ec02f1f0c7..0a5d6d3ce0165 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -940,12 +940,9 @@ def build_pattern( builder.add_bool_branch(cond, code_block, next_block) elif isinstance(pattern, OrPattern): - assert all(isinstance(p, ValuePattern) for p in pattern.patterns) - for p in pattern.patterns: - cond = builder.binary_op(subject, builder.accept(p.expr), "==", p.expr.line) # type: ignore + build_pattern(p, code_block, next_block) - builder.add_bool_branch(cond, code_block, next_block) builder.activate_block(next_block) next_block = BasicBlock() diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 3e4c9173f05fc..e4cec5b0f8fa2 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -435,3 +435,48 @@ L6: L7: r30 = box(None, 1) return r30 +[case testMatchRecursiveOrPattern_python3_10] +def f(): + match 1: + case 1 | int(): + print("matched") +[out] +def f(): + r0 :: bit + r1, r2 :: object + r3 :: int32 + r4 :: bit + r5 :: bool + r6 :: str + r7 :: object + r8 :: str + r9 :: object + r10 :: object[1] + r11 :: object_ptr + r12, r13 :: object +L0: + r0 = 2 == 2 + if r0 goto L3 else goto L1 :: bool +L1: + r1 = load_address PyLong_Type + r2 = object 1 + r3 = PyObject_IsInstance(r2, r1) + r4 = r3 >= 0 :: signed + r5 = truncate r3: int32 to builtins.bool + if r5 goto L3 else goto L2 :: bool +L2: + goto L4 +L3: + r6 = 'matched' + r7 = builtins :: module + r8 = 'print' + r9 = CPyObject_GetAttr(r7, r8) + r10 = [r6] + r11 = load_address r10 + r12 = _PyObject_Vectorcall(r9, r11, 1, 0) + keep_alive r6 + goto L5 +L4: +L5: + r13 = box(None, 1) + return r13 From cc9f3759296d93e58311d664c2764f45f94db3ed Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 16:59:27 -0700 Subject: [PATCH 33/97] Add basic AsPattern support --- mypyc/irbuild/statement.py | 4 ++-- mypyc/test-data/irbuild-match.test | 33 ++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 0a5d6d3ce0165..ec885f10bbeee 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -962,8 +962,8 @@ def build_pattern( builder.add_bool_branch(cond, code_block, next_block) elif isinstance(pattern, AsPattern): - assert not pattern.pattern - assert not pattern.name + if pattern.pattern: + build_pattern(pattern.pattern, code_block, next_block) builder.goto(code_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index e4cec5b0f8fa2..9d3a7c65bebde 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -480,3 +480,36 @@ L4: L5: r13 = box(None, 1) return r13 +[case testMatchAsPattern_python3_10] +def f(): + match 123: + case 123 as x: + # print(x) + print("matched") +[out] +def f(): + r0 :: bit + r1 :: str + r2 :: object + r3 :: str + r4 :: object + r5 :: object[1] + r6 :: object_ptr + r7, r8 :: object +L0: + r0 = 246 == 246 + if r0 goto L1 else goto L2 :: bool +L1: + r1 = 'matched' + r2 = builtins :: module + r3 = 'print' + r4 = CPyObject_GetAttr(r2, r3) + r5 = [r1] + r6 = load_address r5 + r7 = _PyObject_Vectorcall(r4, r6, 1, 0) + keep_alive r1 + goto L3 +L2: +L3: + r8 = box(None, 1) + return r8 From 331531b76551ab81a229779ed0054d16ddac0cff Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 17:16:08 -0700 Subject: [PATCH 34/97] Add groundwork for captured patterns --- mypyc/irbuild/statement.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index ec885f10bbeee..9e235efb75d38 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -932,7 +932,9 @@ def build_match_body( def build_pattern( pattern: Pattern, code_block: BasicBlock, next_block: BasicBlock - ) -> BasicBlock: + ) -> tuple[BasicBlock, list[Pattern]]: + captured: list[Pattern] = [] + if isinstance(pattern, ValuePattern): cond = builder.binary_op( subject, builder.accept(pattern.expr), "==", pattern.expr.line @@ -941,7 +943,7 @@ def build_pattern( elif isinstance(pattern, OrPattern): for p in pattern.patterns: - build_pattern(p, code_block, next_block) + next_block, captured = build_pattern(p, code_block, next_block) builder.activate_block(next_block) next_block = BasicBlock() @@ -963,7 +965,7 @@ def build_pattern( elif isinstance(pattern, AsPattern): if pattern.pattern: - build_pattern(pattern.pattern, code_block, next_block) + next_block, captured = build_pattern(pattern.pattern, code_block, next_block) builder.goto(code_block) @@ -979,13 +981,13 @@ def build_pattern( builder.add_bool_branch(cond, code_block, next_block) - return next_block + return next_block, captured for i, pattern in enumerate(m.patterns): code_block = BasicBlock() next_block = BasicBlock() - next_block = build_pattern(pattern, code_block, next_block) + next_block, _ = build_pattern(pattern, code_block, next_block) build_match_body(i, code_block, next_block) builder.activate_block(next_block) From 92b05b8c4201b02c4138a1ebf7d58c0c4d1b26dd Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 17:55:52 -0700 Subject: [PATCH 35/97] Convert to visitor --- mypyc/irbuild/statement.py | 134 +++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 58 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 9e235efb75d38..89a2cba1c477b 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -47,6 +47,7 @@ ClassPattern, OrPattern, Pattern, + PatternVisitor, SingletonPattern, ValuePattern, ) @@ -909,87 +910,104 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: return emit_yield_from_or_await(builder, builder.accept(o.expr), o.line, is_await=True) -def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: - subject = builder.accept(m.subject) - - final_block = BasicBlock() +class MatchVisitor(PatternVisitor[None]): + builder: IRBuilder + code_block: BasicBlock + next_block: BasicBlock + final_block: BasicBlock + subject: Value - def build_match_body( - index: int, code_block: BasicBlock, next_block: BasicBlock + def __init__( + self, + builder: IRBuilder, + match_node: MatchStmt, ) -> None: - builder.activate_block(code_block) + self.builder = builder - if guard := m.guards[index]: - new_code_block = BasicBlock() + self.code_block = BasicBlock() + self.next_block = BasicBlock() + self.final_block = BasicBlock() - cond = builder.accept(guard) - builder.add_bool_branch(cond, new_code_block, next_block) + self.subject = builder.accept(match_node.subject) - builder.activate_block(new_code_block) + def visit_value_pattern(self, pattern: ValuePattern) -> None: + cond = self.builder.binary_op( + self.subject, + self.builder.accept(pattern.expr), + "==", + pattern.expr.line + ) + self.builder.add_bool_branch(cond, self.code_block, self.next_block) - builder.accept(m.bodies[index]) - builder.goto(final_block) + def visit_or_pattern(self, pattern: OrPattern) -> None: + for p in pattern.patterns: + p.accept(self) - def build_pattern( - pattern: Pattern, code_block: BasicBlock, next_block: BasicBlock - ) -> tuple[BasicBlock, list[Pattern]]: - captured: list[Pattern] = [] + self.builder.activate_block(self.next_block) + self.next_block = BasicBlock() - if isinstance(pattern, ValuePattern): - cond = builder.binary_op( - subject, builder.accept(pattern.expr), "==", pattern.expr.line - ) - builder.add_bool_branch(cond, code_block, next_block) + self.builder.goto(self.next_block) + + def visit_class_pattern(self, pattern: ClassPattern) -> None: + assert not pattern.positionals + assert not pattern.keyword_keys + assert not pattern.keyword_values - elif isinstance(pattern, OrPattern): - for p in pattern.patterns: - next_block, captured = build_pattern(p, code_block, next_block) + cond = self.builder.call_c( + slow_isinstance_op, + [self.subject, self.builder.accept(pattern.class_ref)], + pattern.line + ) - builder.activate_block(next_block) - next_block = BasicBlock() + self.builder.add_bool_branch(cond, self.code_block, self.next_block) - builder.goto(next_block) + def visit_as_pattern(self, pattern: AsPattern) -> None: + if pattern.pattern: + pattern.pattern.accept(self) - elif isinstance(pattern, ClassPattern): - assert not pattern.positionals - assert not pattern.keyword_keys - assert not pattern.keyword_values + self.builder.goto(self.code_block) - cond = builder.call_c( - slow_isinstance_op, - [subject, builder.accept(pattern.class_ref)], - pattern.line - ) + def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: + if pattern.value is None: + obj = self.builder.none_object() + elif pattern.value is True: + obj = self.builder.true() + else: + obj = self.builder.false() - builder.add_bool_branch(cond, code_block, next_block) + cond = self.builder.binary_op(self.subject, obj, "is", pattern.line) - elif isinstance(pattern, AsPattern): - if pattern.pattern: - next_block, captured = build_pattern(pattern.pattern, code_block, next_block) + self.builder.add_bool_branch(cond, self.code_block, self.next_block) - builder.goto(code_block) - elif isinstance(pattern, SingletonPattern): - if pattern.value is None: - obj = builder.none_object() - elif pattern.value is True: - obj = builder.true() - else: - obj = builder.false() +def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: + final_block = BasicBlock() + + def build_match_body( + index: int, code_block: BasicBlock, next_block: BasicBlock + ) -> None: + builder.activate_block(code_block) - cond = builder.binary_op(subject, obj, "is", pattern.line) + if guard := m.guards[index]: + new_code_block = BasicBlock() + + cond = builder.accept(guard) + builder.add_bool_branch(cond, new_code_block, next_block) - builder.add_bool_branch(cond, code_block, next_block) + builder.activate_block(new_code_block) + + builder.accept(m.bodies[index]) + builder.goto(final_block) - return next_block, captured + mv = MatchVisitor(builder, m) for i, pattern in enumerate(m.patterns): - code_block = BasicBlock() - next_block = BasicBlock() + mv.code_block = BasicBlock() + mv.next_block = BasicBlock() - next_block, _ = build_pattern(pattern, code_block, next_block) + pattern.accept(mv) - build_match_body(i, code_block, next_block) - builder.activate_block(next_block) + build_match_body(i, mv.code_block, mv.next_block) + builder.activate_block(mv.next_block) builder.goto_and_activate(final_block) From 2e1a2adeba71ef466900d52f7d9544c343fa31f8 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 18:05:05 -0700 Subject: [PATCH 36/97] Rewrite using visitor --- mypyc/irbuild/statement.py | 71 +++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 89a2cba1c477b..687cfe1b7294b 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -51,6 +51,7 @@ SingletonPattern, ValuePattern, ) +from mypy.traverser import TraverserVisitor from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -910,26 +911,53 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: return emit_yield_from_or_await(builder, builder.accept(o.expr), o.line, is_await=True) -class MatchVisitor(PatternVisitor[None]): + +class MatchVisitor(TraverserVisitor): builder: IRBuilder code_block: BasicBlock next_block: BasicBlock final_block: BasicBlock subject: Value + match: MatchStmt - def __init__( - self, - builder: IRBuilder, - match_node: MatchStmt, - ) -> None: + def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: self.builder = builder self.code_block = BasicBlock() self.next_block = BasicBlock() self.final_block = BasicBlock() + self.match = match_node self.subject = builder.accept(match_node.subject) + def build_match_body( + self, index: int, code_block: BasicBlock, next_block: BasicBlock + ) -> None: + self.builder.activate_block(code_block) + + if guard := self.match.guards[index]: + new_code_block = BasicBlock() + + cond = self.builder.accept(guard) + self.builder.add_bool_branch(cond, new_code_block, next_block) + + self.builder.activate_block(new_code_block) + + self.builder.accept(self.match.bodies[index]) + self.builder.goto(self.final_block) + + def visit_match_stmt(self, m: MatchStmt) -> None: + for i, pattern in enumerate(m.patterns): + self.code_block = BasicBlock() + self.next_block = BasicBlock() + + pattern.accept(self) + + self.build_match_body(i, self.code_block, self.next_block) + self.builder.activate_block(self.next_block) + + self.builder.goto_and_activate(self.final_block) + def visit_value_pattern(self, pattern: ValuePattern) -> None: cond = self.builder.binary_op( self.subject, @@ -981,33 +1009,4 @@ def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: - final_block = BasicBlock() - - def build_match_body( - index: int, code_block: BasicBlock, next_block: BasicBlock - ) -> None: - builder.activate_block(code_block) - - if guard := m.guards[index]: - new_code_block = BasicBlock() - - cond = builder.accept(guard) - builder.add_bool_branch(cond, new_code_block, next_block) - - builder.activate_block(new_code_block) - - builder.accept(m.bodies[index]) - builder.goto(final_block) - - mv = MatchVisitor(builder, m) - - for i, pattern in enumerate(m.patterns): - mv.code_block = BasicBlock() - mv.next_block = BasicBlock() - - pattern.accept(mv) - - build_match_body(i, mv.code_block, mv.next_block) - builder.activate_block(mv.next_block) - - builder.goto_and_activate(final_block) + m.accept(MatchVisitor(builder, m)) From c55236e7911442772931bd074bc186c5560ad1d0 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 20:35:32 -0700 Subject: [PATCH 37/97] Add basic AsPattern support for ValuePattern: Since the AsPattern binds a value to name, it is very coupled to the pattern that is actually being built. There is (in general) 3 stages to a typical pattern match: 1. Converting the expression to a value/register 2. Checking if the match was successful 3. Assigning the value to a variable if it is an AsPattern For this to work though, steps 2 and 3 need to be done at the same time, because the jump must jump to a block which sets the variable. Maybe I am wrong, and I can just assign to a variable, and if the match did not pass, it will simply be reassigned. --- mypyc/irbuild/statement.py | 14 +++++++++++++- mypyc/test-data/irbuild-match.test | 13 ++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 687cfe1b7294b..036da1f44784e 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -920,6 +920,8 @@ class MatchVisitor(TraverserVisitor): subject: Value match: MatchStmt + as_pattern: AsPattern | None = None + def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: self.builder = builder @@ -959,12 +961,21 @@ def visit_match_stmt(self, m: MatchStmt) -> None: self.builder.goto_and_activate(self.final_block) def visit_value_pattern(self, pattern: ValuePattern) -> None: + value = self.builder.accept(pattern.expr) + cond = self.builder.binary_op( self.subject, - self.builder.accept(pattern.expr), + value, "==", pattern.expr.line ) + + if self.as_pattern and self.as_pattern.name: + target = self.builder.get_assignment_target(self.as_pattern.name) + self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + + self.as_pattern = None + self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_or_pattern(self, pattern: OrPattern) -> None: @@ -991,6 +1002,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: + self.as_pattern = pattern pattern.pattern.accept(self) self.builder.goto(self.code_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 9d3a7c65bebde..16b7ae783f343 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -484,13 +484,11 @@ L5: def f(): match 123: case 123 as x: - # print(x) - print("matched") + print(x) [out] def f(): r0 :: bit - r1 :: str - r2 :: object + r1, x, r2 :: object r3 :: str r4 :: object r5 :: object[1] @@ -498,16 +496,17 @@ def f(): r7, r8 :: object L0: r0 = 246 == 246 + r1 = object 123 + x = r1 if r0 goto L1 else goto L2 :: bool L1: - r1 = 'matched' r2 = builtins :: module r3 = 'print' r4 = CPyObject_GetAttr(r2, r3) - r5 = [r1] + r5 = [x] r6 = load_address r5 r7 = _PyObject_Vectorcall(r4, r6, 1, 0) - keep_alive r1 + keep_alive x goto L3 L2: L3: From 0af9b242d6739c974da9373033a96d3e66f37e0c Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 20:48:16 -0700 Subject: [PATCH 38/97] Add Or pattern support for AsPattern --- mypyc/irbuild/statement.py | 3 +-- mypyc/test-data/irbuild-match.test | 41 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 036da1f44784e..d9f9da066a373 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -974,8 +974,6 @@ def visit_value_pattern(self, pattern: ValuePattern) -> None: target = self.builder.get_assignment_target(self.as_pattern.name) self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore - self.as_pattern = None - self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_or_pattern(self, pattern: OrPattern) -> None: @@ -1004,6 +1002,7 @@ def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: self.as_pattern = pattern pattern.pattern.accept(self) + self.as_pattern = None self.builder.goto(self.code_block) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 16b7ae783f343..478e0f85dc3d2 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -512,3 +512,44 @@ L2: L3: r8 = box(None, 1) return r8 +[case testMatchAsPatternOnOrPattern_python3_10] +def f(): + match 1: + case (1 | 2) as x: + print(x) +[out] +def f(): + r0 :: bit + r1, x :: object + r2 :: bit + r3, r4 :: object + r5 :: str + r6 :: object + r7 :: object[1] + r8 :: object_ptr + r9, r10 :: object +L0: + r0 = 2 == 2 + r1 = object 1 + x = r1 + if r0 goto L3 else goto L1 :: bool +L1: + r2 = 2 == 4 + r3 = object 2 + x = r3 + if r2 goto L3 else goto L2 :: bool +L2: + goto L4 +L3: + r4 = builtins :: module + r5 = 'print' + r6 = CPyObject_GetAttr(r4, r5) + r7 = [x] + r8 = load_address r7 + r9 = _PyObject_Vectorcall(r6, r8, 1, 0) + keep_alive x + goto L5 +L4: +L5: + r10 = box(None, 1) + return r10 From aedd1a6751fd98c8e66bfce0c36a229824a835ce Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 21:06:01 -0700 Subject: [PATCH 39/97] Add AsPattern support for class pattern --- mypyc/irbuild/statement.py | 8 ++++++ mypyc/test-data/irbuild-match.test | 40 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index d9f9da066a373..6f6b34b0a1e9b 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -996,6 +996,14 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: pattern.line ) + if self.as_pattern and self.as_pattern.name: + # TODO: add ability to handle class pattern when not at top level + value = self.subject + + target = self.builder.get_assignment_target(self.as_pattern.name) + self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + + self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_as_pattern(self, pattern: AsPattern) -> None: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 478e0f85dc3d2..9f018d60f3220 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -553,3 +553,43 @@ L4: L5: r10 = box(None, 1) return r10 +[case testMatchAsPatternOnClassPattern_python3_10] +def f(): + match 123: + case int() as i: + print(i) +[out] +def f(): + r0, r1 :: object + r2 :: int32 + r3 :: bit + r4 :: bool + i :: int + r5 :: object + r6 :: str + r7, r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object +L0: + r0 = load_address PyLong_Type + r1 = object 123 + r2 = PyObject_IsInstance(r1, r0) + r3 = r2 >= 0 :: signed + r4 = truncate r2: int32 to builtins.bool + i = 246 + if r4 goto L1 else goto L2 :: bool +L1: + r5 = builtins :: module + r6 = 'print' + r7 = CPyObject_GetAttr(r5, r6) + r8 = box(int, i) + r9 = [r8] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r7, r10, 1, 0) + keep_alive r8 + goto L3 +L2: +L3: + r12 = box(None, 1) + return r12 From 5b955965cbae9fcab26de0504b70189f1d3b60d0 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 15 Oct 2022 21:10:50 -0700 Subject: [PATCH 40/97] Cleanup --- mypyc/irbuild/statement.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 6f6b34b0a1e9b..961a4b71d7711 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -970,9 +970,7 @@ def visit_value_pattern(self, pattern: ValuePattern) -> None: pattern.expr.line ) - if self.as_pattern and self.as_pattern.name: - target = self.builder.get_assignment_target(self.as_pattern.name) - self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + self.bind_as_pattern(value) self.builder.add_bool_branch(cond, self.code_block, self.next_block) @@ -996,13 +994,8 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: pattern.line ) - if self.as_pattern and self.as_pattern.name: - # TODO: add ability to handle class pattern when not at top level - value = self.subject - - target = self.builder.get_assignment_target(self.as_pattern.name) - self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore - + # TODO: add ability to handle class pattern when not at top level + self.bind_as_pattern(self.subject) self.builder.add_bool_branch(cond, self.code_block, self.next_block) @@ -1026,6 +1019,11 @@ def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) + def bind_as_pattern(self, value: Value) -> None: + if self.as_pattern and self.as_pattern.name: + target = self.builder.get_assignment_target(self.as_pattern.name) + self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: m.accept(MatchVisitor(builder, m)) From 25d0edc6a8a37ec947a3ef9261e9296f00345c05 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 18 Oct 2022 17:57:39 -0700 Subject: [PATCH 41/97] Add basic positional arg parsing --- mypyc/irbuild/statement.py | 39 +++++++++++++++- mypyc/test-data/irbuild-match.test | 71 ++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 961a4b71d7711..992f3cebfcd18 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -37,6 +37,8 @@ TempNode, TryStmt, TupleExpr, + TypeInfo, + Var, WhileStmt, WithStmt, YieldExpr, @@ -52,6 +54,7 @@ ValuePattern, ) from mypy.traverser import TraverserVisitor +from mypy.types import Instance, LiteralType, TupleType from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -984,21 +987,53 @@ def visit_or_pattern(self, pattern: OrPattern) -> None: self.builder.goto(self.next_block) def visit_class_pattern(self, pattern: ClassPattern) -> None: - assert not pattern.positionals assert not pattern.keyword_keys assert not pattern.keyword_values + # TODO: add ability to handle class pattern when not at top level cond = self.builder.call_c( slow_isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line ) - # TODO: add ability to handle class pattern when not at top level self.bind_as_pattern(self.subject) self.builder.add_bool_branch(cond, self.code_block, self.next_block) + if pattern.positionals: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + assert len(pattern.positionals) == 1 + + expr = pattern.positionals[0] + assert isinstance(expr, ValuePattern) + + node = pattern.class_ref.node + assert isinstance(node, TypeInfo) + + ty = node.names.get("__match_args__") + assert ty and isinstance(ty.type, TupleType) + + match_args = [] + + for item in ty.type.items: + assert isinstance(item, Instance) and item.last_known_value + match_args.append(item.last_known_value.value) + + + value = self.builder.py_get_attr(self.subject, match_args[0], expr.line) + + cond2 = self.builder.binary_op( + value, + self.builder.accept(expr.expr), + "==", + expr.line, + ) + + self.builder.add_bool_branch(cond2, self.code_block, self.next_block) + def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: self.as_pattern = pattern diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 9f018d60f3220..5a62bffe214b7 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -593,3 +593,74 @@ L2: L3: r12 = box(None, 1) return r12 +[case testMatchClassPatternWithPositionalArgs_python3_10] +class Position: + __match_args__ = ("x", "y", "z") + + x: int + y: int + z: int + +def f(x): + match x: + # case Position(1, 2, 3): + case Position(1): + print("matched") +[out] +def Position.__mypyc_defaults_setup(__mypyc_self__): + __mypyc_self__ :: __main__.Position + r0, r1, r2 :: str + r3 :: tuple[str, str, str] +L0: + r0 = 'x' + r1 = 'y' + r2 = 'z' + r3 = (r0, r1, r2) + __mypyc_self__.__match_args__ = r3 + return 1 +def f(x): + x, r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4 :: str + r5, r6, r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11 :: str + r12 :: object + r13 :: str + r14 :: object + r15 :: object[1] + r16 :: object_ptr + r17, r18 :: object +L0: + r0 = __main__.Position :: type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L3 :: bool +L1: + r4 = 'x' + r5 = CPyObject_GetAttr(x, r4) + r6 = object 1 + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L2 else goto L3 :: bool +L2: + r11 = 'matched' + r12 = builtins :: module + r13 = 'print' + r14 = CPyObject_GetAttr(r12, r13) + r15 = [r11] + r16 = load_address r15 + r17 = _PyObject_Vectorcall(r14, r16, 1, 0) + keep_alive r11 + goto L4 +L3: +L4: + r18 = box(None, 1) + return r18 From b85f1781e8ccf13df358e3b2776386578de35457 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:05:39 -0700 Subject: [PATCH 42/97] Add support for variable number of positional args --- mypyc/irbuild/statement.py | 39 +++++++++-------- mypyc/test-data/irbuild-match.test | 69 +++++++++++++++++++++--------- 2 files changed, 69 insertions(+), 39 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 992f3cebfcd18..612c43a9bb562 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -1002,37 +1002,40 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) if pattern.positionals: - self.builder.activate_block(self.code_block) - self.code_block = BasicBlock() - - assert len(pattern.positionals) == 1 - - expr = pattern.positionals[0] - assert isinstance(expr, ValuePattern) - node = pattern.class_ref.node assert isinstance(node, TypeInfo) ty = node.names.get("__match_args__") assert ty and isinstance(ty.type, TupleType) - match_args = [] + match_args: list[str] = [] for item in ty.type.items: assert isinstance(item, Instance) and item.last_known_value - match_args.append(item.last_known_value.value) + value = item.last_known_value.value + assert isinstance(value, str) - value = self.builder.py_get_attr(self.subject, match_args[0], expr.line) + match_args.append(value) - cond2 = self.builder.binary_op( - value, - self.builder.accept(expr.expr), - "==", - expr.line, - ) + for i, expr in enumerate(pattern.positionals): + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + assert isinstance(expr, ValuePattern) + + value = self.builder.py_get_attr( + self.subject, match_args[i], expr.line + ) + + cond = self.builder.binary_op( + value, + self.builder.accept(expr.expr), + "==", + expr.line, + ) - self.builder.add_bool_branch(cond2, self.code_block, self.next_block) + self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 5a62bffe214b7..1466ea598a0dd 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -603,8 +603,7 @@ class Position: def f(x): match x: - # case Position(1, 2, 3): - case Position(1): + case Position(1, 2, 3): print("matched") [out] def Position.__mypyc_defaults_setup(__mypyc_self__): @@ -629,18 +628,28 @@ def f(x): r9 :: bit r10 :: bool r11 :: str - r12 :: object - r13 :: str - r14 :: object - r15 :: object[1] - r16 :: object_ptr - r17, r18 :: object + r12, r13, r14 :: object + r15 :: int32 + r16 :: bit + r17 :: bool + r18 :: str + r19, r20, r21 :: object + r22 :: int32 + r23 :: bit + r24 :: bool + r25 :: str + r26 :: object + r27 :: str + r28 :: object + r29 :: object[1] + r30 :: object_ptr + r31, r32 :: object L0: r0 = __main__.Position :: type r1 = PyObject_IsInstance(x, r0) r2 = r1 >= 0 :: signed r3 = truncate r1: int32 to builtins.bool - if r3 goto L1 else goto L3 :: bool + if r3 goto L1 else goto L5 :: bool L1: r4 = 'x' r5 = CPyObject_GetAttr(x, r4) @@ -649,18 +658,36 @@ L1: r8 = PyObject_IsTrue(r7) r9 = r8 >= 0 :: signed r10 = truncate r8: int32 to builtins.bool - if r10 goto L2 else goto L3 :: bool + if r10 goto L2 else goto L5 :: bool L2: - r11 = 'matched' - r12 = builtins :: module - r13 = 'print' - r14 = CPyObject_GetAttr(r12, r13) - r15 = [r11] - r16 = load_address r15 - r17 = _PyObject_Vectorcall(r14, r16, 1, 0) - keep_alive r11 - goto L4 + r11 = 'y' + r12 = CPyObject_GetAttr(x, r11) + r13 = object 2 + r14 = PyObject_RichCompare(r12, r13, 2) + r15 = PyObject_IsTrue(r14) + r16 = r15 >= 0 :: signed + r17 = truncate r15: int32 to builtins.bool + if r17 goto L3 else goto L5 :: bool L3: + r18 = 'z' + r19 = CPyObject_GetAttr(x, r18) + r20 = object 3 + r21 = PyObject_RichCompare(r19, r20, 2) + r22 = PyObject_IsTrue(r21) + r23 = r22 >= 0 :: signed + r24 = truncate r22: int32 to builtins.bool + if r24 goto L4 else goto L5 :: bool L4: - r18 = box(None, 1) - return r18 + r25 = 'matched' + r26 = builtins :: module + r27 = 'print' + r28 = CPyObject_GetAttr(r26, r27) + r29 = [r25] + r30 = load_address r29 + r31 = _PyObject_Vectorcall(r28, r30, 1, 0) + keep_alive r25 + goto L6 +L5: +L6: + r32 = box(None, 1) + return r32 From 92935499ae95b7930734d6f10ff052a6b36a5992 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 18 Oct 2022 18:10:27 -0700 Subject: [PATCH 43/97] Use self.code_block --- mypyc/irbuild/statement.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 612c43a9bb562..402941c1c5893 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -941,12 +941,12 @@ def build_match_body( self.builder.activate_block(code_block) if guard := self.match.guards[index]: - new_code_block = BasicBlock() + self.code_block = BasicBlock() cond = self.builder.accept(guard) - self.builder.add_bool_branch(cond, new_code_block, next_block) + self.builder.add_bool_branch(cond, self.code_block, next_block) - self.builder.activate_block(new_code_block) + self.builder.activate_block(self.code_block) self.builder.accept(self.match.bodies[index]) self.builder.goto(self.final_block) From 2132f783758f879a3b7a22bfdf0c311a1a4fa984 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Wed, 19 Oct 2022 15:37:27 -0700 Subject: [PATCH 44/97] Add support for keyword class patterns --- mypyc/irbuild/statement.py | 20 +++++-- mypyc/test-data/irbuild-match.test | 85 ++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 402941c1c5893..a061ff12b6de5 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -987,9 +987,6 @@ def visit_or_pattern(self, pattern: OrPattern) -> None: self.builder.goto(self.next_block) def visit_class_pattern(self, pattern: ClassPattern) -> None: - assert not pattern.keyword_keys - assert not pattern.keyword_values - # TODO: add ability to handle class pattern when not at top level cond = self.builder.call_c( slow_isinstance_op, @@ -1037,6 +1034,23 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) + for key, value in zip(pattern.keyword_keys, pattern.keyword_values): + assert isinstance(value, ValuePattern) + + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + attr = self.builder.py_get_attr(self.subject, key, value.line) + + cond = self.builder.binary_op( + attr, + self.builder.accept(value.expr), + "==", + value.line, + ) + + self.builder.add_bool_branch(cond, self.code_block, self.next_block) + def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: self.as_pattern = pattern diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 1466ea598a0dd..a97168f6bc515 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -691,3 +691,88 @@ L5: L6: r32 = box(None, 1) return r32 +[case testMatchClassPatternWithKeywordPatterns_python3_10] +class Position: + x: int + y: int + z: int + +def f(x): + match x: + case Position(z=1, y=2, x=3): + print("matched") +[out] +def f(x): + x, r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4 :: str + r5, r6, r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11 :: str + r12, r13, r14 :: object + r15 :: int32 + r16 :: bit + r17 :: bool + r18 :: str + r19, r20, r21 :: object + r22 :: int32 + r23 :: bit + r24 :: bool + r25 :: str + r26 :: object + r27 :: str + r28 :: object + r29 :: object[1] + r30 :: object_ptr + r31, r32 :: object +L0: + r0 = __main__.Position :: type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L5 :: bool +L1: + r4 = 'z' + r5 = CPyObject_GetAttr(x, r4) + r6 = object 1 + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L2 else goto L5 :: bool +L2: + r11 = 'y' + r12 = CPyObject_GetAttr(x, r11) + r13 = object 2 + r14 = PyObject_RichCompare(r12, r13, 2) + r15 = PyObject_IsTrue(r14) + r16 = r15 >= 0 :: signed + r17 = truncate r15: int32 to builtins.bool + if r17 goto L3 else goto L5 :: bool +L3: + r18 = 'x' + r19 = CPyObject_GetAttr(x, r18) + r20 = object 3 + r21 = PyObject_RichCompare(r19, r20, 2) + r22 = PyObject_IsTrue(r21) + r23 = r22 >= 0 :: signed + r24 = truncate r22: int32 to builtins.bool + if r24 goto L4 else goto L5 :: bool +L4: + r25 = 'matched' + r26 = builtins :: module + r27 = 'print' + r28 = CPyObject_GetAttr(r26, r27) + r29 = [r25] + r30 = load_address r29 + r31 = _PyObject_Vectorcall(r28, r30, 1, 0) + keep_alive r25 + goto L6 +L5: +L6: + r32 = box(None, 1) + return r32 From c5dd161874b6a704f2a0b2eb4820000870992cd7 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Wed, 19 Oct 2022 15:58:17 -0700 Subject: [PATCH 45/97] Add better scoping entering --- mypyc/irbuild/statement.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index a061ff12b6de5..ce9fe29e80bc8 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -7,9 +7,10 @@ """ from __future__ import annotations +from contextlib import contextmanager import importlib.util -from typing import Callable, Sequence +from typing import Callable, Generator, Sequence from mypy.nodes import ( AssertStmt, @@ -1053,9 +1054,8 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: - self.as_pattern = pattern - pattern.pattern.accept(self) - self.as_pattern = None + with self.enter_subpattern(pattern, self.subject): + pattern.pattern.accept(self) self.builder.goto(self.code_block) @@ -1076,6 +1076,16 @@ def bind_as_pattern(self, value: Value) -> None: target = self.builder.get_assignment_target(self.as_pattern.name) self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + @contextmanager + def enter_subpattern(self, pattern: AsPattern, subject: Value) -> Generator[None, None, None]: + old_pattern = self.as_pattern + old_subject = self.subject + self.as_pattern = pattern + self.subject = subject + yield + self.subject = old_subject + self.as_pattern = old_pattern + def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: m.accept(MatchVisitor(builder, m)) From c64c343dcb381477178f376e67bd4bdb8dae470a Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:02:42 -0700 Subject: [PATCH 46/97] Split context managers --- mypyc/irbuild/statement.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index ce9fe29e80bc8..10ca38b825eee 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -1054,7 +1054,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: - with self.enter_subpattern(pattern, self.subject): + with self.enter_as_pattern(pattern): pattern.pattern.accept(self) self.builder.goto(self.code_block) @@ -1077,13 +1077,17 @@ def bind_as_pattern(self, value: Value) -> None: self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore @contextmanager - def enter_subpattern(self, pattern: AsPattern, subject: Value) -> Generator[None, None, None]: - old_pattern = self.as_pattern + def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: old_subject = self.subject - self.as_pattern = pattern self.subject = subject yield self.subject = old_subject + + @contextmanager + def enter_as_pattern(self, pattern: AsPattern) -> Generator[None, None, None]: + old_pattern = self.as_pattern + self.as_pattern = pattern + yield self.as_pattern = old_pattern From 49d60c83701a4209d6e3b2638d8a126356e626ba Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:28:29 -0700 Subject: [PATCH 47/97] Support nested patterns in class pattern --- mypyc/irbuild/statement.py | 24 ++--------- mypyc/test-data/irbuild-match.test | 68 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 20 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 10ca38b825eee..79d8765d6ae71 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -1020,37 +1020,21 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - assert isinstance(expr, ValuePattern) - value = self.builder.py_get_attr( self.subject, match_args[i], expr.line ) - cond = self.builder.binary_op( - value, - self.builder.accept(expr.expr), - "==", - expr.line, - ) - - self.builder.add_bool_branch(cond, self.code_block, self.next_block) + with self.enter_subpattern(value): + expr.accept(self) for key, value in zip(pattern.keyword_keys, pattern.keyword_values): - assert isinstance(value, ValuePattern) - self.builder.activate_block(self.code_block) self.code_block = BasicBlock() attr = self.builder.py_get_attr(self.subject, key, value.line) - cond = self.builder.binary_op( - attr, - self.builder.accept(value.expr), - "==", - value.line, - ) - - self.builder.add_bool_branch(cond, self.code_block, self.next_block) + with self.enter_subpattern(attr): + value.accept(self) def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index a97168f6bc515..03c002ce62afc 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -776,3 +776,71 @@ L5: L6: r32 = box(None, 1) return r32 +[case testMatchClassPatternWithNestedPattern_python3_10] +class C: + num: int + +def f(x): + match x: + case C(num=1 | 2): + print("matched") +[out] +def f(x): + x, r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4 :: str + r5, r6, r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11, r12 :: object + r13 :: int32 + r14 :: bit + r15 :: bool + r16 :: str + r17 :: object + r18 :: str + r19 :: object + r20 :: object[1] + r21 :: object_ptr + r22, r23 :: object +L0: + r0 = __main__.C :: type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L2 :: bool +L1: + r4 = 'num' + r5 = CPyObject_GetAttr(x, r4) + r6 = object 1 + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L4 else goto L2 :: bool +L2: + r11 = object 2 + r12 = PyObject_RichCompare(r5, r11, 2) + r13 = PyObject_IsTrue(r12) + r14 = r13 >= 0 :: signed + r15 = truncate r13: int32 to builtins.bool + if r15 goto L4 else goto L3 :: bool +L3: + goto L5 +L4: + r16 = 'matched' + r17 = builtins :: module + r18 = 'print' + r19 = CPyObject_GetAttr(r17, r18) + r20 = [r16] + r21 = load_address r20 + r22 = _PyObject_Vectorcall(r19, r21, 1, 0) + keep_alive r16 + goto L6 +L5: +L6: + r23 = box(None, 1) + return r23 From 1deecb4a897ae8c74c0922d9980858f196d49524 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Wed, 19 Oct 2022 17:04:14 -0700 Subject: [PATCH 48/97] Fix as pattern binding to subpatterns --- mypyc/irbuild/statement.py | 6 +++ mypyc/test-data/irbuild-match.test | 84 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 79d8765d6ae71..837f091c23619 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -980,7 +980,11 @@ def visit_value_pattern(self, pattern: ValuePattern) -> None: def visit_or_pattern(self, pattern: OrPattern) -> None: for p in pattern.patterns: + # Hack to ensure the as pattern is bound to each pattern in the + # "or" pattern, but not every subpattern + backup = self.as_pattern p.accept(self) + self.as_pattern = backup self.builder.activate_block(self.next_block) self.next_block = BasicBlock() @@ -1060,6 +1064,8 @@ def bind_as_pattern(self, value: Value) -> None: target = self.builder.get_assignment_target(self.as_pattern.name) self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + self.as_pattern = None + @contextmanager def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: old_subject = self.subject diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 03c002ce62afc..c11d2d6142353 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -844,3 +844,87 @@ L5: L6: r23 = box(None, 1) return r23 +[case testing_python3_10] +class C: + __match_args__ = ("a", "b") + a: int + b: int + +def f(x): + match x: + case C(1, 2) as y: + print("matched") +[out] +def C.__mypyc_defaults_setup(__mypyc_self__): + __mypyc_self__ :: __main__.C + r0, r1 :: str + r2 :: tuple[str, str] +L0: + r0 = 'a' + r1 = 'b' + r2 = (r0, r1) + __mypyc_self__.__match_args__ = r2 + return 1 +def f(x): + x, r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4, y :: __main__.C + r5 :: str + r6, r7, r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool + r12 :: str + r13, r14, r15 :: object + r16 :: int32 + r17 :: bit + r18 :: bool + r19 :: str + r20 :: object + r21 :: str + r22 :: object + r23 :: object[1] + r24 :: object_ptr + r25, r26 :: object +L0: + r0 = __main__.C :: type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + r4 = cast(__main__.C, x) + y = r4 + if r3 goto L1 else goto L4 :: bool +L1: + r5 = 'a' + r6 = CPyObject_GetAttr(x, r5) + r7 = object 1 + r8 = PyObject_RichCompare(r6, r7, 2) + r9 = PyObject_IsTrue(r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + if r11 goto L2 else goto L4 :: bool +L2: + r12 = 'b' + r13 = CPyObject_GetAttr(x, r12) + r14 = object 2 + r15 = PyObject_RichCompare(r13, r14, 2) + r16 = PyObject_IsTrue(r15) + r17 = r16 >= 0 :: signed + r18 = truncate r16: int32 to builtins.bool + if r18 goto L3 else goto L4 :: bool +L3: + r19 = 'matched' + r20 = builtins :: module + r21 = 'print' + r22 = CPyObject_GetAttr(r20, r21) + r23 = [r19] + r24 = load_address r23 + r25 = _PyObject_Vectorcall(r22, r24, 1, 0) + keep_alive r19 + goto L5 +L4: +L5: + r26 = box(None, 1) + return r26 From df4a146be0c0961bc89016f8618d5d8dcb5a5203 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Wed, 19 Oct 2022 17:25:45 -0700 Subject: [PATCH 49/97] Add positional captures --- mypyc/irbuild/statement.py | 5 +++ mypyc/test-data/irbuild-match.test | 60 ++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 837f091c23619..99a9cab59e4af 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -1045,6 +1045,11 @@ def visit_as_pattern(self, pattern: AsPattern) -> None: with self.enter_as_pattern(pattern): pattern.pattern.accept(self) + elif pattern.name: + target = self.builder.get_assignment_target(pattern.name) + + self.builder.assign(target, self.subject, pattern.line) + self.builder.goto(self.code_block) def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index c11d2d6142353..f4c1fbc158062 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -928,3 +928,63 @@ L4: L5: r26 = box(None, 1) return r26 +[case testMatchClassPatternPositionalCapture_python3_10] +class C: + __match_args__ = ("x",) + + x: int + +def f(x): + match x: + case C(num): + print("matched") +[out] +def C.__mypyc_defaults_setup(__mypyc_self__): + __mypyc_self__ :: __main__.C + r0 :: str + r1 :: tuple[str] +L0: + r0 = 'x' + r1 = (r0) + __mypyc_self__.__match_args__ = r1 + return 1 +def f(x): + x, r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4 :: str + r5 :: object + r6, num :: int + r7 :: str + r8 :: object + r9 :: str + r10 :: object + r11 :: object[1] + r12 :: object_ptr + r13, r14 :: object +L0: + r0 = __main__.C :: type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L3 :: bool +L1: + r4 = 'x' + r5 = CPyObject_GetAttr(x, r4) + r6 = unbox(int, r5) + num = r6 +L2: + r7 = 'matched' + r8 = builtins :: module + r9 = 'print' + r10 = CPyObject_GetAttr(r8, r9) + r11 = [r7] + r12 = load_address r11 + r13 = _PyObject_Vectorcall(r10, r12, 1, 0) + keep_alive r7 + goto L4 +L3: +L4: + r14 = box(None, 1) + return r14 From 037652695f8928bcc89147875fc545035e11cad4 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 20 Oct 2022 15:38:25 -0700 Subject: [PATCH 50/97] Add basic mapping support --- mypyc/irbuild/statement.py | 19 +++++++++++----- mypyc/primitives/misc_ops.py | 8 +++++++ mypyc/test-data/irbuild-match.test | 35 ++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 99a9cab59e4af..6e8346896cf77 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -39,7 +39,6 @@ TryStmt, TupleExpr, TypeInfo, - Var, WhileStmt, WithStmt, YieldExpr, @@ -49,13 +48,12 @@ AsPattern, ClassPattern, OrPattern, - Pattern, - PatternVisitor, + MappingPattern, SingletonPattern, ValuePattern, ) from mypy.traverser import TraverserVisitor -from mypy.types import Instance, LiteralType, TupleType +from mypy.types import Instance, TupleType from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -106,10 +104,10 @@ ) from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op from mypyc.primitives.misc_ops import ( + check_mapping_protocol, check_stop_op, coro_op, import_from_op, - none_object_op, send_op, slow_isinstance_op, type_op, @@ -1064,6 +1062,17 @@ def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) + def visit_mapping_pattern(self, pattern: MappingPattern) -> None: + assert not pattern.rest + + is_map = self.builder.call_c( + check_mapping_protocol, + [self.subject], + pattern.line, + ) + + self.builder.add_bool_branch(is_map, self.code_block, self.next_block) + def bind_as_pattern(self, value: Value) -> None: if self.as_pattern and self.as_pattern.name: target = self.builder.get_assignment_target(self.as_pattern.name) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 07df9c69714ba..1bdf2b041153a 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -238,3 +238,11 @@ c_function_name="CPySingledispatch_RegisterFunction", error_kind=ERR_MAGIC, ) + +# Check that object supports Mapping protocol +check_mapping_protocol = custom_op( + arg_types=[object_rprimitive], + return_type=int_rprimitive, + c_function_name="PyMapping_Check", + error_kind=ERR_NEVER, +) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index f4c1fbc158062..9c0d856d8cad5 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -988,3 +988,38 @@ L3: L4: r14 = box(None, 1) return r14 +[case testMatchMappingEmpty_python3_10] +def f(x): + match x: + case {}: + print("matched") +[out] +def f(x): + x :: object + r0 :: int + r1 :: bit + r2 :: str + r3 :: object + r4 :: str + r5 :: object + r6 :: object[1] + r7 :: object_ptr + r8, r9 :: object +L0: + r0 = PyMapping_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = 'matched' + r3 = builtins :: module + r4 = 'print' + r5 = CPyObject_GetAttr(r3, r4) + r6 = [r2] + r7 = load_address r6 + r8 = _PyObject_Vectorcall(r5, r7, 1, 0) + keep_alive r2 + goto L3 +L2: +L3: + r9 = box(None, 1) + return r9 From 5e479b0d20b53ca914c3f17232623e7bccde09b8 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 20 Oct 2022 17:54:49 -0700 Subject: [PATCH 51/97] Add key value patterns --- mypyc/irbuild/statement.py | 21 ++++++++++-- mypyc/test-data/irbuild-match.test | 53 +++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 6e8346896cf77..c3aacb63f85da 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -102,7 +102,12 @@ reraise_exception_op, restore_exc_info_op, ) -from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op +from mypyc.primitives.generic_ops import ( + iter_op, + next_raw_op, + py_delattr_op, + py_getattr_op, +) from mypyc.primitives.misc_ops import ( check_mapping_protocol, check_stop_op, @@ -990,7 +995,6 @@ def visit_or_pattern(self, pattern: OrPattern) -> None: self.builder.goto(self.next_block) def visit_class_pattern(self, pattern: ClassPattern) -> None: - # TODO: add ability to handle class pattern when not at top level cond = self.builder.call_c( slow_isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], @@ -1073,6 +1077,19 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.add_bool_branch(is_map, self.code_block, self.next_block) + for key, value in zip(pattern.keys, pattern.values): + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + attr = self.builder.call_c( + py_getattr_op, + [self.subject, self.builder.accept(key)], + pattern.line + ) + + with self.enter_subpattern(attr): + value.accept(self) + def bind_as_pattern(self, value: Value) -> None: if self.as_pattern and self.as_pattern.name: target = self.builder.get_assignment_target(self.as_pattern.name) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 9c0d856d8cad5..361c68226b8a8 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -844,7 +844,7 @@ L5: L6: r23 = box(None, 1) return r23 -[case testing_python3_10] +[case testAsPatternDoesntBleedIntoSubPatterns_python3_10] class C: __match_args__ = ("a", "b") a: int @@ -1023,3 +1023,54 @@ L2: L3: r9 = box(None, 1) return r9 +[case testing_python3_10] +def f(x): + match x: + case {"key": "value"}: + print("matched") +[out] +def f(x): + x :: object + r0 :: int + r1 :: bit + r2 :: str + r3 :: object + r4 :: str + r5 :: object + r6 :: int32 + r7 :: bit + r8 :: bool + r9 :: str + r10 :: object + r11 :: str + r12 :: object + r13 :: object[1] + r14 :: object_ptr + r15, r16 :: object +L0: + r0 = PyMapping_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L3 :: bool +L1: + r2 = 'key' + r3 = CPyObject_GetAttr(x, r2) + r4 = 'value' + r5 = PyObject_RichCompare(r3, r4, 2) + r6 = PyObject_IsTrue(r5) + r7 = r6 >= 0 :: signed + r8 = truncate r6: int32 to builtins.bool + if r8 goto L2 else goto L3 :: bool +L2: + r9 = 'matched' + r10 = builtins :: module + r11 = 'print' + r12 = CPyObject_GetAttr(r10, r11) + r13 = [r9] + r14 = load_address r13 + r15 = _PyObject_Vectorcall(r12, r14, 1, 0) + keep_alive r9 + goto L4 +L3: +L4: + r16 = box(None, 1) + return r16 From 820b9bdf3183a55bcb43cf1d98df955cbc206c6f Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 20 Oct 2022 18:22:13 -0700 Subject: [PATCH 52/97] Add basic mapping rest --- mypyc/irbuild/statement.py | 18 +++++++++++-- mypyc/primitives/misc_ops.py | 8 ++++++ mypyc/test-data/irbuild-match.test | 41 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index c3aacb63f85da..c90413c0c9b91 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -112,6 +112,7 @@ check_mapping_protocol, check_stop_op, coro_op, + dict_copy, import_from_op, send_op, slow_isinstance_op, @@ -1067,8 +1068,6 @@ def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_mapping_pattern(self, pattern: MappingPattern) -> None: - assert not pattern.rest - is_map = self.builder.call_c( check_mapping_protocol, [self.subject], @@ -1090,6 +1089,21 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: with self.enter_subpattern(attr): value.accept(self) + if pattern.rest: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + copy = self.builder.call_c( + dict_copy, + [self.subject], + pattern.rest.line, + ) + + target = self.builder.get_assignment_target(pattern.rest) + + self.builder.assign(target, copy, pattern.rest.line) + self.builder.goto(self.code_block) + def bind_as_pattern(self, value: Value) -> None: if self.as_pattern and self.as_pattern.name: target = self.builder.get_assignment_target(self.as_pattern.name) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 1bdf2b041153a..71b7901e16a9a 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -246,3 +246,11 @@ c_function_name="PyMapping_Check", error_kind=ERR_NEVER, ) + +# Copy an object into a dict +dict_copy = custom_op( + arg_types=[object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyDict_Copy", + error_kind=ERR_NEVER, +) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 361c68226b8a8..d6d8d239462f5 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1074,3 +1074,44 @@ L3: L4: r16 = box(None, 1) return r16 +[case testMatchMappingPatternWithRest_python3_10] +def f(x): + match x: + case {**rest}: + print("matched") +[out] +def f(x): + x :: object + r0 :: int + r1 :: bit + r2 :: object + r3, rest :: dict + r4 :: str + r5 :: object + r6 :: str + r7 :: object + r8 :: object[1] + r9 :: object_ptr + r10, r11 :: object +L0: + r0 = PyMapping_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L3 :: bool +L1: + r2 = PyDict_Copy(x) + r3 = cast(dict, r2) + rest = r3 +L2: + r4 = 'matched' + r5 = builtins :: module + r6 = 'print' + r7 = CPyObject_GetAttr(r5, r6) + r8 = [r4] + r9 = load_address r8 + r10 = _PyObject_Vectorcall(r7, r9, 1, 0) + keep_alive r4 + goto L4 +L3: +L4: + r11 = box(None, 1) + return r11 From 29eadd5405b40763fb4501f0a7ad289d0b97c1c5 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 20 Oct 2022 18:37:09 -0700 Subject: [PATCH 53/97] Make sure to pop keys from rest dict --- mypyc/irbuild/statement.py | 16 ++++++-- mypyc/primitives/misc_ops.py | 8 ++++ mypyc/test-data/irbuild-match.test | 61 ++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index c90413c0c9b91..f2d29d63113f2 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -113,6 +113,7 @@ check_stop_op, coro_op, dict_copy, + dict_del_item, import_from_op, send_op, slow_isinstance_op, @@ -1076,13 +1077,18 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.add_bool_branch(is_map, self.code_block, self.next_block) + keys: list[Value] = [] + for key, value in zip(pattern.keys, pattern.values): self.builder.activate_block(self.code_block) self.code_block = BasicBlock() + key_value = self.builder.accept(key) + keys.append(key_value) + attr = self.builder.call_c( py_getattr_op, - [self.subject, self.builder.accept(key)], + [self.subject, key_value], pattern.line ) @@ -1093,7 +1099,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - copy = self.builder.call_c( + rest = self.builder.call_c( dict_copy, [self.subject], pattern.rest.line, @@ -1101,7 +1107,11 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: target = self.builder.get_assignment_target(pattern.rest) - self.builder.assign(target, copy, pattern.rest.line) + self.builder.assign(target, rest, pattern.rest.line) + + for i, key in enumerate(keys): + self.builder.call_c(dict_del_item, [rest, key], pattern.keys[i].line) + self.builder.goto(self.code_block) def bind_as_pattern(self, value: Value) -> None: diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 71b7901e16a9a..ddfbed5f08a85 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -254,3 +254,11 @@ c_function_name="PyDict_Copy", error_kind=ERR_NEVER, ) + +# Delete an item from a dict +dict_del_item = custom_op( + arg_types=[object_rprimitive, object_rprimitive], + return_type=int_rprimitive, + c_function_name="PyDict_DelItem", + error_kind=ERR_NEG_INT, +) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index d6d8d239462f5..e762909bb8b70 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1115,3 +1115,64 @@ L3: L4: r11 = box(None, 1) return r11 +[case testMatchMappingPatternWithRestPopKeys_python3_10] +def f(x): + match x: + case {"key": "value", **rest}: + print("matched") +[out] +def f(x): + x :: object + r0 :: int + r1 :: bit + r2 :: str + r3 :: object + r4 :: str + r5 :: object + r6 :: int32 + r7 :: bit + r8 :: bool + r9 :: object + r10, rest :: dict + r11 :: int + r12 :: bit + r13 :: str + r14 :: object + r15 :: str + r16 :: object + r17 :: object[1] + r18 :: object_ptr + r19, r20 :: object +L0: + r0 = PyMapping_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L4 :: bool +L1: + r2 = 'key' + r3 = CPyObject_GetAttr(x, r2) + r4 = 'value' + r5 = PyObject_RichCompare(r3, r4, 2) + r6 = PyObject_IsTrue(r5) + r7 = r6 >= 0 :: signed + r8 = truncate r6: int32 to builtins.bool + if r8 goto L2 else goto L4 :: bool +L2: + r9 = PyDict_Copy(x) + r10 = cast(dict, r9) + rest = r10 + r11 = PyDict_DelItem(r9, r2) + r12 = r11 >= 0 :: signed +L3: + r13 = 'matched' + r14 = builtins :: module + r15 = 'print' + r16 = CPyObject_GetAttr(r14, r15) + r17 = [r13] + r18 = load_address r17 + r19 = _PyObject_Vectorcall(r16, r18, 1, 0) + keep_alive r13 + goto L5 +L4: +L5: + r20 = box(None, 1) + return r20 From d538d6966ac0e243d63960295af0f51515089d34 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 20 Oct 2022 18:42:09 -0700 Subject: [PATCH 54/97] Cleanup --- mypyc/irbuild/statement.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index f2d29d63113f2..52e0f7ed68ef9 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -941,16 +941,14 @@ def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: self.match = match_node self.subject = builder.accept(match_node.subject) - def build_match_body( - self, index: int, code_block: BasicBlock, next_block: BasicBlock - ) -> None: - self.builder.activate_block(code_block) + def build_match_body(self, index: int) -> None: + self.builder.activate_block(self.code_block) if guard := self.match.guards[index]: self.code_block = BasicBlock() cond = self.builder.accept(guard) - self.builder.add_bool_branch(cond, self.code_block, next_block) + self.builder.add_bool_branch(cond, self.code_block, self.next_block) self.builder.activate_block(self.code_block) @@ -964,7 +962,7 @@ def visit_match_stmt(self, m: MatchStmt) -> None: pattern.accept(self) - self.build_match_body(i, self.code_block, self.next_block) + self.build_match_body(i) self.builder.activate_block(self.next_block) self.builder.goto_and_activate(self.final_block) From 553e41cd3482fcd5462d497b1832b1740ea12832 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 20 Oct 2022 18:46:05 -0700 Subject: [PATCH 55/97] Split match stuff into its own file --- mypyc/irbuild/match.py | 236 +++++++++++++++++++++++++++++++++++++ mypyc/irbuild/statement.py | 235 +----------------------------------- 2 files changed, 239 insertions(+), 232 deletions(-) create mode 100644 mypyc/irbuild/match.py diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py new file mode 100644 index 0000000000000..d8e421238fbb7 --- /dev/null +++ b/mypyc/irbuild/match.py @@ -0,0 +1,236 @@ +from contextlib import contextmanager +from typing import Generator + +from mypy.nodes import MatchStmt, TypeInfo +from mypyc.ir.ops import Value, BasicBlock +from mypy.patterns import ( + AsPattern, + ClassPattern, + OrPattern, + MappingPattern, + SingletonPattern, + ValuePattern, +) +from mypy.traverser import TraverserVisitor +from mypy.types import Instance, TupleType + +from mypyc.primitives.generic_ops import py_getattr_op +from mypyc.primitives.misc_ops import ( + check_mapping_protocol, + dict_copy, + dict_del_item, + slow_isinstance_op, +) +from mypyc.irbuild.builder import IRBuilder + +class MatchVisitor(TraverserVisitor): + builder: IRBuilder + code_block: BasicBlock + next_block: BasicBlock + final_block: BasicBlock + subject: Value + match: MatchStmt + + as_pattern: AsPattern | None = None + + def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: + self.builder = builder + + self.code_block = BasicBlock() + self.next_block = BasicBlock() + self.final_block = BasicBlock() + + self.match = match_node + self.subject = builder.accept(match_node.subject) + + def build_match_body(self, index: int) -> None: + self.builder.activate_block(self.code_block) + + if guard := self.match.guards[index]: + self.code_block = BasicBlock() + + cond = self.builder.accept(guard) + self.builder.add_bool_branch(cond, self.code_block, self.next_block) + + self.builder.activate_block(self.code_block) + + self.builder.accept(self.match.bodies[index]) + self.builder.goto(self.final_block) + + def visit_match_stmt(self, m: MatchStmt) -> None: + for i, pattern in enumerate(m.patterns): + self.code_block = BasicBlock() + self.next_block = BasicBlock() + + pattern.accept(self) + + self.build_match_body(i) + self.builder.activate_block(self.next_block) + + self.builder.goto_and_activate(self.final_block) + + def visit_value_pattern(self, pattern: ValuePattern) -> None: + value = self.builder.accept(pattern.expr) + + cond = self.builder.binary_op( + self.subject, + value, + "==", + pattern.expr.line + ) + + self.bind_as_pattern(value) + + self.builder.add_bool_branch(cond, self.code_block, self.next_block) + + def visit_or_pattern(self, pattern: OrPattern) -> None: + for p in pattern.patterns: + # Hack to ensure the as pattern is bound to each pattern in the + # "or" pattern, but not every subpattern + backup = self.as_pattern + p.accept(self) + self.as_pattern = backup + + self.builder.activate_block(self.next_block) + self.next_block = BasicBlock() + + self.builder.goto(self.next_block) + + def visit_class_pattern(self, pattern: ClassPattern) -> None: + cond = self.builder.call_c( + slow_isinstance_op, + [self.subject, self.builder.accept(pattern.class_ref)], + pattern.line + ) + + self.bind_as_pattern(self.subject) + + self.builder.add_bool_branch(cond, self.code_block, self.next_block) + + if pattern.positionals: + node = pattern.class_ref.node + assert isinstance(node, TypeInfo) + + ty = node.names.get("__match_args__") + assert ty and isinstance(ty.type, TupleType) + + match_args: list[str] = [] + + for item in ty.type.items: + assert isinstance(item, Instance) and item.last_known_value + + value = item.last_known_value.value + assert isinstance(value, str) + + match_args.append(value) + + for i, expr in enumerate(pattern.positionals): + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + value = self.builder.py_get_attr( + self.subject, match_args[i], expr.line + ) + + with self.enter_subpattern(value): + expr.accept(self) + + for key, value in zip(pattern.keyword_keys, pattern.keyword_values): + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + attr = self.builder.py_get_attr(self.subject, key, value.line) + + with self.enter_subpattern(attr): + value.accept(self) + + def visit_as_pattern(self, pattern: AsPattern) -> None: + if pattern.pattern: + with self.enter_as_pattern(pattern): + pattern.pattern.accept(self) + + elif pattern.name: + target = self.builder.get_assignment_target(pattern.name) + + self.builder.assign(target, self.subject, pattern.line) + + self.builder.goto(self.code_block) + + def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: + if pattern.value is None: + obj = self.builder.none_object() + elif pattern.value is True: + obj = self.builder.true() + else: + obj = self.builder.false() + + cond = self.builder.binary_op(self.subject, obj, "is", pattern.line) + + self.builder.add_bool_branch(cond, self.code_block, self.next_block) + + def visit_mapping_pattern(self, pattern: MappingPattern) -> None: + is_map = self.builder.call_c( + check_mapping_protocol, + [self.subject], + pattern.line, + ) + + self.builder.add_bool_branch(is_map, self.code_block, self.next_block) + + keys: list[Value] = [] + + for key, value in zip(pattern.keys, pattern.values): + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + key_value = self.builder.accept(key) + keys.append(key_value) + + attr = self.builder.call_c( + py_getattr_op, + [self.subject, key_value], + pattern.line + ) + + with self.enter_subpattern(attr): + value.accept(self) + + if pattern.rest: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + rest = self.builder.call_c( + dict_copy, + [self.subject], + pattern.rest.line, + ) + + target = self.builder.get_assignment_target(pattern.rest) + + self.builder.assign(target, rest, pattern.rest.line) + + for i, key in enumerate(keys): + self.builder.call_c(dict_del_item, [rest, key], pattern.keys[i].line) + + self.builder.goto(self.code_block) + + def bind_as_pattern(self, value: Value) -> None: + if self.as_pattern and self.as_pattern.name: + target = self.builder.get_assignment_target(self.as_pattern.name) + self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + + self.as_pattern = None + + @contextmanager + def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: + old_subject = self.subject + self.subject = subject + yield + self.subject = old_subject + + @contextmanager + def enter_as_pattern(self, pattern: AsPattern) -> Generator[None, None, None]: + old_pattern = self.as_pattern + self.as_pattern = pattern + yield + self.as_pattern = old_pattern diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 52e0f7ed68ef9..0200c5cadf9ca 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -7,10 +7,9 @@ """ from __future__ import annotations -from contextlib import contextmanager import importlib.util -from typing import Callable, Generator, Sequence +from typing import Callable, Sequence from mypy.nodes import ( AssertStmt, @@ -38,22 +37,11 @@ TempNode, TryStmt, TupleExpr, - TypeInfo, WhileStmt, WithStmt, YieldExpr, YieldFromExpr, ) -from mypy.patterns import ( - AsPattern, - ClassPattern, - OrPattern, - MappingPattern, - SingletonPattern, - ValuePattern, -) -from mypy.traverser import TraverserVisitor -from mypy.types import Instance, TupleType from mypyc.ir.ops import ( NO_TRACEBACK_LINE_NO, Assign, @@ -106,21 +94,18 @@ iter_op, next_raw_op, py_delattr_op, - py_getattr_op, ) from mypyc.primitives.misc_ops import ( - check_mapping_protocol, check_stop_op, coro_op, - dict_copy, - dict_del_item, import_from_op, send_op, - slow_isinstance_op, type_op, yield_from_except_op, ) +from .match import MatchVisitor + GenFunc = Callable[[], None] ValueGenFunc = Callable[[], Value] @@ -920,219 +905,5 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: return emit_yield_from_or_await(builder, builder.accept(o.expr), o.line, is_await=True) - -class MatchVisitor(TraverserVisitor): - builder: IRBuilder - code_block: BasicBlock - next_block: BasicBlock - final_block: BasicBlock - subject: Value - match: MatchStmt - - as_pattern: AsPattern | None = None - - def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: - self.builder = builder - - self.code_block = BasicBlock() - self.next_block = BasicBlock() - self.final_block = BasicBlock() - - self.match = match_node - self.subject = builder.accept(match_node.subject) - - def build_match_body(self, index: int) -> None: - self.builder.activate_block(self.code_block) - - if guard := self.match.guards[index]: - self.code_block = BasicBlock() - - cond = self.builder.accept(guard) - self.builder.add_bool_branch(cond, self.code_block, self.next_block) - - self.builder.activate_block(self.code_block) - - self.builder.accept(self.match.bodies[index]) - self.builder.goto(self.final_block) - - def visit_match_stmt(self, m: MatchStmt) -> None: - for i, pattern in enumerate(m.patterns): - self.code_block = BasicBlock() - self.next_block = BasicBlock() - - pattern.accept(self) - - self.build_match_body(i) - self.builder.activate_block(self.next_block) - - self.builder.goto_and_activate(self.final_block) - - def visit_value_pattern(self, pattern: ValuePattern) -> None: - value = self.builder.accept(pattern.expr) - - cond = self.builder.binary_op( - self.subject, - value, - "==", - pattern.expr.line - ) - - self.bind_as_pattern(value) - - self.builder.add_bool_branch(cond, self.code_block, self.next_block) - - def visit_or_pattern(self, pattern: OrPattern) -> None: - for p in pattern.patterns: - # Hack to ensure the as pattern is bound to each pattern in the - # "or" pattern, but not every subpattern - backup = self.as_pattern - p.accept(self) - self.as_pattern = backup - - self.builder.activate_block(self.next_block) - self.next_block = BasicBlock() - - self.builder.goto(self.next_block) - - def visit_class_pattern(self, pattern: ClassPattern) -> None: - cond = self.builder.call_c( - slow_isinstance_op, - [self.subject, self.builder.accept(pattern.class_ref)], - pattern.line - ) - - self.bind_as_pattern(self.subject) - - self.builder.add_bool_branch(cond, self.code_block, self.next_block) - - if pattern.positionals: - node = pattern.class_ref.node - assert isinstance(node, TypeInfo) - - ty = node.names.get("__match_args__") - assert ty and isinstance(ty.type, TupleType) - - match_args: list[str] = [] - - for item in ty.type.items: - assert isinstance(item, Instance) and item.last_known_value - - value = item.last_known_value.value - assert isinstance(value, str) - - match_args.append(value) - - for i, expr in enumerate(pattern.positionals): - self.builder.activate_block(self.code_block) - self.code_block = BasicBlock() - - value = self.builder.py_get_attr( - self.subject, match_args[i], expr.line - ) - - with self.enter_subpattern(value): - expr.accept(self) - - for key, value in zip(pattern.keyword_keys, pattern.keyword_values): - self.builder.activate_block(self.code_block) - self.code_block = BasicBlock() - - attr = self.builder.py_get_attr(self.subject, key, value.line) - - with self.enter_subpattern(attr): - value.accept(self) - - def visit_as_pattern(self, pattern: AsPattern) -> None: - if pattern.pattern: - with self.enter_as_pattern(pattern): - pattern.pattern.accept(self) - - elif pattern.name: - target = self.builder.get_assignment_target(pattern.name) - - self.builder.assign(target, self.subject, pattern.line) - - self.builder.goto(self.code_block) - - def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: - if pattern.value is None: - obj = self.builder.none_object() - elif pattern.value is True: - obj = self.builder.true() - else: - obj = self.builder.false() - - cond = self.builder.binary_op(self.subject, obj, "is", pattern.line) - - self.builder.add_bool_branch(cond, self.code_block, self.next_block) - - def visit_mapping_pattern(self, pattern: MappingPattern) -> None: - is_map = self.builder.call_c( - check_mapping_protocol, - [self.subject], - pattern.line, - ) - - self.builder.add_bool_branch(is_map, self.code_block, self.next_block) - - keys: list[Value] = [] - - for key, value in zip(pattern.keys, pattern.values): - self.builder.activate_block(self.code_block) - self.code_block = BasicBlock() - - key_value = self.builder.accept(key) - keys.append(key_value) - - attr = self.builder.call_c( - py_getattr_op, - [self.subject, key_value], - pattern.line - ) - - with self.enter_subpattern(attr): - value.accept(self) - - if pattern.rest: - self.builder.activate_block(self.code_block) - self.code_block = BasicBlock() - - rest = self.builder.call_c( - dict_copy, - [self.subject], - pattern.rest.line, - ) - - target = self.builder.get_assignment_target(pattern.rest) - - self.builder.assign(target, rest, pattern.rest.line) - - for i, key in enumerate(keys): - self.builder.call_c(dict_del_item, [rest, key], pattern.keys[i].line) - - self.builder.goto(self.code_block) - - def bind_as_pattern(self, value: Value) -> None: - if self.as_pattern and self.as_pattern.name: - target = self.builder.get_assignment_target(self.as_pattern.name) - self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore - - self.as_pattern = None - - @contextmanager - def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: - old_subject = self.subject - self.subject = subject - yield - self.subject = old_subject - - @contextmanager - def enter_as_pattern(self, pattern: AsPattern) -> Generator[None, None, None]: - old_pattern = self.as_pattern - self.as_pattern = pattern - yield - self.as_pattern = old_pattern - - def transform_match_stmt(builder: IRBuilder, m: MatchStmt) -> None: m.accept(MatchVisitor(builder, m)) From ec128ad2886f81d3fb5b40885ece3ca09814eedf Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Fri, 21 Oct 2022 16:25:02 -0700 Subject: [PATCH 56/97] Add a bunch of tests --- mypyc/test-data/run-match.test | 93 +++++++++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 7 deletions(-) diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 5f90e98fbdf82..c6604f748a804 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -1,16 +1,95 @@ -[case testMatchBasic_python3_10] +[case testTheBigMatch_python3_10] +class Person: + __match_args__ = ("name", "age") + + name: str + age: int + + def __init__(self, name: str, age: int) -> None: + self.name = name + self.age = age + + def f(x): match x: case 123: - print("matched!") + print("test 1") + + case 456 | 789: + print("test 2") + + case True | False | None: + print("test 3") + + case Person("bob" as name, age): + print(f"test 4 ({name=}, {age=})") + + case num if num == 5: + print("test 5") + + case 6 as num: + print(f"test 6 ({num=})") + + case (7 | "7") as value: + print(f"test 7 ({value=})") + + case Person("alice", age=123): + print("test 8") + + # case Person("charlie", age=123 | 456): + # print("test 9") + case _: - print("no match") + print("test final") [file driver.py] -from native import f +from native import f, Person +# test 1 f(123) -f(321) + +# test 2 +f(456) +f(789) + +# test 3 +f(True) +f(False) +f(None) + +# test 4 +f(Person("bob", 123)) + +# test 5 +f(5) + +# test 6 +f(6) + +# test 7 +f(7) +f("7") + +# test 8 +f(Person("alice", 123)) + +# test 9 +# f(Person("charlie", 123)) +# f(Person("charlie", 456)) + +# test final +f("") [out] -matched! -no match +test 1 +test 2 +test 2 +test 3 +test 3 +test 3 +test 4 (name='bob', age=123) +test 5 +test 6 (num=6) +test 7 (value=7) +test 7 (value='7') +test 8 +test final From 1b0523bdbb4c367bee3d70c42f0b83a5d6a16c91 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 22 Oct 2022 20:19:01 -0700 Subject: [PATCH 57/97] Fix next_block not being setup for or pattern --- mypyc/irbuild/match.py | 4 ++++ mypyc/test-data/irbuild-match.test | 4 ++-- mypyc/test-data/run-match.test | 10 ++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index d8e421238fbb7..2f0270df2f68b 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -84,6 +84,9 @@ def visit_value_pattern(self, pattern: ValuePattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_or_pattern(self, pattern: OrPattern) -> None: + backup_block = self.next_block + self.next_block = BasicBlock() + for p in pattern.patterns: # Hack to ensure the as pattern is bound to each pattern in the # "or" pattern, but not every subpattern @@ -94,6 +97,7 @@ def visit_or_pattern(self, pattern: OrPattern) -> None: self.builder.activate_block(self.next_block) self.next_block = BasicBlock() + self.next_block = backup_block self.builder.goto(self.next_block) def visit_class_pattern(self, pattern: ClassPattern) -> None: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index e762909bb8b70..16a1da97d8e11 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -811,7 +811,7 @@ L0: r1 = PyObject_IsInstance(x, r0) r2 = r1 >= 0 :: signed r3 = truncate r1: int32 to builtins.bool - if r3 goto L1 else goto L2 :: bool + if r3 goto L1 else goto L5 :: bool L1: r4 = 'num' r5 = CPyObject_GetAttr(x, r4) @@ -1023,7 +1023,7 @@ L2: L3: r9 = box(None, 1) return r9 -[case testing_python3_10] +[case testMatchMappingPatternWithKeys_python3_10] def f(x): match x: case {"key": "value"}: diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index c6604f748a804..cd3fe63b8afa5 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -36,8 +36,8 @@ def f(x): case Person("alice", age=123): print("test 8") - # case Person("charlie", age=123 | 456): - # print("test 9") + case Person("charlie", age=123 | 456): + print("test 9") case _: print("test final") @@ -73,8 +73,10 @@ f("7") f(Person("alice", 123)) # test 9 -# f(Person("charlie", 123)) -# f(Person("charlie", 456)) +f(Person("charlie", 123)) +f(Person("charlie", 456)) + +# stopping at irbuild-match.test line 779 # test final f("") From 626d8f11663db40a2c7d3cfc5f00ba1727e505a6 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sat, 22 Oct 2022 20:57:51 -0700 Subject: [PATCH 58/97] Fix as pattern variable being assigned if condition is false --- mypyc/irbuild/match.py | 13 ++++++++++--- mypyc/test-data/irbuild-match.test | 26 ++++++++++++++------------ mypyc/test-data/run-match.test | 12 +++++++++++- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 2f0270df2f68b..5472e9d94060d 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -107,10 +107,10 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: pattern.line ) - self.bind_as_pattern(self.subject) - self.builder.add_bool_branch(cond, self.code_block, self.next_block) + self.bind_as_pattern(self.subject, new_block=True) + if pattern.positionals: node = pattern.class_ref.node assert isinstance(node, TypeInfo) @@ -218,13 +218,20 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) - def bind_as_pattern(self, value: Value) -> None: + def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: if self.as_pattern and self.as_pattern.name: + if new_block: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + target = self.builder.get_assignment_target(self.as_pattern.name) self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore self.as_pattern = None + if new_block: + self.builder.goto(self.code_block) + @contextmanager def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: old_subject = self.subject diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 16a1da97d8e11..0c1233c053c0e 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -577,9 +577,10 @@ L0: r2 = PyObject_IsInstance(r1, r0) r3 = r2 >= 0 :: signed r4 = truncate r2: int32 to builtins.bool - i = 246 - if r4 goto L1 else goto L2 :: bool + if r4 goto L1 else goto L3 :: bool L1: + i = 246 +L2: r5 = builtins :: module r6 = 'print' r7 = CPyObject_GetAttr(r5, r6) @@ -588,9 +589,9 @@ L1: r10 = load_address r9 r11 = _PyObject_Vectorcall(r7, r10, 1, 0) keep_alive r8 - goto L3 -L2: + goto L4 L3: +L4: r12 = box(None, 1) return r12 [case testMatchClassPatternWithPositionalArgs_python3_10] @@ -893,10 +894,11 @@ L0: r1 = PyObject_IsInstance(x, r0) r2 = r1 >= 0 :: signed r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L5 :: bool +L1: r4 = cast(__main__.C, x) y = r4 - if r3 goto L1 else goto L4 :: bool -L1: +L2: r5 = 'a' r6 = CPyObject_GetAttr(x, r5) r7 = object 1 @@ -904,8 +906,8 @@ L1: r9 = PyObject_IsTrue(r8) r10 = r9 >= 0 :: signed r11 = truncate r9: int32 to builtins.bool - if r11 goto L2 else goto L4 :: bool -L2: + if r11 goto L3 else goto L5 :: bool +L3: r12 = 'b' r13 = CPyObject_GetAttr(x, r12) r14 = object 2 @@ -913,8 +915,8 @@ L2: r16 = PyObject_IsTrue(r15) r17 = r16 >= 0 :: signed r18 = truncate r16: int32 to builtins.bool - if r18 goto L3 else goto L4 :: bool -L3: + if r18 goto L4 else goto L5 :: bool +L4: r19 = 'matched' r20 = builtins :: module r21 = 'print' @@ -923,9 +925,9 @@ L3: r24 = load_address r23 r25 = _PyObject_Vectorcall(r22, r24, 1, 0) keep_alive r19 - goto L5 -L4: + goto L6 L5: +L6: r26 = box(None, 1) return r26 [case testMatchClassPatternPositionalCapture_python3_10] diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index cd3fe63b8afa5..314c407363f01 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -9,6 +9,9 @@ class Person: self.name = name self.age = age + def __str__(self) -> str: + return f"Person(name={self.name!r}, age={self.age})" + def f(x): match x: @@ -39,6 +42,9 @@ def f(x): case Person("charlie", age=123 | 456): print("test 9") + case Person("dave", 123) as dave: + print(f"test 10 {dave}") + case _: print("test final") [file driver.py] @@ -76,7 +82,8 @@ f(Person("alice", 123)) f(Person("charlie", 123)) f(Person("charlie", 456)) -# stopping at irbuild-match.test line 779 +# test 10 +f(Person("dave", 123)) # test final f("") @@ -94,4 +101,7 @@ test 6 (num=6) test 7 (value=7) test 7 (value='7') test 8 +test 9 +test 9 +test 10 Person(name='dave', age=123) test final From 3bb2e55781a95a08a5f550112b26ef72e4bb352e Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 11:56:17 -0700 Subject: [PATCH 59/97] Sorta fix mapping issue: For some reason PyMapping_Check returns true for string types, dont know why. This probably means I have to to do a subclass check. --- mypyc/primitives/misc_ops.py | 4 ++-- mypyc/test-data/run-match.test | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index ddfbed5f08a85..3b5cef97a8ce8 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -242,7 +242,7 @@ # Check that object supports Mapping protocol check_mapping_protocol = custom_op( arg_types=[object_rprimitive], - return_type=int_rprimitive, + return_type=c_int_rprimitive, c_function_name="PyMapping_Check", error_kind=ERR_NEVER, ) @@ -258,7 +258,7 @@ # Delete an item from a dict dict_del_item = custom_op( arg_types=[object_rprimitive, object_rprimitive], - return_type=int_rprimitive, + return_type=c_int_rprimitive, c_function_name="PyDict_DelItem", error_kind=ERR_NEG_INT, ) diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 314c407363f01..af1b810157224 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -45,6 +45,9 @@ def f(x): case Person("dave", 123) as dave: print(f"test 10 {dave}") + case {}: + print("test 11") + case _: print("test final") [file driver.py] @@ -85,6 +88,9 @@ f(Person("charlie", 456)) # test 10 f(Person("dave", 123)) +# test 11 +f({}) + # test final f("") @@ -104,4 +110,5 @@ test 8 test 9 test 9 test 10 Person(name='dave', age=123) +test 11 test final From 59f140e4edbcb2d3578c8f27460611b4019f3223 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 12:26:03 -0700 Subject: [PATCH 60/97] Switch to using PyDict_Check --- mypyc/irbuild/match.py | 13 +++++++++---- mypyc/primitives/misc_ops.py | 6 +++--- mypyc/test-data/irbuild-match.test | 18 +++++++++--------- 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 5472e9d94060d..04bb7d8ffca3f 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -16,7 +16,7 @@ from mypyc.primitives.generic_ops import py_getattr_op from mypyc.primitives.misc_ops import ( - check_mapping_protocol, + check_dict, dict_copy, dict_del_item, slow_isinstance_op, @@ -173,13 +173,18 @@ def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_mapping_pattern(self, pattern: MappingPattern) -> None: - is_map = self.builder.call_c( - check_mapping_protocol, + # TODO: technically this should accept any object that supports the + # mapping protocol, but the PyMapping_Check function returns true for + # string types, which is confusing. This should work for the time + # being, but will need to be changed at some point. + + is_dict = self.builder.call_c( + check_dict, [self.subject], pattern.line, ) - self.builder.add_bool_branch(is_map, self.code_block, self.next_block) + self.builder.add_bool_branch(is_dict, self.code_block, self.next_block) keys: list[Value] = [] diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 3b5cef97a8ce8..60865a642240c 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -239,11 +239,11 @@ error_kind=ERR_MAGIC, ) -# Check that object supports Mapping protocol -check_mapping_protocol = custom_op( +# Check that the object is a dict or a subclass of dict +check_dict = custom_op( arg_types=[object_rprimitive], return_type=c_int_rprimitive, - c_function_name="PyMapping_Check", + c_function_name="PyDict_Check", error_kind=ERR_NEVER, ) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 0c1233c053c0e..705d41330875f 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -998,7 +998,7 @@ def f(x): [out] def f(x): x :: object - r0 :: int + r0 :: int32 r1 :: bit r2 :: str r3 :: object @@ -1008,7 +1008,7 @@ def f(x): r7 :: object_ptr r8, r9 :: object L0: - r0 = PyMapping_Check(x) + r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L2 :: bool L1: @@ -1033,7 +1033,7 @@ def f(x): [out] def f(x): x :: object - r0 :: int + r0 :: int32 r1 :: bit r2 :: str r3 :: object @@ -1050,7 +1050,7 @@ def f(x): r14 :: object_ptr r15, r16 :: object L0: - r0 = PyMapping_Check(x) + r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L3 :: bool L1: @@ -1084,7 +1084,7 @@ def f(x): [out] def f(x): x :: object - r0 :: int + r0 :: int32 r1 :: bit r2 :: object r3, rest :: dict @@ -1096,7 +1096,7 @@ def f(x): r9 :: object_ptr r10, r11 :: object L0: - r0 = PyMapping_Check(x) + r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L3 :: bool L1: @@ -1125,7 +1125,7 @@ def f(x): [out] def f(x): x :: object - r0 :: int + r0 :: int32 r1 :: bit r2 :: str r3 :: object @@ -1136,7 +1136,7 @@ def f(x): r8 :: bool r9 :: object r10, rest :: dict - r11 :: int + r11 :: int32 r12 :: bit r13 :: str r14 :: object @@ -1146,7 +1146,7 @@ def f(x): r18 :: object_ptr r19, r20 :: object L0: - r0 = PyMapping_Check(x) + r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L4 :: bool L1: From 67efc089e9502149474c67bb10395b7d3c5a2c1a Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 12:56:33 -0700 Subject: [PATCH 61/97] Fix dict item being accessed via get attr instead of get item --- mypyc/irbuild/match.py | 8 +- mypyc/test-data/irbuild-match.test | 142 +++++++++++++++-------------- mypyc/test-data/run-match.test | 13 ++- 3 files changed, 88 insertions(+), 75 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 04bb7d8ffca3f..3839acaf39257 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -14,7 +14,7 @@ from mypy.traverser import TraverserVisitor from mypy.types import Instance, TupleType -from mypyc.primitives.generic_ops import py_getattr_op +from mypyc.primitives.dict_ops import dict_get_item_op from mypyc.primitives.misc_ops import ( check_dict, dict_copy, @@ -195,13 +195,13 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: key_value = self.builder.accept(key) keys.append(key_value) - attr = self.builder.call_c( - py_getattr_op, + item = self.builder.call_c( + dict_get_item_op, [self.subject, key_value], pattern.line ) - with self.enter_subpattern(attr): + with self.enter_subpattern(item): value.accept(self) if pattern.rest: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 705d41330875f..dd2d6ab436af7 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1036,46 +1036,48 @@ def f(x): r0 :: int32 r1 :: bit r2 :: str - r3 :: object - r4 :: str - r5 :: object - r6 :: int32 - r7 :: bit - r8 :: bool - r9 :: str - r10 :: object - r11 :: str - r12 :: object - r13 :: object[1] - r14 :: object_ptr - r15, r16 :: object + r3 :: dict + r4 :: object + r5 :: str + r6 :: object + r7 :: int32 + r8 :: bit + r9 :: bool + r10 :: str + r11 :: object + r12 :: str + r13 :: object + r14 :: object[1] + r15 :: object_ptr + r16, r17 :: object L0: r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L3 :: bool L1: r2 = 'key' - r3 = CPyObject_GetAttr(x, r2) - r4 = 'value' - r5 = PyObject_RichCompare(r3, r4, 2) - r6 = PyObject_IsTrue(r5) - r7 = r6 >= 0 :: signed - r8 = truncate r6: int32 to builtins.bool - if r8 goto L2 else goto L3 :: bool + r3 = cast(dict, x) + r4 = CPyDict_GetItem(r3, r2) + r5 = 'value' + r6 = PyObject_RichCompare(r4, r5, 2) + r7 = PyObject_IsTrue(r6) + r8 = r7 >= 0 :: signed + r9 = truncate r7: int32 to builtins.bool + if r9 goto L2 else goto L3 :: bool L2: - r9 = 'matched' - r10 = builtins :: module - r11 = 'print' - r12 = CPyObject_GetAttr(r10, r11) - r13 = [r9] - r14 = load_address r13 - r15 = _PyObject_Vectorcall(r12, r14, 1, 0) - keep_alive r9 + r10 = 'matched' + r11 = builtins :: module + r12 = 'print' + r13 = CPyObject_GetAttr(r11, r12) + r14 = [r10] + r15 = load_address r14 + r16 = _PyObject_Vectorcall(r13, r15, 1, 0) + keep_alive r10 goto L4 L3: L4: - r16 = box(None, 1) - return r16 + r17 = box(None, 1) + return r17 [case testMatchMappingPatternWithRest_python3_10] def f(x): match x: @@ -1128,53 +1130,55 @@ def f(x): r0 :: int32 r1 :: bit r2 :: str - r3 :: object - r4 :: str - r5 :: object - r6 :: int32 - r7 :: bit - r8 :: bool - r9 :: object - r10, rest :: dict - r11 :: int32 - r12 :: bit - r13 :: str - r14 :: object - r15 :: str - r16 :: object - r17 :: object[1] - r18 :: object_ptr - r19, r20 :: object + r3 :: dict + r4 :: object + r5 :: str + r6 :: object + r7 :: int32 + r8 :: bit + r9 :: bool + r10 :: object + r11, rest :: dict + r12 :: int32 + r13 :: bit + r14 :: str + r15 :: object + r16 :: str + r17 :: object + r18 :: object[1] + r19 :: object_ptr + r20, r21 :: object L0: r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L4 :: bool L1: r2 = 'key' - r3 = CPyObject_GetAttr(x, r2) - r4 = 'value' - r5 = PyObject_RichCompare(r3, r4, 2) - r6 = PyObject_IsTrue(r5) - r7 = r6 >= 0 :: signed - r8 = truncate r6: int32 to builtins.bool - if r8 goto L2 else goto L4 :: bool + r3 = cast(dict, x) + r4 = CPyDict_GetItem(r3, r2) + r5 = 'value' + r6 = PyObject_RichCompare(r4, r5, 2) + r7 = PyObject_IsTrue(r6) + r8 = r7 >= 0 :: signed + r9 = truncate r7: int32 to builtins.bool + if r9 goto L2 else goto L4 :: bool L2: - r9 = PyDict_Copy(x) - r10 = cast(dict, r9) - rest = r10 - r11 = PyDict_DelItem(r9, r2) - r12 = r11 >= 0 :: signed + r10 = PyDict_Copy(x) + r11 = cast(dict, r10) + rest = r11 + r12 = PyDict_DelItem(r10, r2) + r13 = r12 >= 0 :: signed L3: - r13 = 'matched' - r14 = builtins :: module - r15 = 'print' - r16 = CPyObject_GetAttr(r14, r15) - r17 = [r13] - r18 = load_address r17 - r19 = _PyObject_Vectorcall(r16, r18, 1, 0) - keep_alive r13 + r14 = 'matched' + r15 = builtins :: module + r16 = 'print' + r17 = CPyObject_GetAttr(r15, r16) + r18 = [r14] + r19 = load_address r18 + r20 = _PyObject_Vectorcall(r17, r19, 1, 0) + keep_alive r14 goto L5 L4: L5: - r20 = box(None, 1) - return r20 + r21 = box(None, 1) + return r21 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index af1b810157224..60a630b7af257 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -45,9 +45,12 @@ def f(x): case Person("dave", 123) as dave: print(f"test 10 {dave}") - case {}: + case {"test": 11}: print("test 11") + case {}: + print("test final-1") + case _: print("test final") [file driver.py] @@ -88,7 +91,11 @@ f(Person("charlie", 456)) # test 10 f(Person("dave", 123)) -# test 11 +# test 12 +f({"test": 12}) +f({"test": 12, "some": "key"}) + +# test final-1 f({}) # test final @@ -111,4 +118,6 @@ test 9 test 9 test 10 Person(name='dave', age=123) test 11 +test 11 +test final-1 test final From e57f0a19ada6babe2a466486a472535ee3ad6b42 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 13:15:53 -0700 Subject: [PATCH 62/97] Check that key is contained in dict before grabbing it --- mypyc/irbuild/match.py | 8 ++ mypyc/test-data/irbuild-match.test | 170 ++++++++++++++++------------- 2 files changed, 101 insertions(+), 77 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 3839acaf39257..a6c6e6bb03b1c 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -195,6 +195,14 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: key_value = self.builder.accept(key) keys.append(key_value) + exists = self.builder.binary_op( + key_value, self.subject, "in", pattern.line + ) + + self.builder.add_bool_branch(exists, self.code_block, self.next_block) + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + item = self.builder.call_c( dict_get_item_op, [self.subject, key_value], diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index dd2d6ab436af7..a4c7443bd6d3a 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1036,48 +1036,56 @@ def f(x): r0 :: int32 r1 :: bit r2 :: str - r3 :: dict - r4 :: object - r5 :: str - r6 :: object - r7 :: int32 - r8 :: bit - r9 :: bool - r10 :: str - r11 :: object - r12 :: str - r13 :: object - r14 :: object[1] - r15 :: object_ptr - r16, r17 :: object + r3 :: int32 + r4 :: bit + r5 :: bool + r6 :: dict + r7 :: object + r8 :: str + r9 :: object + r10 :: int32 + r11 :: bit + r12 :: bool + r13 :: str + r14 :: object + r15 :: str + r16 :: object + r17 :: object[1] + r18 :: object_ptr + r19, r20 :: object L0: r0 = PyDict_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L3 :: bool + if r1 goto L1 else goto L4 :: bool L1: r2 = 'key' - r3 = cast(dict, x) - r4 = CPyDict_GetItem(r3, r2) - r5 = 'value' - r6 = PyObject_RichCompare(r4, r5, 2) - r7 = PyObject_IsTrue(r6) - r8 = r7 >= 0 :: signed - r9 = truncate r7: int32 to builtins.bool - if r9 goto L2 else goto L3 :: bool + r3 = PySequence_Contains(x, r2) + r4 = r3 >= 0 :: signed + r5 = truncate r3: int32 to builtins.bool + if r5 goto L2 else goto L4 :: bool L2: - r10 = 'matched' - r11 = builtins :: module - r12 = 'print' - r13 = CPyObject_GetAttr(r11, r12) - r14 = [r10] - r15 = load_address r14 - r16 = _PyObject_Vectorcall(r13, r15, 1, 0) - keep_alive r10 - goto L4 + r6 = cast(dict, x) + r7 = CPyDict_GetItem(r6, r2) + r8 = 'value' + r9 = PyObject_RichCompare(r7, r8, 2) + r10 = PyObject_IsTrue(r9) + r11 = r10 >= 0 :: signed + r12 = truncate r10: int32 to builtins.bool + if r12 goto L3 else goto L4 :: bool L3: + r13 = 'matched' + r14 = builtins :: module + r15 = 'print' + r16 = CPyObject_GetAttr(r14, r15) + r17 = [r13] + r18 = load_address r17 + r19 = _PyObject_Vectorcall(r16, r18, 1, 0) + keep_alive r13 + goto L5 L4: - r17 = box(None, 1) - return r17 +L5: + r20 = box(None, 1) + return r20 [case testMatchMappingPatternWithRest_python3_10] def f(x): match x: @@ -1130,55 +1138,63 @@ def f(x): r0 :: int32 r1 :: bit r2 :: str - r3 :: dict - r4 :: object - r5 :: str - r6 :: object - r7 :: int32 - r8 :: bit - r9 :: bool - r10 :: object - r11, rest :: dict - r12 :: int32 - r13 :: bit - r14 :: str - r15 :: object - r16 :: str - r17 :: object - r18 :: object[1] - r19 :: object_ptr - r20, r21 :: object + r3 :: int32 + r4 :: bit + r5 :: bool + r6 :: dict + r7 :: object + r8 :: str + r9 :: object + r10 :: int32 + r11 :: bit + r12 :: bool + r13 :: object + r14, rest :: dict + r15 :: int32 + r16 :: bit + r17 :: str + r18 :: object + r19 :: str + r20 :: object + r21 :: object[1] + r22 :: object_ptr + r23, r24 :: object L0: r0 = PyDict_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L4 :: bool + if r1 goto L1 else goto L5 :: bool L1: r2 = 'key' - r3 = cast(dict, x) - r4 = CPyDict_GetItem(r3, r2) - r5 = 'value' - r6 = PyObject_RichCompare(r4, r5, 2) - r7 = PyObject_IsTrue(r6) - r8 = r7 >= 0 :: signed - r9 = truncate r7: int32 to builtins.bool - if r9 goto L2 else goto L4 :: bool + r3 = PySequence_Contains(x, r2) + r4 = r3 >= 0 :: signed + r5 = truncate r3: int32 to builtins.bool + if r5 goto L2 else goto L5 :: bool L2: - r10 = PyDict_Copy(x) - r11 = cast(dict, r10) - rest = r11 - r12 = PyDict_DelItem(r10, r2) - r13 = r12 >= 0 :: signed + r6 = cast(dict, x) + r7 = CPyDict_GetItem(r6, r2) + r8 = 'value' + r9 = PyObject_RichCompare(r7, r8, 2) + r10 = PyObject_IsTrue(r9) + r11 = r10 >= 0 :: signed + r12 = truncate r10: int32 to builtins.bool + if r12 goto L3 else goto L5 :: bool L3: - r14 = 'matched' - r15 = builtins :: module - r16 = 'print' - r17 = CPyObject_GetAttr(r15, r16) - r18 = [r14] - r19 = load_address r18 - r20 = _PyObject_Vectorcall(r17, r19, 1, 0) - keep_alive r14 - goto L5 + r13 = PyDict_Copy(x) + r14 = cast(dict, r13) + rest = r14 + r15 = PyDict_DelItem(r13, r2) + r16 = r15 >= 0 :: signed L4: + r17 = 'matched' + r18 = builtins :: module + r19 = 'print' + r20 = CPyObject_GetAttr(r18, r19) + r21 = [r17] + r22 = load_address r21 + r23 = _PyObject_Vectorcall(r20, r22, 1, 0) + keep_alive r17 + goto L6 L5: - r21 = box(None, 1) - return r21 +L6: + r24 = box(None, 1) + return r24 From 214dceac0355f41da3931c0d07bd41abecf8e600 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 14:03:28 -0700 Subject: [PATCH 63/97] Finish map runtime tests --- mypyc/test-data/run-match.test | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 60a630b7af257..53cc7cd6c0da5 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -48,8 +48,11 @@ def f(x): case {"test": 11}: print("test 11") + case {"test": 12, **rest}: + print(f"test 12 (rest={rest})") + case {}: - print("test final-1") + print("test map final") case _: print("test final") @@ -91,11 +94,16 @@ f(Person("charlie", 456)) # test 10 f(Person("dave", 123)) +# test 11 +f({"test": 11}) +f({"test": 11, "some": "key"}) + # test 12 f({"test": 12}) -f({"test": 12, "some": "key"}) +f({"test": 12, "key": "value"}) +f({"test": 12, "key": "value", "abc": "123"}) -# test final-1 +# test map final f({}) # test final @@ -119,5 +127,8 @@ test 9 test 10 Person(name='dave', age=123) test 11 test 11 -test final-1 +test 12 (rest={}) +test 12 (rest={'key': 'value'}) +test 12 (rest={'key': 'value', 'abc': '123'}) +test map final test final From ad82d9f0c027f7390d92343dd659bb695f50fcd1 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 14:11:21 -0700 Subject: [PATCH 64/97] Add empty sequence pattern matching --- mypyc/irbuild/match.py | 11 ++++++++++ mypyc/primitives/list_ops.py | 8 +++++++ mypyc/test-data/irbuild-match.test | 35 ++++++++++++++++++++++++++++++ mypyc/test-data/run-match.test | 7 ++++++ 4 files changed, 61 insertions(+) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index a6c6e6bb03b1c..cb316c270067e 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -9,6 +9,7 @@ OrPattern, MappingPattern, SingletonPattern, + SequencePattern, ValuePattern, ) from mypy.traverser import TraverserVisitor @@ -21,6 +22,7 @@ dict_del_item, slow_isinstance_op, ) +from mypyc.primitives.list_ops import check_list from mypyc.irbuild.builder import IRBuilder class MatchVisitor(TraverserVisitor): @@ -231,6 +233,15 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) + def visit_sequence_pattern(self, pattern: SequencePattern) -> None: + is_list = self.builder.call_c( + check_list, + [self.subject], + pattern.line, + ) + + self.builder.add_bool_branch(is_list, self.code_block, self.next_block) + def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: if self.as_pattern and self.as_pattern.name: if new_block: diff --git a/mypyc/primitives/list_ops.py b/mypyc/primitives/list_ops.py index c729e264fc14e..283e702734985 100644 --- a/mypyc/primitives/list_ops.py +++ b/mypyc/primitives/list_ops.py @@ -277,3 +277,11 @@ c_function_name="CPyList_GetSlice", error_kind=ERR_MAGIC, ) + +# Check that the object is a list or a subclass of list +check_list = custom_op( + arg_types=[object_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyList_Check", + error_kind=ERR_NEVER, +) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index a4c7443bd6d3a..e178edde4f27b 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1198,3 +1198,38 @@ L5: L6: r24 = box(None, 1) return r24 +[case testMatchEmptySequencePattern_python3_10] +def f(x): + match x: + case []: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: str + r3 :: object + r4 :: str + r5 :: object + r6 :: object[1] + r7 :: object_ptr + r8, r9 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = 'matched' + r3 = builtins :: module + r4 = 'print' + r5 = CPyObject_GetAttr(r3, r4) + r6 = [r2] + r7 = load_address r6 + r8 = _PyObject_Vectorcall(r5, r7, 1, 0) + keep_alive r2 + goto L3 +L2: +L3: + r9 = box(None, 1) + return r9 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 53cc7cd6c0da5..07f846dc15277 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -54,6 +54,9 @@ def f(x): case {}: print("test map final") + case []: + print("test sequence final") + case _: print("test final") [file driver.py] @@ -106,6 +109,9 @@ f({"test": 12, "key": "value", "abc": "123"}) # test map final f({}) +# test sequence final +f([]) + # test final f("") @@ -131,4 +137,5 @@ test 12 (rest={}) test 12 (rest={'key': 'value'}) test 12 (rest={'key': 'value', 'abc': '123'}) test map final +test sequence final test final From c59c46521c260056fddac4c555c21cb9ea5eb819 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 14:34:13 -0700 Subject: [PATCH 65/97] Add basic sequence support --- mypyc/irbuild/match.py | 18 ++++++++- mypyc/test-data/irbuild-match.test | 63 ++++++++++++++++++++++++++++++ mypyc/test-data/run-match.test | 7 ++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index cb316c270067e..12517fa70306a 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -10,6 +10,7 @@ MappingPattern, SingletonPattern, SequencePattern, + StarredPattern, ValuePattern, ) from mypy.traverser import TraverserVisitor @@ -22,7 +23,7 @@ dict_del_item, slow_isinstance_op, ) -from mypyc.primitives.list_ops import check_list +from mypyc.primitives.list_ops import check_list, list_get_item_op from mypyc.irbuild.builder import IRBuilder class MatchVisitor(TraverserVisitor): @@ -234,6 +235,8 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) def visit_sequence_pattern(self, pattern: SequencePattern) -> None: + assert not any(isinstance(p, StarredPattern) for p in pattern.patterns) + is_list = self.builder.call_c( check_list, [self.subject], @@ -242,6 +245,19 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) + for i, p in enumerate(pattern.patterns): + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + item = self.builder.call_c( + list_get_item_op, + [self.subject, self.builder.load_int(i)], + p.line, + ) + + with self.enter_subpattern(item): + p.accept(self) + def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: if self.as_pattern and self.as_pattern.name: if new_block: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index e178edde4f27b..c00a3b5d4ddc1 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1233,3 +1233,66 @@ L2: L3: r9 = box(None, 1) return r9 +[case testMatchFixedLengthSequencePattern_python3_10] +def f(x): + match x: + case [1, 2]: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: list + r3, r4, r5 :: object + r6 :: int32 + r7 :: bit + r8 :: bool + r9 :: list + r10, r11, r12 :: object + r13 :: int32 + r14 :: bit + r15 :: bool + r16 :: str + r17 :: object + r18 :: str + r19 :: object + r20 :: object[1] + r21 :: object_ptr + r22, r23 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L4 :: bool +L1: + r2 = cast(list, x) + r3 = CPyList_GetItem(r2, 0) + r4 = object 1 + r5 = PyObject_RichCompare(r3, r4, 2) + r6 = PyObject_IsTrue(r5) + r7 = r6 >= 0 :: signed + r8 = truncate r6: int32 to builtins.bool + if r8 goto L2 else goto L4 :: bool +L2: + r9 = cast(list, x) + r10 = CPyList_GetItem(r9, 2) + r11 = object 2 + r12 = PyObject_RichCompare(r10, r11, 2) + r13 = PyObject_IsTrue(r12) + r14 = r13 >= 0 :: signed + r15 = truncate r13: int32 to builtins.bool + if r15 goto L3 else goto L4 :: bool +L3: + r16 = 'matched' + r17 = builtins :: module + r18 = 'print' + r19 = CPyObject_GetAttr(r17, r18) + r20 = [r16] + r21 = load_address r20 + r22 = _PyObject_Vectorcall(r19, r21, 1, 0) + keep_alive r16 + goto L5 +L4: +L5: + r23 = box(None, 1) + return r23 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 07f846dc15277..d8a6dc29b4cb3 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -54,6 +54,9 @@ def f(x): case {}: print("test map final") + case ["test", 13]: + print("test 13") + case []: print("test sequence final") @@ -109,6 +112,9 @@ f({"test": 12, "key": "value", "abc": "123"}) # test map final f({}) +# test 13 +f(["test", 13]) + # test sequence final f([]) @@ -137,5 +143,6 @@ test 12 (rest={}) test 12 (rest={'key': 'value'}) test 12 (rest={'key': 'value', 'abc': '123'}) test map final +test 13 test sequence final test final From fc74fa226b3e52db01489714122ed45850ca8211 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 16:03:08 -0700 Subject: [PATCH 66/97] Get unbound sequence capture working at end of list (almost): We need to do a check to see that the list is not too short for the number of patterns we have. Currently there is a failing test, which will be fixed momentarily. --- mypyc/irbuild/match.py | 13 +++++- mypyc/test-data/irbuild-match.test | 63 ++++++++++++++++++++++++++++++ mypyc/test-data/run-match.test | 9 +++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 12517fa70306a..5f7e6dc0ca41e 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -235,7 +235,15 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) def visit_sequence_pattern(self, pattern: SequencePattern) -> None: - assert not any(isinstance(p, StarredPattern) for p in pattern.patterns) + index = -1 + + for i, p in enumerate(pattern.patterns): + if isinstance(p, StarredPattern): + index = i + + assert not p.capture + + assert index in (-1, len(pattern.patterns) - 1) is_list = self.builder.call_c( check_list, @@ -246,6 +254,9 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) for i, p in enumerate(pattern.patterns): + if i == index: + continue + self.builder.activate_block(self.code_block) self.code_block = BasicBlock() diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index c00a3b5d4ddc1..8ce0a44cf9d42 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1296,3 +1296,66 @@ L4: L5: r23 = box(None, 1) return r23 +[case testMatchSequencePatternWithTrailingUnboundStar_python3_10] +def f(x): + match x: + case [1, 2, *_]: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: list + r3, r4, r5 :: object + r6 :: int32 + r7 :: bit + r8 :: bool + r9 :: list + r10, r11, r12 :: object + r13 :: int32 + r14 :: bit + r15 :: bool + r16 :: str + r17 :: object + r18 :: str + r19 :: object + r20 :: object[1] + r21 :: object_ptr + r22, r23 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L4 :: bool +L1: + r2 = cast(list, x) + r3 = CPyList_GetItem(r2, 0) + r4 = object 1 + r5 = PyObject_RichCompare(r3, r4, 2) + r6 = PyObject_IsTrue(r5) + r7 = r6 >= 0 :: signed + r8 = truncate r6: int32 to builtins.bool + if r8 goto L2 else goto L4 :: bool +L2: + r9 = cast(list, x) + r10 = CPyList_GetItem(r9, 2) + r11 = object 2 + r12 = PyObject_RichCompare(r10, r11, 2) + r13 = PyObject_IsTrue(r12) + r14 = r13 >= 0 :: signed + r15 = truncate r13: int32 to builtins.bool + if r15 goto L3 else goto L4 :: bool +L3: + r16 = 'matched' + r17 = builtins :: module + r18 = 'print' + r19 = CPyObject_GetAttr(r17, r18) + r20 = [r16] + r21 = load_address r20 + r22 = _PyObject_Vectorcall(r19, r21, 1, 0) + keep_alive r16 + goto L5 +L4: +L5: + r23 = box(None, 1) + return r23 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index d8a6dc29b4cb3..074636f3221e7 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -57,6 +57,9 @@ def f(x): case ["test", 13]: print("test 13") + case ["test", 14, *_]: + print("test 14") + case []: print("test sequence final") @@ -115,6 +118,10 @@ f({}) # test 13 f(["test", 13]) +# test 14 +f(["test", 14]) +f(["test", 14, "something"]) + # test sequence final f([]) @@ -144,5 +151,7 @@ test 12 (rest={'key': 'value'}) test 12 (rest={'key': 'value', 'abc': '123'}) test map final test 13 +test 14 +test 14 test sequence final test final From 739962fd3b710aefd81a709e0f7049e7ae02b5db Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 16:18:46 -0700 Subject: [PATCH 67/97] Fix last commit --- mypyc/irbuild/match.py | 21 +++- mypyc/test-data/irbuild-match.test | 194 ++++++++++++++++------------- 2 files changed, 124 insertions(+), 91 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 5f7e6dc0ca41e..e5e3ba49856ce 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -24,6 +24,7 @@ slow_isinstance_op, ) from mypyc.primitives.list_ops import check_list, list_get_item_op +from mypyc.primitives.generic_ops import generic_ssize_t_len_op from mypyc.irbuild.builder import IRBuilder class MatchVisitor(TraverserVisitor): @@ -253,9 +254,27 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) + min_len = len(pattern.patterns) - (0 if index == -1 else 1) + + if min_len: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + actual_len = self.builder.call_c( + generic_ssize_t_len_op, + [self.subject], + pattern.line, + ) + + is_long_enough = self.builder.binary_op( + self.builder.load_int(min_len), actual_len, "<=", pattern.line + ) + + self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) + for i, p in enumerate(pattern.patterns): if i == index: - continue + break self.builder.activate_block(self.code_block) self.code_block = BasicBlock() diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 8ce0a44cf9d42..3877d9df26dbc 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1243,59 +1243,66 @@ def f(x): x :: object r0 :: int32 r1 :: bit - r2 :: list - r3, r4, r5 :: object - r6 :: int32 - r7 :: bit - r8 :: bool - r9 :: list - r10, r11, r12 :: object - r13 :: int32 - r14 :: bit - r15 :: bool - r16 :: str - r17 :: object - r18 :: str - r19 :: object - r20 :: object[1] - r21 :: object_ptr - r22, r23 :: object + r2 :: native_int + r3, r4 :: bit + r5 :: list + r6, r7, r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool + r12 :: list + r13, r14, r15 :: object + r16 :: int32 + r17 :: bit + r18 :: bool + r19 :: str + r20 :: object + r21 :: str + r22 :: object + r23 :: object[1] + r24 :: object_ptr + r25, r26 :: object L0: r0 = PyList_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L4 :: bool + if r1 goto L1 else goto L5 :: bool L1: - r2 = cast(list, x) - r3 = CPyList_GetItem(r2, 0) - r4 = object 1 - r5 = PyObject_RichCompare(r3, r4, 2) - r6 = PyObject_IsTrue(r5) - r7 = r6 >= 0 :: signed - r8 = truncate r6: int32 to builtins.bool - if r8 goto L2 else goto L4 :: bool + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = 2 >= r2 :: signed + if r4 goto L2 else goto L5 :: bool L2: - r9 = cast(list, x) - r10 = CPyList_GetItem(r9, 2) - r11 = object 2 - r12 = PyObject_RichCompare(r10, r11, 2) - r13 = PyObject_IsTrue(r12) - r14 = r13 >= 0 :: signed - r15 = truncate r13: int32 to builtins.bool - if r15 goto L3 else goto L4 :: bool + r5 = cast(list, x) + r6 = CPyList_GetItem(r5, 0) + r7 = object 1 + r8 = PyObject_RichCompare(r6, r7, 2) + r9 = PyObject_IsTrue(r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + if r11 goto L3 else goto L5 :: bool L3: - r16 = 'matched' - r17 = builtins :: module - r18 = 'print' - r19 = CPyObject_GetAttr(r17, r18) - r20 = [r16] - r21 = load_address r20 - r22 = _PyObject_Vectorcall(r19, r21, 1, 0) - keep_alive r16 - goto L5 + r12 = cast(list, x) + r13 = CPyList_GetItem(r12, 2) + r14 = object 2 + r15 = PyObject_RichCompare(r13, r14, 2) + r16 = PyObject_IsTrue(r15) + r17 = r16 >= 0 :: signed + r18 = truncate r16: int32 to builtins.bool + if r18 goto L4 else goto L5 :: bool L4: + r19 = 'matched' + r20 = builtins :: module + r21 = 'print' + r22 = CPyObject_GetAttr(r20, r21) + r23 = [r19] + r24 = load_address r23 + r25 = _PyObject_Vectorcall(r22, r24, 1, 0) + keep_alive r19 + goto L6 L5: - r23 = box(None, 1) - return r23 +L6: + r26 = box(None, 1) + return r26 [case testMatchSequencePatternWithTrailingUnboundStar_python3_10] def f(x): match x: @@ -1306,56 +1313,63 @@ def f(x): x :: object r0 :: int32 r1 :: bit - r2 :: list - r3, r4, r5 :: object - r6 :: int32 - r7 :: bit - r8 :: bool - r9 :: list - r10, r11, r12 :: object - r13 :: int32 - r14 :: bit - r15 :: bool - r16 :: str - r17 :: object - r18 :: str - r19 :: object - r20 :: object[1] - r21 :: object_ptr - r22, r23 :: object + r2 :: native_int + r3, r4 :: bit + r5 :: list + r6, r7, r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool + r12 :: list + r13, r14, r15 :: object + r16 :: int32 + r17 :: bit + r18 :: bool + r19 :: str + r20 :: object + r21 :: str + r22 :: object + r23 :: object[1] + r24 :: object_ptr + r25, r26 :: object L0: r0 = PyList_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L4 :: bool + if r1 goto L1 else goto L5 :: bool L1: - r2 = cast(list, x) - r3 = CPyList_GetItem(r2, 0) - r4 = object 1 - r5 = PyObject_RichCompare(r3, r4, 2) - r6 = PyObject_IsTrue(r5) - r7 = r6 >= 0 :: signed - r8 = truncate r6: int32 to builtins.bool - if r8 goto L2 else goto L4 :: bool + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = 2 >= r2 :: signed + if r4 goto L2 else goto L5 :: bool L2: - r9 = cast(list, x) - r10 = CPyList_GetItem(r9, 2) - r11 = object 2 - r12 = PyObject_RichCompare(r10, r11, 2) - r13 = PyObject_IsTrue(r12) - r14 = r13 >= 0 :: signed - r15 = truncate r13: int32 to builtins.bool - if r15 goto L3 else goto L4 :: bool + r5 = cast(list, x) + r6 = CPyList_GetItem(r5, 0) + r7 = object 1 + r8 = PyObject_RichCompare(r6, r7, 2) + r9 = PyObject_IsTrue(r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + if r11 goto L3 else goto L5 :: bool L3: - r16 = 'matched' - r17 = builtins :: module - r18 = 'print' - r19 = CPyObject_GetAttr(r17, r18) - r20 = [r16] - r21 = load_address r20 - r22 = _PyObject_Vectorcall(r19, r21, 1, 0) - keep_alive r16 - goto L5 + r12 = cast(list, x) + r13 = CPyList_GetItem(r12, 2) + r14 = object 2 + r15 = PyObject_RichCompare(r13, r14, 2) + r16 = PyObject_IsTrue(r15) + r17 = r16 >= 0 :: signed + r18 = truncate r16: int32 to builtins.bool + if r18 goto L4 else goto L5 :: bool L4: + r19 = 'matched' + r20 = builtins :: module + r21 = 'print' + r22 = CPyObject_GetAttr(r20, r21) + r23 = [r19] + r24 = load_address r23 + r25 = _PyObject_Vectorcall(r22, r24, 1, 0) + keep_alive r19 + goto L6 L5: - r23 = box(None, 1) - return r23 +L6: + r26 = box(None, 1) + return r26 From 011538163fea3337a5da731eb287c7ca7dd66655 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Sun, 23 Oct 2022 17:06:00 -0700 Subject: [PATCH 68/97] Updates --- mypyc/irbuild/match.py | 54 ++++++++++++----- mypyc/test-data/irbuild-match.test | 95 +++++++++++++++++++++++++++++- mypyc/test-data/run-match.test | 10 ++++ 3 files changed, 141 insertions(+), 18 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index e5e3ba49856ce..d3846a2dae9c7 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from typing import Generator -from mypy.nodes import MatchStmt, TypeInfo +from mypy.nodes import MatchStmt, NameExpr, TypeInfo from mypyc.ir.ops import Value, BasicBlock from mypy.patterns import ( AsPattern, @@ -23,7 +23,7 @@ dict_del_item, slow_isinstance_op, ) -from mypyc.primitives.list_ops import check_list, list_get_item_op +from mypyc.primitives.list_ops import check_list, list_get_item_op, list_slice_op from mypyc.primitives.generic_ops import generic_ssize_t_len_op from mypyc.irbuild.builder import IRBuilder @@ -237,12 +237,12 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: def visit_sequence_pattern(self, pattern: SequencePattern) -> None: index = -1 + capture: NameExpr | None = None for i, p in enumerate(pattern.patterns): if isinstance(p, StarredPattern): index = i - - assert not p.capture + capture = p.capture assert index in (-1, len(pattern.patterns) - 1) @@ -256,21 +256,23 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: min_len = len(pattern.patterns) - (0 if index == -1 else 1) - if min_len: - self.builder.activate_block(self.code_block) - self.code_block = BasicBlock() + if not min_len: + return - actual_len = self.builder.call_c( - generic_ssize_t_len_op, - [self.subject], - pattern.line, - ) + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() - is_long_enough = self.builder.binary_op( - self.builder.load_int(min_len), actual_len, "<=", pattern.line - ) + actual_len = self.builder.call_c( + generic_ssize_t_len_op, + [self.subject], + pattern.line, + ) - self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) + is_long_enough = self.builder.binary_op( + self.builder.load_int(min_len), actual_len, "<=", pattern.line + ) + + self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) for i, p in enumerate(pattern.patterns): if i == index: @@ -288,6 +290,26 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: with self.enter_subpattern(item): p.accept(self) + if capture: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + target = self.builder.get_assignment_target(capture) + + rest = self.builder.call_c( + list_slice_op, + [ + self.subject, + self.builder.load_int(index), + actual_len, + ], + capture.line, + ) + + self.builder.assign(target, rest, capture.line) + + self.builder.goto(self.code_block) + def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: if self.as_pattern and self.as_pattern.name: if new_block: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 3877d9df26dbc..d63515a35c27c 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1269,7 +1269,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 >= r2 :: signed + r4 = 2 <= r2 :: signed if r4 goto L2 else goto L5 :: bool L2: r5 = cast(list, x) @@ -1339,7 +1339,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 >= r2 :: signed + r4 = 2 <= r2 :: signed if r4 goto L2 else goto L5 :: bool L2: r5 = cast(list, x) @@ -1373,3 +1373,94 @@ L5: L6: r26 = box(None, 1) return r26 +[case testMatchSequencePatternWithTrailingBoundStar_python3_10] +def f(x): + match x: + case [1, 2, *rest]: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: native_int + r3, r4 :: bit + r5 :: list + r6, r7, r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool + r12 :: list + r13, r14, r15 :: object + r16 :: int32 + r17 :: bit + r18 :: bool + r19 :: list + r20, r21 :: bit + r22, r23, r24 :: int + r25, rest :: object + r26 :: str + r27 :: object + r28 :: str + r29 :: object + r30 :: object[1] + r31 :: object_ptr + r32, r33 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L10 :: bool +L1: + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = 2 <= r2 :: signed + if r4 goto L2 else goto L10 :: bool +L2: + r5 = cast(list, x) + r6 = CPyList_GetItem(r5, 0) + r7 = object 1 + r8 = PyObject_RichCompare(r6, r7, 2) + r9 = PyObject_IsTrue(r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + if r11 goto L3 else goto L10 :: bool +L3: + r12 = cast(list, x) + r13 = CPyList_GetItem(r12, 2) + r14 = object 2 + r15 = PyObject_RichCompare(r13, r14, 2) + r16 = PyObject_IsTrue(r15) + r17 = r16 >= 0 :: signed + r18 = truncate r16: int32 to builtins.bool + if r18 goto L4 else goto L10 :: bool +L4: + r19 = cast(list, x) + r20 = r2 <= 4611686018427387903 :: signed + if r20 goto L5 else goto L6 :: bool +L5: + r21 = r2 >= -4611686018427387904 :: signed + if r21 goto L7 else goto L6 :: bool +L6: + r22 = CPyTagged_FromInt64(r2) + r23 = r22 + goto L8 +L7: + r24 = r2 << 1 + r23 = r24 +L8: + r25 = CPyList_GetSlice(r19, 4, r23) + rest = r25 +L9: + r26 = 'matched' + r27 = builtins :: module + r28 = 'print' + r29 = CPyObject_GetAttr(r27, r28) + r30 = [r26] + r31 = load_address r30 + r32 = _PyObject_Vectorcall(r29, r31, 1, 0) + keep_alive r26 + goto L11 +L10: +L11: + r33 = box(None, 1) + return r33 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 074636f3221e7..277c42460d802 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -60,6 +60,10 @@ def f(x): case ["test", 14, *_]: print("test 14") + # TODO: Fix "rest" being used here coliding with above "rest" + case ["test", 15, *rest2]: + print(f"test 15 ({rest2})") + case []: print("test sequence final") @@ -122,6 +126,10 @@ f(["test", 13]) f(["test", 14]) f(["test", 14, "something"]) +# test 15 +f(["test", 15]) +f(["test", 15, "something"]) + # test sequence final f([]) @@ -153,5 +161,7 @@ test map final test 13 test 14 test 14 +test 15 ([]) +test 15 (['something']) test sequence final test final From 38a5cc511534b18d29fc59d14212cb560ce3ef2e Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 14:33:51 -0700 Subject: [PATCH 69/97] Require exact size for fixed length sequences --- mypyc/irbuild/match.py | 13 +++++++------ mypyc/test-data/run-match.test | 7 +++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index d3846a2dae9c7..25399ac895e6d 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -236,7 +236,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) def visit_sequence_pattern(self, pattern: SequencePattern) -> None: - index = -1 + index: int | None = None capture: NameExpr | None = None for i, p in enumerate(pattern.patterns): @@ -244,8 +244,6 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: index = i capture = p.capture - assert index in (-1, len(pattern.patterns) - 1) - is_list = self.builder.call_c( check_list, [self.subject], @@ -254,7 +252,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) - min_len = len(pattern.patterns) - (0 if index == -1 else 1) + min_len = len(pattern.patterns) - (0 if index is None else 1) if not min_len: return @@ -269,7 +267,10 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: ) is_long_enough = self.builder.binary_op( - self.builder.load_int(min_len), actual_len, "<=", pattern.line + self.builder.load_int(min_len), + actual_len, + "==" if index is None else "<=", + pattern.line ) self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) @@ -290,7 +291,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: with self.enter_subpattern(item): p.accept(self) - if capture: + if capture and index: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 277c42460d802..266a7094a0565 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -57,6 +57,9 @@ def f(x): case ["test", 13]: print("test 13") + case ["test", 13, _]: + print("test 13b") + case ["test", 14, *_]: print("test 14") @@ -122,6 +125,9 @@ f({}) # test 13 f(["test", 13]) +# test 13b +f(["test", 13, "fail"]) + # test 14 f(["test", 14]) f(["test", 14, "something"]) @@ -159,6 +165,7 @@ test 12 (rest={'key': 'value'}) test 12 (rest={'key': 'value', 'abc': '123'}) test map final test 13 +test 13b test 14 test 14 test 15 ([]) From 04f0cbcb9a069bcc8830c946284835993975eaca Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 15:17:44 -0700 Subject: [PATCH 70/97] Very hacky support for star pattern in middle of list --- mypyc/irbuild/match.py | 23 +++++- mypyc/test-data/irbuild-match.test | 114 ++++++++++++++++++++++++++++- mypyc/test-data/run-match.test | 11 +++ 3 files changed, 144 insertions(+), 4 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 25399ac895e6d..1d982d9722a69 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -275,16 +275,33 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) + idk = False + capture_end = actual_len + for i, p in enumerate(pattern.patterns): if i == index: - break + idk = True + continue self.builder.activate_block(self.code_block) self.code_block = BasicBlock() + if idk: + current = self.builder.binary_op( + actual_len, + self.builder.load_int(min_len - i + 1), + "-", + p.line, + ) + + capture_end = current + + else: + current = self.builder.load_int(i) + item = self.builder.call_c( list_get_item_op, - [self.subject, self.builder.load_int(i)], + [self.subject, current], p.line, ) @@ -302,7 +319,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: [ self.subject, self.builder.load_int(index), - actual_len, + capture_end, ], capture.line, ) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index d63515a35c27c..77dc946212d4b 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1269,7 +1269,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 <= r2 :: signed + r4 = 2 == r2 if r4 goto L2 else goto L5 :: bool L2: r5 = cast(list, x) @@ -1464,3 +1464,115 @@ L10: L11: r33 = box(None, 1) return r33 +[case testMatchSequenceWithStarPatternInTheMiddle_python3_10] +def f(x): + match x: + case ["start", *rest, "end"]: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: native_int + r3, r4 :: bit + r5 :: list + r6 :: object + r7 :: str + r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool + r12 :: native_int + r13 :: list + r14, r15 :: bit + r16, r17, r18 :: int + r19 :: object + r20 :: str + r21 :: object + r22 :: int32 + r23 :: bit + r24 :: bool + r25 :: list + r26, r27 :: bit + r28, r29, r30 :: int + r31, rest :: object + r32 :: str + r33 :: object + r34 :: str + r35 :: object + r36 :: object[1] + r37 :: object_ptr + r38, r39 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L14 :: bool +L1: + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = 2 <= r2 :: signed + if r4 goto L2 else goto L14 :: bool +L2: + r5 = cast(list, x) + r6 = CPyList_GetItem(r5, 0) + r7 = 'start' + r8 = PyObject_RichCompare(r6, r7, 2) + r9 = PyObject_IsTrue(r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + if r11 goto L3 else goto L14 :: bool +L3: + r12 = r2 - 1 + r13 = cast(list, x) + r14 = r12 <= 4611686018427387903 :: signed + if r14 goto L4 else goto L5 :: bool +L4: + r15 = r12 >= -4611686018427387904 :: signed + if r15 goto L6 else goto L5 :: bool +L5: + r16 = CPyTagged_FromInt64(r12) + r17 = r16 + goto L7 +L6: + r18 = r12 << 1 + r17 = r18 +L7: + r19 = CPyList_GetItem(r13, r17) + r20 = 'end' + r21 = PyObject_RichCompare(r19, r20, 2) + r22 = PyObject_IsTrue(r21) + r23 = r22 >= 0 :: signed + r24 = truncate r22: int32 to builtins.bool + if r24 goto L8 else goto L14 :: bool +L8: + r25 = cast(list, x) + r26 = r12 <= 4611686018427387903 :: signed + if r26 goto L9 else goto L10 :: bool +L9: + r27 = r12 >= -4611686018427387904 :: signed + if r27 goto L11 else goto L10 :: bool +L10: + r28 = CPyTagged_FromInt64(r12) + r29 = r28 + goto L12 +L11: + r30 = r12 << 1 + r29 = r30 +L12: + r31 = CPyList_GetSlice(r25, 2, r29) + rest = r31 +L13: + r32 = 'matched' + r33 = builtins :: module + r34 = 'print' + r35 = CPyObject_GetAttr(r33, r34) + r36 = [r32] + r37 = load_address r36 + r38 = _PyObject_Vectorcall(r35, r37, 1, 0) + keep_alive r32 + goto L15 +L14: +L15: + r39 = box(None, 1) + return r39 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 266a7094a0565..587d6b5f8a4c2 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -67,6 +67,9 @@ def f(x): case ["test", 15, *rest2]: print(f"test 15 ({rest2})") + case ["test", *rest3, 16]: + print(f"test 16 ({rest3})") + case []: print("test sequence final") @@ -136,6 +139,11 @@ f(["test", 14, "something"]) f(["test", 15]) f(["test", 15, "something"]) +# test 16 +f(["test", 16]) +f(["test", "filler", 16]) +f(["test", "more", "filler", 16]) + # test sequence final f([]) @@ -170,5 +178,8 @@ test 14 test 14 test 15 ([]) test 15 (['something']) +test 16 ([]) +test 16 (['filler']) +test 16 (['more', 'filler']) test sequence final test final From 360f686c9e3b192371d38b78eb5a314984166fd7 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 15:44:09 -0700 Subject: [PATCH 71/97] Hackily add leading star pattern --- mypyc/irbuild/match.py | 10 ++- mypyc/test-data/irbuild-match.test | 127 +++++++++++++++++++++++++++++ mypyc/test-data/run-match.test | 11 +++ 3 files changed, 147 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 1d982d9722a69..9f887a4157046 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -308,12 +308,20 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: with self.enter_subpattern(item): p.accept(self) - if capture and index: + if capture and index is not None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() target = self.builder.get_assignment_target(capture) + if index == 0: + capture_end = self.builder.binary_op( + capture_end, + self.builder.load_int(1), + "-", + pattern.patterns[0].line, + ) + rest = self.builder.call_c( list_slice_op, [ diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 77dc946212d4b..9e5d8037394aa 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1576,3 +1576,130 @@ L14: L15: r39 = box(None, 1) return r39 +[case testMatchSequenceWithStarPatternAtTheStart_python3_10] +def f(x): + match x: + case [*rest, 1, 2]: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: native_int + r3, r4 :: bit + r5 :: native_int + r6 :: list + r7, r8 :: bit + r9, r10, r11 :: int + r12, r13, r14 :: object + r15 :: int32 + r16 :: bit + r17 :: bool + r18 :: native_int + r19 :: list + r20, r21 :: bit + r22, r23, r24 :: int + r25, r26, r27 :: object + r28 :: int32 + r29 :: bit + r30 :: bool + r31 :: native_int + r32 :: list + r33, r34 :: bit + r35, r36, r37 :: int + r38, rest :: object + r39 :: str + r40 :: object + r41 :: str + r42 :: object + r43 :: object[1] + r44 :: object_ptr + r45, r46 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L18 :: bool +L1: + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = 2 <= r2 :: signed + if r4 goto L2 else goto L18 :: bool +L2: + r5 = r2 - 2 + r6 = cast(list, x) + r7 = r5 <= 4611686018427387903 :: signed + if r7 goto L3 else goto L4 :: bool +L3: + r8 = r5 >= -4611686018427387904 :: signed + if r8 goto L5 else goto L4 :: bool +L4: + r9 = CPyTagged_FromInt64(r5) + r10 = r9 + goto L6 +L5: + r11 = r5 << 1 + r10 = r11 +L6: + r12 = CPyList_GetItem(r6, r10) + r13 = object 1 + r14 = PyObject_RichCompare(r12, r13, 2) + r15 = PyObject_IsTrue(r14) + r16 = r15 >= 0 :: signed + r17 = truncate r15: int32 to builtins.bool + if r17 goto L7 else goto L18 :: bool +L7: + r18 = r2 - 1 + r19 = cast(list, x) + r20 = r18 <= 4611686018427387903 :: signed + if r20 goto L8 else goto L9 :: bool +L8: + r21 = r18 >= -4611686018427387904 :: signed + if r21 goto L10 else goto L9 :: bool +L9: + r22 = CPyTagged_FromInt64(r18) + r23 = r22 + goto L11 +L10: + r24 = r18 << 1 + r23 = r24 +L11: + r25 = CPyList_GetItem(r19, r23) + r26 = object 2 + r27 = PyObject_RichCompare(r25, r26, 2) + r28 = PyObject_IsTrue(r27) + r29 = r28 >= 0 :: signed + r30 = truncate r28: int32 to builtins.bool + if r30 goto L12 else goto L18 :: bool +L12: + r31 = r18 - 1 + r32 = cast(list, x) + r33 = r31 <= 4611686018427387903 :: signed + if r33 goto L13 else goto L14 :: bool +L13: + r34 = r31 >= -4611686018427387904 :: signed + if r34 goto L15 else goto L14 :: bool +L14: + r35 = CPyTagged_FromInt64(r31) + r36 = r35 + goto L16 +L15: + r37 = r31 << 1 + r36 = r37 +L16: + r38 = CPyList_GetSlice(r32, 0, r36) + rest = r38 +L17: + r39 = 'matched' + r40 = builtins :: module + r41 = 'print' + r42 = CPyObject_GetAttr(r40, r41) + r43 = [r39] + r44 = load_address r43 + r45 = _PyObject_Vectorcall(r42, r44, 1, 0) + keep_alive r39 + goto L19 +L18: +L19: + r46 = box(None, 1) + return r46 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 587d6b5f8a4c2..5bda532aadd5e 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -70,6 +70,9 @@ def f(x): case ["test", *rest3, 16]: print(f"test 16 ({rest3})") + case [*rest4, "test", 17]: + print(f"test 17 ({rest4})") + case []: print("test sequence final") @@ -144,6 +147,11 @@ f(["test", 16]) f(["test", "filler", 16]) f(["test", "more", "filler", 16]) +# test 17 +f(["test", 17]) +f(["stuff", "test", 17]) +f(["more", "stuff", "test", 17]) + # test sequence final f([]) @@ -181,5 +189,8 @@ test 15 (['something']) test 16 ([]) test 16 (['filler']) test 16 (['more', 'filler']) +test 17 ([]) +test 17 (['stuff']) +test 17 (['more', 'stuff']) test sequence final test final From 9714a2c622b22d1803ba8869ce5977ccfd57272a Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 16:53:05 -0700 Subject: [PATCH 72/97] Cleanups --- mypyc/irbuild/match.py | 18 +++++++++--------- mypyc/test-data/irbuild-match.test | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 9f887a4157046..be474afd9ff1a 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -236,12 +236,12 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) def visit_sequence_pattern(self, pattern: SequencePattern) -> None: - index: int | None = None + star_index: int | None = None capture: NameExpr | None = None for i, p in enumerate(pattern.patterns): if isinstance(p, StarredPattern): - index = i + star_index = i capture = p.capture is_list = self.builder.call_c( @@ -252,7 +252,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) - min_len = len(pattern.patterns) - (0 if index is None else 1) + min_len = len(pattern.patterns) - (0 if star_index is None else 1) if not min_len: return @@ -267,9 +267,9 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: ) is_long_enough = self.builder.binary_op( - self.builder.load_int(min_len), actual_len, - "==" if index is None else "<=", + self.builder.load_int(min_len), + "==" if star_index is None else ">=", pattern.line ) @@ -279,7 +279,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: capture_end = actual_len for i, p in enumerate(pattern.patterns): - if i == index: + if i == star_index: idk = True continue @@ -308,13 +308,13 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: with self.enter_subpattern(item): p.accept(self) - if capture and index is not None: + if capture and star_index is not None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() target = self.builder.get_assignment_target(capture) - if index == 0: + if star_index == 0: capture_end = self.builder.binary_op( capture_end, self.builder.load_int(1), @@ -326,7 +326,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: list_slice_op, [ self.subject, - self.builder.load_int(index), + self.builder.load_int(star_index), capture_end, ], capture.line, diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 9e5d8037394aa..34f5464dbf6e9 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1269,7 +1269,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 == r2 + r4 = r2 == 2 if r4 goto L2 else goto L5 :: bool L2: r5 = cast(list, x) @@ -1339,7 +1339,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 <= r2 :: signed + r4 = r2 >= 2 :: signed if r4 goto L2 else goto L5 :: bool L2: r5 = cast(list, x) @@ -1413,7 +1413,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 <= r2 :: signed + r4 = r2 >= 2 :: signed if r4 goto L2 else goto L10 :: bool L2: r5 = cast(list, x) @@ -1511,7 +1511,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 <= r2 :: signed + r4 = r2 >= 2 :: signed if r4 goto L2 else goto L14 :: bool L2: r5 = cast(list, x) @@ -1623,7 +1623,7 @@ L0: L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed - r4 = 2 <= r2 :: signed + r4 = r2 >= 2 :: signed if r4 goto L2 else goto L18 :: bool L2: r5 = r2 - 2 From 5b70d51bd18361a324450b2c7965d55d8c6138c5 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 16:59:40 -0700 Subject: [PATCH 73/97] Renaming --- mypyc/irbuild/match.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index be474afd9ff1a..91c2c432c6ad0 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -275,18 +275,18 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) - idk = False + on_rhs_of_star_pattern = False capture_end = actual_len for i, p in enumerate(pattern.patterns): if i == star_index: - idk = True + on_rhs_of_star_pattern = True continue self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - if idk: + if on_rhs_of_star_pattern: current = self.builder.binary_op( actual_len, self.builder.load_int(min_len - i + 1), From 25564c869cd13ec99b792adc81545777f15aa87a Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 17:07:49 -0700 Subject: [PATCH 74/97] Cleanup --- mypyc/irbuild/match.py | 7 ++++--- mypyc/test-data/run-match.test | 11 +++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 91c2c432c6ad0..7ce957ea5401a 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -294,7 +294,8 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: p.line, ) - capture_end = current + if star_index != 0: + capture_end = current else: current = self.builder.load_int(i) @@ -316,8 +317,8 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: if star_index == 0: capture_end = self.builder.binary_op( - capture_end, - self.builder.load_int(1), + actual_len, + self.builder.load_int(min_len), "-", pattern.patterns[0].line, ) diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 5bda532aadd5e..10b5e7a20a08f 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -73,6 +73,9 @@ def f(x): case [*rest4, "test", 17]: print(f"test 17 ({rest4})") + case [*rest4, "test", 18, "some", "fluff"]: + print(f"test 18 ({rest4})") + case []: print("test sequence final") @@ -152,6 +155,11 @@ f(["test", 17]) f(["stuff", "test", 17]) f(["more", "stuff", "test", 17]) +# test 18 +f(["test", 18, "some", "fluff"]) +f(["stuff", "test", 18, "some", "fluff"]) +f(["more", "stuff", "test", 18, "some", "fluff"]) + # test sequence final f([]) @@ -192,5 +200,8 @@ test 16 (['more', 'filler']) test 17 ([]) test 17 (['stuff']) test 17 (['more', 'stuff']) +test 18 ([]) +test 18 (['stuff']) +test 18 (['more', 'stuff']) test sequence final test final From 9601a2371db5559aa6d419220e2a12d98c89b6d5 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 17:20:34 -0700 Subject: [PATCH 75/97] Cleanups --- mypyc/irbuild/match.py | 17 ++-- mypyc/test-data/irbuild-match.test | 134 +++++++++++++++-------------- 2 files changed, 75 insertions(+), 76 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 7ce957ea5401a..fc192a3605824 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -276,7 +276,6 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) on_rhs_of_star_pattern = False - capture_end = actual_len for i, p in enumerate(pattern.patterns): if i == star_index: @@ -294,9 +293,6 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: p.line, ) - if star_index != 0: - capture_end = current - else: current = self.builder.load_int(i) @@ -315,13 +311,12 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: target = self.builder.get_assignment_target(capture) - if star_index == 0: - capture_end = self.builder.binary_op( - actual_len, - self.builder.load_int(min_len), - "-", - pattern.patterns[0].line, - ) + capture_end = self.builder.binary_op( + actual_len, + self.builder.load_int(min_len - star_index), + "-", + pattern.patterns[0].line, + ) rest = self.builder.call_c( list_slice_op, diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 34f5464dbf6e9..0074bff09c1a0 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1395,17 +1395,18 @@ def f(x): r16 :: int32 r17 :: bit r18 :: bool - r19 :: list - r20, r21 :: bit - r22, r23, r24 :: int - r25, rest :: object - r26 :: str - r27 :: object - r28 :: str - r29 :: object - r30 :: object[1] - r31 :: object_ptr - r32, r33 :: object + r19 :: native_int + r20 :: list + r21, r22 :: bit + r23, r24, r25 :: int + r26, rest :: object + r27 :: str + r28 :: object + r29 :: str + r30 :: object + r31 :: object[1] + r32 :: object_ptr + r33, r34 :: object L0: r0 = PyList_Check(x) r1 = r0 != 0 @@ -1434,36 +1435,37 @@ L3: r18 = truncate r16: int32 to builtins.bool if r18 goto L4 else goto L10 :: bool L4: - r19 = cast(list, x) - r20 = r2 <= 4611686018427387903 :: signed - if r20 goto L5 else goto L6 :: bool + r19 = r2 - 0 + r20 = cast(list, x) + r21 = r19 <= 4611686018427387903 :: signed + if r21 goto L5 else goto L6 :: bool L5: - r21 = r2 >= -4611686018427387904 :: signed - if r21 goto L7 else goto L6 :: bool + r22 = r19 >= -4611686018427387904 :: signed + if r22 goto L7 else goto L6 :: bool L6: - r22 = CPyTagged_FromInt64(r2) - r23 = r22 + r23 = CPyTagged_FromInt64(r19) + r24 = r23 goto L8 L7: - r24 = r2 << 1 - r23 = r24 + r25 = r19 << 1 + r24 = r25 L8: - r25 = CPyList_GetSlice(r19, 4, r23) - rest = r25 + r26 = CPyList_GetSlice(r20, 4, r24) + rest = r26 L9: - r26 = 'matched' - r27 = builtins :: module - r28 = 'print' - r29 = CPyObject_GetAttr(r27, r28) - r30 = [r26] - r31 = load_address r30 - r32 = _PyObject_Vectorcall(r29, r31, 1, 0) - keep_alive r26 + r27 = 'matched' + r28 = builtins :: module + r29 = 'print' + r30 = CPyObject_GetAttr(r28, r29) + r31 = [r27] + r32 = load_address r31 + r33 = _PyObject_Vectorcall(r30, r32, 1, 0) + keep_alive r27 goto L11 L10: L11: - r33 = box(None, 1) - return r33 + r34 = box(None, 1) + return r34 [case testMatchSequenceWithStarPatternInTheMiddle_python3_10] def f(x): match x: @@ -1493,17 +1495,18 @@ def f(x): r22 :: int32 r23 :: bit r24 :: bool - r25 :: list - r26, r27 :: bit - r28, r29, r30 :: int - r31, rest :: object - r32 :: str - r33 :: object - r34 :: str - r35 :: object - r36 :: object[1] - r37 :: object_ptr - r38, r39 :: object + r25 :: native_int + r26 :: list + r27, r28 :: bit + r29, r30, r31 :: int + r32, rest :: object + r33 :: str + r34 :: object + r35 :: str + r36 :: object + r37 :: object[1] + r38 :: object_ptr + r39, r40 :: object L0: r0 = PyList_Check(x) r1 = r0 != 0 @@ -1546,36 +1549,37 @@ L7: r24 = truncate r22: int32 to builtins.bool if r24 goto L8 else goto L14 :: bool L8: - r25 = cast(list, x) - r26 = r12 <= 4611686018427387903 :: signed - if r26 goto L9 else goto L10 :: bool + r25 = r2 - 1 + r26 = cast(list, x) + r27 = r25 <= 4611686018427387903 :: signed + if r27 goto L9 else goto L10 :: bool L9: - r27 = r12 >= -4611686018427387904 :: signed - if r27 goto L11 else goto L10 :: bool + r28 = r25 >= -4611686018427387904 :: signed + if r28 goto L11 else goto L10 :: bool L10: - r28 = CPyTagged_FromInt64(r12) - r29 = r28 + r29 = CPyTagged_FromInt64(r25) + r30 = r29 goto L12 L11: - r30 = r12 << 1 - r29 = r30 + r31 = r25 << 1 + r30 = r31 L12: - r31 = CPyList_GetSlice(r25, 2, r29) - rest = r31 + r32 = CPyList_GetSlice(r26, 2, r30) + rest = r32 L13: - r32 = 'matched' - r33 = builtins :: module - r34 = 'print' - r35 = CPyObject_GetAttr(r33, r34) - r36 = [r32] - r37 = load_address r36 - r38 = _PyObject_Vectorcall(r35, r37, 1, 0) - keep_alive r32 + r33 = 'matched' + r34 = builtins :: module + r35 = 'print' + r36 = CPyObject_GetAttr(r34, r35) + r37 = [r33] + r38 = load_address r37 + r39 = _PyObject_Vectorcall(r36, r38, 1, 0) + keep_alive r33 goto L15 L14: L15: - r39 = box(None, 1) - return r39 + r40 = box(None, 1) + return r40 [case testMatchSequenceWithStarPatternAtTheStart_python3_10] def f(x): match x: @@ -1672,7 +1676,7 @@ L11: r30 = truncate r28: int32 to builtins.bool if r30 goto L12 else goto L18 :: bool L12: - r31 = r18 - 1 + r31 = r2 - 2 r32 = cast(list, x) r33 = r31 <= 4611686018427387903 :: signed if r33 goto L13 else goto L14 :: bool From 16e8d120236c782bcaa6addc91b749b61711974e Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 17:30:47 -0700 Subject: [PATCH 76/97] Renames --- mypyc/irbuild/match.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index fc192a3605824..9ac66ba0d2bd6 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -235,24 +235,24 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) - def visit_sequence_pattern(self, pattern: SequencePattern) -> None: + def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: star_index: int | None = None capture: NameExpr | None = None - for i, p in enumerate(pattern.patterns): - if isinstance(p, StarredPattern): + for i, pattern in enumerate(seq_pattern.patterns): + if isinstance(pattern, StarredPattern): star_index = i - capture = p.capture + capture = pattern.capture is_list = self.builder.call_c( check_list, [self.subject], - pattern.line, + seq_pattern.line, ) self.builder.add_bool_branch(is_list, self.code_block, self.next_block) - min_len = len(pattern.patterns) - (0 if star_index is None else 1) + min_len = len(seq_pattern.patterns) - (0 if star_index is None else 1) if not min_len: return @@ -263,21 +263,21 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: actual_len = self.builder.call_c( generic_ssize_t_len_op, [self.subject], - pattern.line, + seq_pattern.line, ) is_long_enough = self.builder.binary_op( actual_len, self.builder.load_int(min_len), "==" if star_index is None else ">=", - pattern.line + seq_pattern.line ) self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) on_rhs_of_star_pattern = False - for i, p in enumerate(pattern.patterns): + for i, pattern in enumerate(seq_pattern.patterns): if i == star_index: on_rhs_of_star_pattern = True continue @@ -290,7 +290,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: actual_len, self.builder.load_int(min_len - i + 1), "-", - p.line, + pattern.line, ) else: @@ -299,23 +299,21 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: item = self.builder.call_c( list_get_item_op, [self.subject, current], - p.line, + pattern.line, ) with self.enter_subpattern(item): - p.accept(self) + pattern.accept(self) if capture and star_index is not None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - target = self.builder.get_assignment_target(capture) - capture_end = self.builder.binary_op( actual_len, self.builder.load_int(min_len - star_index), "-", - pattern.patterns[0].line, + seq_pattern.patterns[0].line, ) rest = self.builder.call_c( @@ -328,6 +326,7 @@ def visit_sequence_pattern(self, pattern: SequencePattern) -> None: capture.line, ) + target = self.builder.get_assignment_target(capture) self.builder.assign(target, rest, capture.line) self.builder.goto(self.code_block) From b241029e31b4d774469a38415d63ae70d98faa57 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 17:35:49 -0700 Subject: [PATCH 77/97] Cleanup --- mypyc/irbuild/match.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 9ac66ba0d2bd6..f6123d0d6bf3c 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -244,6 +244,8 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: star_index = i capture = pattern.capture + ### + is_list = self.builder.call_c( check_list, [self.subject], @@ -252,11 +254,15 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) + ### + min_len = len(seq_pattern.patterns) - (0 if star_index is None else 1) if not min_len: return + ### + self.builder.activate_block(self.code_block) self.code_block = BasicBlock() @@ -275,17 +281,16 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) - on_rhs_of_star_pattern = False + ### for i, pattern in enumerate(seq_pattern.patterns): if i == star_index: - on_rhs_of_star_pattern = True continue self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - if on_rhs_of_star_pattern: + if star_index is not None and i > star_index: current = self.builder.binary_op( actual_len, self.builder.load_int(min_len - i + 1), @@ -305,6 +310,8 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: with self.enter_subpattern(item): pattern.accept(self) + ### + if capture and star_index is not None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() From 4b2cfc352590a946c81b017ef7fee7689f834ee4 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 17:53:16 -0700 Subject: [PATCH 78/97] Cleanups --- mypyc/irbuild/match.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index f6123d0d6bf3c..85489a795616b 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -8,6 +8,7 @@ ClassPattern, OrPattern, MappingPattern, + Pattern, SingletonPattern, SequencePattern, StarredPattern, @@ -238,25 +239,27 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: star_index: int | None = None capture: NameExpr | None = None + patterns: list[Pattern] = [] for i, pattern in enumerate(seq_pattern.patterns): if isinstance(pattern, StarredPattern): star_index = i capture = pattern.capture + else: + patterns.append(pattern) + ### is_list = self.builder.call_c( - check_list, - [self.subject], - seq_pattern.line, + check_list, [self.subject], seq_pattern.line ) self.builder.add_bool_branch(is_list, self.code_block, self.next_block) ### - min_len = len(seq_pattern.patterns) - (0 if star_index is None else 1) + min_len = len(patterns) if not min_len: return @@ -283,17 +286,14 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: ### - for i, pattern in enumerate(seq_pattern.patterns): - if i == star_index: - continue - + for i, pattern in enumerate(patterns): self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - if star_index is not None and i > star_index: + if star_index is not None and i >= star_index: current = self.builder.binary_op( actual_len, - self.builder.load_int(min_len - i + 1), + self.builder.load_int(min_len - i), "-", pattern.line, ) @@ -320,7 +320,7 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: actual_len, self.builder.load_int(min_len - star_index), "-", - seq_pattern.patterns[0].line, + patterns[0].line, ) rest = self.builder.call_c( From ede2e1422876cb551e24706796d750e33da1ed4f Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 24 Oct 2022 18:03:48 -0700 Subject: [PATCH 79/97] More cleanups --- mypyc/irbuild/match.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 85489a795616b..2dbd622768cf4 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -237,19 +237,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.goto(self.code_block) def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: - star_index: int | None = None - capture: NameExpr | None = None - patterns: list[Pattern] = [] - - for i, pattern in enumerate(seq_pattern.patterns): - if isinstance(pattern, StarredPattern): - star_index = i - capture = pattern.capture - - else: - patterns.append(pattern) - - ### + star_index, capture, patterns = prep_sequence_pattern(seq_pattern) is_list = self.builder.call_c( check_list, [self.subject], seq_pattern.line @@ -257,15 +245,11 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) - ### - min_len = len(patterns) if not min_len: return - ### - self.builder.activate_block(self.code_block) self.code_block = BasicBlock() @@ -284,8 +268,6 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) - ### - for i, pattern in enumerate(patterns): self.builder.activate_block(self.code_block) self.code_block = BasicBlock() @@ -310,8 +292,6 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: with self.enter_subpattern(item): pattern.accept(self) - ### - if capture and star_index is not None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() @@ -320,7 +300,7 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: actual_len, self.builder.load_int(min_len - star_index), "-", - patterns[0].line, + capture.line, ) rest = self.builder.call_c( @@ -365,3 +345,21 @@ def enter_as_pattern(self, pattern: AsPattern) -> Generator[None, None, None]: self.as_pattern = pattern yield self.as_pattern = old_pattern + + +def prep_sequence_pattern(seq_pattern: SequencePattern) -> tuple[ + int | None, NameExpr | None, list[Pattern] +]: + star_index: int | None = None + capture: NameExpr | None = None + patterns: list[Pattern] = [] + + for i, pattern in enumerate(seq_pattern.patterns): + if isinstance(pattern, StarredPattern): + star_index = i + capture = pattern.capture + + else: + patterns.append(pattern) + + return star_index, capture, patterns From 0b30e8b6c63a91d27de5be5514ae5b580334dc27 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 25 Oct 2022 14:35:13 -0700 Subject: [PATCH 80/97] Add class pattern support for builtins --- mypyc/irbuild/match.py | 24 +++++++++++++++++ mypyc/test-data/irbuild-match.test | 42 ++++++++++++++++++++++++++++++ mypyc/test-data/run-match.test | 16 ++++++++++++ 3 files changed, 82 insertions(+) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 2dbd622768cf4..72a62447db5e5 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -28,6 +28,22 @@ from mypyc.primitives.generic_ops import generic_ssize_t_len_op from mypyc.irbuild.builder import IRBuilder +# From: https://peps.python.org/pep-0634/#class-patterns +MATCHABLE_BUILTINS = { + "builtins.bool", + "builtins.bytearray", + "builtins.bytes", + "builtins.dict", + "builtins.float", + "builtins.frozenset", + "builtins.int", + "builtins.list", + "builtins.set", + "builtins.str", + "builtins.tuple", +} + + class MatchVisitor(TraverserVisitor): builder: IRBuilder code_block: BasicBlock @@ -117,6 +133,14 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.bind_as_pattern(self.subject, new_block=True) if pattern.positionals: + if pattern.class_ref.fullname in MATCHABLE_BUILTINS: + self.builder.activate_block(self.code_block) + self.code_block = BasicBlock() + + pattern.positionals[0].accept(self) + + return + node = pattern.class_ref.node assert isinstance(node, TypeInfo) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 0074bff09c1a0..336a5965116ed 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1707,3 +1707,45 @@ L18: L19: r46 = box(None, 1) return r46 +[case testMatchBuiltinClassPattern_python3_10] +def f(x): + match x: + case int(y): + print("matched") +[out] +def f(x): + x, r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4, y :: int + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object +L0: + r0 = load_address PyLong_Type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L3 :: bool +L1: + r4 = unbox(int, x) + y = r4 +L2: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L4 +L3: +L4: + r12 = box(None, 1) + return r12 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 10b5e7a20a08f..c602c8722a3bb 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -76,6 +76,12 @@ def f(x): case [*rest4, "test", 18, "some", "fluff"]: print(f"test 18 ({rest4})") + case str("test 19"): + print("test 19") + + case str(test_20) if test_20.startswith("test 20"): + print(f"test 20 ({test_20[7:]!r})") + case []: print("test sequence final") @@ -160,6 +166,13 @@ f(["test", 18, "some", "fluff"]) f(["stuff", "test", 18, "some", "fluff"]) f(["more", "stuff", "test", 18, "some", "fluff"]) +# test 19 +f("test 19") + +# test 20 +f("test 20") +f("test 20 something else") + # test sequence final f([]) @@ -203,5 +216,8 @@ test 17 (['more', 'stuff']) test 18 ([]) test 18 (['stuff']) test 18 (['more', 'stuff']) +test 19 +test 20 ('') +test 20 (' something else') test sequence final test final From 8f0a3bf89dee022bc21415d7b8ba69bd5367c676 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 14:16:27 -0700 Subject: [PATCH 81/97] Cleanup --- mypyc/irbuild/match.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 72a62447db5e5..a5836622f46bb 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -179,8 +179,10 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: def visit_as_pattern(self, pattern: AsPattern) -> None: if pattern.pattern: - with self.enter_as_pattern(pattern): - pattern.pattern.accept(self) + old_pattern = self.as_pattern + self.as_pattern = pattern + pattern.pattern.accept(self) + self.as_pattern = old_pattern elif pattern.name: target = self.builder.get_assignment_target(pattern.name) @@ -343,13 +345,13 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.goto(self.code_block) def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: - if self.as_pattern and self.as_pattern.name: + if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name: if new_block: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() target = self.builder.get_assignment_target(self.as_pattern.name) - self.builder.assign(target, value, self.as_pattern.pattern.line) # type: ignore + self.builder.assign(target, value, self.as_pattern.pattern.line) self.as_pattern = None @@ -363,13 +365,6 @@ def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: yield self.subject = old_subject - @contextmanager - def enter_as_pattern(self, pattern: AsPattern) -> Generator[None, None, None]: - old_pattern = self.as_pattern - self.as_pattern = pattern - yield - self.as_pattern = old_pattern - def prep_sequence_pattern(seq_pattern: SequencePattern) -> tuple[ int | None, NameExpr | None, list[Pattern] From 96cd6b1df16e52cc4c756aafbfdf8c84d2888100 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 14:44:55 -0700 Subject: [PATCH 82/97] Add more tests --- mypyc/test-data/run-match.test | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index c602c8722a3bb..6df4efbb6aed6 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -82,6 +82,9 @@ def f(x): case str(test_20) if test_20.startswith("test 20"): print(f"test 20 ({test_20[7:]!r})") + case ("test 21" as value) | ("test 21 as well" as value): + print(f"test 21 ({value[7:]!r})") + case []: print("test sequence final") @@ -173,6 +176,10 @@ f("test 19") f("test 20") f("test 20 something else") +# test 21 +f("test 21") +f("test 21 as well") + # test sequence final f([]) @@ -219,5 +226,7 @@ test 18 (['more', 'stuff']) test 19 test 20 ('') test 20 (' something else') +test 21 ('') +test 21 (' as well') test sequence final test final From 47f96825b32c8229ddfdfa9aefbfb1848a5b1044 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 14:47:53 -0700 Subject: [PATCH 83/97] Black --- mypyc/irbuild/match.py | 82 +++++++++----------------------------- mypyc/irbuild/statement.py | 6 +-- mypyc/test/test_irbuild.py | 2 +- mypyc/test/test_run.py | 2 +- 4 files changed, 21 insertions(+), 71 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index a5836622f46bb..b3af4ef35aed5 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -18,12 +18,7 @@ from mypy.types import Instance, TupleType from mypyc.primitives.dict_ops import dict_get_item_op -from mypyc.primitives.misc_ops import ( - check_dict, - dict_copy, - dict_del_item, - slow_isinstance_op, -) +from mypyc.primitives.misc_ops import check_dict, dict_copy, dict_del_item, slow_isinstance_op from mypyc.primitives.list_ops import check_list, list_get_item_op, list_slice_op from mypyc.primitives.generic_ops import generic_ssize_t_len_op from mypyc.irbuild.builder import IRBuilder @@ -93,12 +88,7 @@ def visit_match_stmt(self, m: MatchStmt) -> None: def visit_value_pattern(self, pattern: ValuePattern) -> None: value = self.builder.accept(pattern.expr) - cond = self.builder.binary_op( - self.subject, - value, - "==", - pattern.expr.line - ) + cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line) self.bind_as_pattern(value) @@ -125,7 +115,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: cond = self.builder.call_c( slow_isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], - pattern.line + pattern.line, ) self.builder.add_bool_branch(cond, self.code_block, self.next_block) @@ -161,9 +151,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - value = self.builder.py_get_attr( - self.subject, match_args[i], expr.line - ) + value = self.builder.py_get_attr(self.subject, match_args[i], expr.line) with self.enter_subpattern(value): expr.accept(self) @@ -209,11 +197,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: # string types, which is confusing. This should work for the time # being, but will need to be changed at some point. - is_dict = self.builder.call_c( - check_dict, - [self.subject], - pattern.line, - ) + is_dict = self.builder.call_c(check_dict, [self.subject], pattern.line) self.builder.add_bool_branch(is_dict, self.code_block, self.next_block) @@ -226,19 +210,13 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: key_value = self.builder.accept(key) keys.append(key_value) - exists = self.builder.binary_op( - key_value, self.subject, "in", pattern.line - ) + exists = self.builder.binary_op(key_value, self.subject, "in", pattern.line) self.builder.add_bool_branch(exists, self.code_block, self.next_block) self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - item = self.builder.call_c( - dict_get_item_op, - [self.subject, key_value], - pattern.line - ) + item = self.builder.call_c(dict_get_item_op, [self.subject, key_value], pattern.line) with self.enter_subpattern(item): value.accept(self) @@ -247,11 +225,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - rest = self.builder.call_c( - dict_copy, - [self.subject], - pattern.rest.line, - ) + rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line) target = self.builder.get_assignment_target(pattern.rest) @@ -265,9 +239,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: star_index, capture, patterns = prep_sequence_pattern(seq_pattern) - is_list = self.builder.call_c( - check_list, [self.subject], seq_pattern.line - ) + is_list = self.builder.call_c(check_list, [self.subject], seq_pattern.line) self.builder.add_bool_branch(is_list, self.code_block, self.next_block) @@ -279,17 +251,13 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - actual_len = self.builder.call_c( - generic_ssize_t_len_op, - [self.subject], - seq_pattern.line, - ) + actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line) is_long_enough = self.builder.binary_op( actual_len, self.builder.load_int(min_len), "==" if star_index is None else ">=", - seq_pattern.line + seq_pattern.line, ) self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) @@ -300,20 +268,13 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: if star_index is not None and i >= star_index: current = self.builder.binary_op( - actual_len, - self.builder.load_int(min_len - i), - "-", - pattern.line, + actual_len, self.builder.load_int(min_len - i), "-", pattern.line ) else: current = self.builder.load_int(i) - item = self.builder.call_c( - list_get_item_op, - [self.subject, current], - pattern.line, - ) + item = self.builder.call_c(list_get_item_op, [self.subject, current], pattern.line) with self.enter_subpattern(item): pattern.accept(self) @@ -323,19 +284,12 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.code_block = BasicBlock() capture_end = self.builder.binary_op( - actual_len, - self.builder.load_int(min_len - star_index), - "-", - capture.line, + actual_len, self.builder.load_int(min_len - star_index), "-", capture.line ) rest = self.builder.call_c( list_slice_op, - [ - self.subject, - self.builder.load_int(star_index), - capture_end, - ], + [self.subject, self.builder.load_int(star_index), capture_end], capture.line, ) @@ -366,9 +320,9 @@ def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: self.subject = old_subject -def prep_sequence_pattern(seq_pattern: SequencePattern) -> tuple[ - int | None, NameExpr | None, list[Pattern] -]: +def prep_sequence_pattern( + seq_pattern: SequencePattern, +) -> tuple[int | None, NameExpr | None, list[Pattern]]: star_index: int | None = None capture: NameExpr | None = None patterns: list[Pattern] = [] diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 0200c5cadf9ca..ca6be13b10b19 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -90,11 +90,7 @@ reraise_exception_op, restore_exc_info_op, ) -from mypyc.primitives.generic_ops import ( - iter_op, - next_raw_op, - py_delattr_op, -) +from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op from mypyc.primitives.misc_ops import ( check_stop_op, coro_op, diff --git a/mypyc/test/test_irbuild.py b/mypyc/test/test_irbuild.py index 51af88580edbd..abde61c1c5054 100644 --- a/mypyc/test/test_irbuild.py +++ b/mypyc/test/test_irbuild.py @@ -45,7 +45,7 @@ # "irbuild-dunders.test", # "irbuild-singledispatch.test", # "irbuild-constant-fold.test", - "irbuild-match.test", + "irbuild-match.test" ] diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 2e18da6df67ce..913ec1a232b00 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -62,7 +62,7 @@ # "run-dunders.test", # "run-singledispatch.test", # "run-attrs.test", - "run-match.test", + "run-match.test" ] files.append("run-python37.test") From 532ff8a622a66d8a33af9a78ecedb908d536b7de Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:02:53 -0700 Subject: [PATCH 84/97] Uncomment old tests, add python_version flag --- mypyc/options.py | 2 ++ mypyc/test/test_irbuild.py | 50 +++++++++++++++++----------------- mypyc/test/test_run.py | 56 +++++++++++++++++++------------------- mypyc/test/testutil.py | 3 +- 4 files changed, 57 insertions(+), 54 deletions(-) diff --git a/mypyc/options.py b/mypyc/options.py index d554cbed164f7..5f0cf12aeefed 100644 --- a/mypyc/options.py +++ b/mypyc/options.py @@ -13,6 +13,7 @@ def __init__( target_dir: str | None = None, include_runtime_files: bool | None = None, capi_version: tuple[int, int] | None = None, + python_version: tuple[int, int] | None = None, ) -> None: self.strip_asserts = strip_asserts self.multi_file = multi_file @@ -28,3 +29,4 @@ def __init__( # binaries are backward compatible even if no recent API # features are used. self.capi_version = capi_version or sys.version_info[:2] + self.python_version = python_version diff --git a/mypyc/test/test_irbuild.py b/mypyc/test/test_irbuild.py index abde61c1c5054..cf94cc3274c33 100644 --- a/mypyc/test/test_irbuild.py +++ b/mypyc/test/test_irbuild.py @@ -21,31 +21,31 @@ ) files = [ - # "irbuild-basic.test", - # "irbuild-int.test", - # "irbuild-lists.test", - # "irbuild-tuple.test", - # "irbuild-dict.test", - # "irbuild-set.test", - # "irbuild-str.test", - # "irbuild-bytes.test", - # "irbuild-statements.test", - # "irbuild-nested.test", - # "irbuild-classes.test", - # "irbuild-optional.test", - # "irbuild-any.test", - # "irbuild-generics.test", - # "irbuild-try.test", - # "irbuild-strip-asserts.test", - # "irbuild-i64.test", - # "irbuild-i32.test", - # "irbuild-vectorcall.test", - # "irbuild-unreachable.test", - # "irbuild-isinstance.test", - # "irbuild-dunders.test", - # "irbuild-singledispatch.test", - # "irbuild-constant-fold.test", - "irbuild-match.test" + "irbuild-basic.test", + "irbuild-int.test", + "irbuild-lists.test", + "irbuild-tuple.test", + "irbuild-dict.test", + "irbuild-set.test", + "irbuild-str.test", + "irbuild-bytes.test", + "irbuild-statements.test", + "irbuild-nested.test", + "irbuild-classes.test", + "irbuild-optional.test", + "irbuild-any.test", + "irbuild-generics.test", + "irbuild-try.test", + "irbuild-strip-asserts.test", + "irbuild-i64.test", + "irbuild-i32.test", + "irbuild-vectorcall.test", + "irbuild-unreachable.test", + "irbuild-isinstance.test", + "irbuild-dunders.test", + "irbuild-singledispatch.test", + "irbuild-constant-fold.test", + "irbuild-match.test", ] diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 913ec1a232b00..99804eb24bcdc 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -35,34 +35,34 @@ ) files = [ - # "run-async.test", - # "run-misc.test", - # "run-functions.test", - # "run-integers.test", - # "run-i64.test", - # "run-i32.test", - # "run-floats.test", - # "run-bools.test", - # "run-strings.test", - # "run-bytes.test", - # "run-tuples.test", - # "run-lists.test", - # "run-dicts.test", - # "run-sets.test", - # "run-primitives.test", - # "run-loops.test", - # "run-exceptions.test", - # "run-imports.test", - # "run-classes.test", - # "run-traits.test", - # "run-generators.test", - # "run-multimodule.test", - # "run-bench.test", - # "run-mypy-sim.test", - # "run-dunders.test", - # "run-singledispatch.test", - # "run-attrs.test", - "run-match.test" + "run-async.test", + "run-misc.test", + "run-functions.test", + "run-integers.test", + "run-i64.test", + "run-i32.test", + "run-floats.test", + "run-bools.test", + "run-strings.test", + "run-bytes.test", + "run-tuples.test", + "run-lists.test", + "run-dicts.test", + "run-sets.test", + "run-primitives.test", + "run-loops.test", + "run-exceptions.test", + "run-imports.test", + "run-classes.test", + "run-traits.test", + "run-generators.test", + "run-multimodule.test", + "run-bench.test", + "run-mypy-sim.test", + "run-dunders.test", + "run-singledispatch.test", + "run-attrs.test", + "run-match.test", ] files.append("run-python37.test") diff --git a/mypyc/test/testutil.py b/mypyc/test/testutil.py index b97d8887e0f72..609ffc27385ea 100644 --- a/mypyc/test/testutil.py +++ b/mypyc/test/testutil.py @@ -108,7 +108,7 @@ def build_ir_for_single_file2( options.hide_error_codes = True options.use_builtins_fixtures = True options.strict_optional = True - options.python_version = (3, 10) + options.python_version = compiler_options.python_version or (3, 6) options.export_types = True options.preserve_asts = True options.allow_empty_bodies = True @@ -277,6 +277,7 @@ def infer_ir_build_options_from_test_name(name: str) -> CompilerOptions | None: m = re.search(r"_python([3-9]+)_([0-9]+)(_|\b)", name) if m: options.capi_version = (int(m.group(1)), int(m.group(2))) + options.python_version = options.capi_version elif "_py" in name or "_Python" in name: assert False, f"Invalid _py* suffix (should be _pythonX_Y): {name}" return options From f73b60a657288ae21bdd9bcdd5f23fba6717dc3a Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:45:53 -0700 Subject: [PATCH 85/97] Reorganize ops --- mypyc/irbuild/match.py | 4 +- mypyc/primitives/dict_ops.py | 18 +++++- mypyc/primitives/misc_ops.py | 24 -------- mypyc/test-data/irbuild-match.test | 92 ++++++++++++++---------------- 4 files changed, 63 insertions(+), 75 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index b3af4ef35aed5..68424c58f7f65 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -17,8 +17,8 @@ from mypy.traverser import TraverserVisitor from mypy.types import Instance, TupleType -from mypyc.primitives.dict_ops import dict_get_item_op -from mypyc.primitives.misc_ops import check_dict, dict_copy, dict_del_item, slow_isinstance_op +from mypyc.primitives.dict_ops import check_dict, dict_copy, dict_del_item, dict_get_item_op +from mypyc.primitives.misc_ops import slow_isinstance_op from mypyc.primitives.list_ops import check_list, list_get_item_op, list_slice_op from mypyc.primitives.generic_ops import generic_ssize_t_len_op from mypyc.irbuild.builder import IRBuilder diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index d1dca5a79e635..455b8922b53f7 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -63,7 +63,7 @@ ) # Generic one-argument dict constructor: dict(obj) -function_op( +dict_copy = function_op( name="builtins.dict", arg_types=[object_rprimitive], return_type=dict_rprimitive, @@ -301,3 +301,19 @@ c_function_name="PyDict_Size", error_kind=ERR_NEVER, ) + +# Check that the object is a dict or a subclass of dict +check_dict = custom_op( + arg_types=[object_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_Check", + error_kind=ERR_NEVER, +) + +# Delete an item from a dict +dict_del_item = custom_op( + arg_types=[object_rprimitive, object_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_DelItem", + error_kind=ERR_NEG_INT, +) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 60865a642240c..07df9c69714ba 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -238,27 +238,3 @@ c_function_name="CPySingledispatch_RegisterFunction", error_kind=ERR_MAGIC, ) - -# Check that the object is a dict or a subclass of dict -check_dict = custom_op( - arg_types=[object_rprimitive], - return_type=c_int_rprimitive, - c_function_name="PyDict_Check", - error_kind=ERR_NEVER, -) - -# Copy an object into a dict -dict_copy = custom_op( - arg_types=[object_rprimitive], - return_type=object_rprimitive, - c_function_name="PyDict_Copy", - error_kind=ERR_NEVER, -) - -# Delete an item from a dict -dict_del_item = custom_op( - arg_types=[object_rprimitive, object_rprimitive], - return_type=c_int_rprimitive, - c_function_name="PyDict_DelItem", - error_kind=ERR_NEG_INT, -) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 336a5965116ed..b42268374d34b 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1096,37 +1096,35 @@ def f(x): x :: object r0 :: int32 r1 :: bit - r2 :: object - r3, rest :: dict - r4 :: str - r5 :: object - r6 :: str - r7 :: object - r8 :: object[1] - r9 :: object_ptr - r10, r11 :: object + r2, rest :: dict + r3 :: str + r4 :: object + r5 :: str + r6 :: object + r7 :: object[1] + r8 :: object_ptr + r9, r10 :: object L0: r0 = PyDict_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L3 :: bool L1: - r2 = PyDict_Copy(x) - r3 = cast(dict, r2) - rest = r3 + r2 = CPyDict_FromAny(x) + rest = r2 L2: - r4 = 'matched' - r5 = builtins :: module - r6 = 'print' - r7 = CPyObject_GetAttr(r5, r6) - r8 = [r4] - r9 = load_address r8 - r10 = _PyObject_Vectorcall(r7, r9, 1, 0) - keep_alive r4 + r3 = 'matched' + r4 = builtins :: module + r5 = 'print' + r6 = CPyObject_GetAttr(r4, r5) + r7 = [r3] + r8 = load_address r7 + r9 = _PyObject_Vectorcall(r6, r8, 1, 0) + keep_alive r3 goto L4 L3: L4: - r11 = box(None, 1) - return r11 + r10 = box(None, 1) + return r10 [case testMatchMappingPatternWithRestPopKeys_python3_10] def f(x): match x: @@ -1148,17 +1146,16 @@ def f(x): r10 :: int32 r11 :: bit r12 :: bool - r13 :: object - r14, rest :: dict - r15 :: int32 - r16 :: bit - r17 :: str - r18 :: object - r19 :: str - r20 :: object - r21 :: object[1] - r22 :: object_ptr - r23, r24 :: object + r13, rest :: dict + r14 :: int32 + r15 :: bit + r16 :: str + r17 :: object + r18 :: str + r19 :: object + r20 :: object[1] + r21 :: object_ptr + r22, r23 :: object L0: r0 = PyDict_Check(x) r1 = r0 != 0 @@ -1179,25 +1176,24 @@ L2: r12 = truncate r10: int32 to builtins.bool if r12 goto L3 else goto L5 :: bool L3: - r13 = PyDict_Copy(x) - r14 = cast(dict, r13) - rest = r14 - r15 = PyDict_DelItem(r13, r2) - r16 = r15 >= 0 :: signed + r13 = CPyDict_FromAny(x) + rest = r13 + r14 = PyDict_DelItem(r13, r2) + r15 = r14 >= 0 :: signed L4: - r17 = 'matched' - r18 = builtins :: module - r19 = 'print' - r20 = CPyObject_GetAttr(r18, r19) - r21 = [r17] - r22 = load_address r21 - r23 = _PyObject_Vectorcall(r20, r22, 1, 0) - keep_alive r17 + r16 = 'matched' + r17 = builtins :: module + r18 = 'print' + r19 = CPyObject_GetAttr(r17, r18) + r20 = [r16] + r21 = load_address r20 + r22 = _PyObject_Vectorcall(r19, r21, 1, 0) + keep_alive r16 goto L6 L5: L6: - r24 = box(None, 1) - return r24 + r23 = box(None, 1) + return r23 [case testMatchEmptySequencePattern_python3_10] def f(x): match x: From b52518829c85e7f108ee94e02481c75593e59358 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 15:46:46 -0700 Subject: [PATCH 86/97] Isort --- mypyc/irbuild/match.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 68424c58f7f65..d7e9d85a41b15 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -2,26 +2,25 @@ from typing import Generator from mypy.nodes import MatchStmt, NameExpr, TypeInfo -from mypyc.ir.ops import Value, BasicBlock from mypy.patterns import ( AsPattern, ClassPattern, - OrPattern, MappingPattern, + OrPattern, Pattern, - SingletonPattern, SequencePattern, + SingletonPattern, StarredPattern, ValuePattern, ) from mypy.traverser import TraverserVisitor from mypy.types import Instance, TupleType - +from mypyc.ir.ops import BasicBlock, Value +from mypyc.irbuild.builder import IRBuilder from mypyc.primitives.dict_ops import check_dict, dict_copy, dict_del_item, dict_get_item_op -from mypyc.primitives.misc_ops import slow_isinstance_op -from mypyc.primitives.list_ops import check_list, list_get_item_op, list_slice_op from mypyc.primitives.generic_ops import generic_ssize_t_len_op -from mypyc.irbuild.builder import IRBuilder +from mypyc.primitives.list_ops import check_list, list_get_item_op, list_slice_op +from mypyc.primitives.misc_ops import slow_isinstance_op # From: https://peps.python.org/pep-0634/#class-patterns MATCHABLE_BUILTINS = { From a117794d0964621dbaa54f6b1ff9a996681f37a5 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 17:45:56 -0700 Subject: [PATCH 87/97] Switch to using older typing syntax --- mypyc/irbuild/match.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index d7e9d85a41b15..129ecd26989be 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Generator +from typing import Generator, Optional, List, Tuple from mypy.nodes import MatchStmt, NameExpr, TypeInfo from mypy.patterns import ( @@ -46,7 +46,7 @@ class MatchVisitor(TraverserVisitor): subject: Value match: MatchStmt - as_pattern: AsPattern | None = None + as_pattern: Optional[AsPattern] = None def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: self.builder = builder @@ -61,7 +61,9 @@ def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: def build_match_body(self, index: int) -> None: self.builder.activate_block(self.code_block) - if guard := self.match.guards[index]: + guard = self.match.guards[index] + + if guard: self.code_block = BasicBlock() cond = self.builder.accept(guard) @@ -136,7 +138,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: ty = node.names.get("__match_args__") assert ty and isinstance(ty.type, TupleType) - match_args: list[str] = [] + match_args: List[str] = [] for item in ty.type.items: assert isinstance(item, Instance) and item.last_known_value @@ -200,7 +202,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.add_bool_branch(is_dict, self.code_block, self.next_block) - keys: list[Value] = [] + keys: List[Value] = [] for key, value in zip(pattern.keys, pattern.values): self.builder.activate_block(self.code_block) @@ -321,10 +323,10 @@ def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: def prep_sequence_pattern( seq_pattern: SequencePattern, -) -> tuple[int | None, NameExpr | None, list[Pattern]]: - star_index: int | None = None - capture: NameExpr | None = None - patterns: list[Pattern] = [] +) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]: + star_index: Optional[int] = None + capture: Optional[NameExpr] = None + patterns: List[Pattern] = [] for i, pattern in enumerate(seq_pattern.patterns): if isinstance(pattern, StarredPattern): From 22c132752e348e45cf98de7adc8ab383f1adafae Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 19:15:56 -0700 Subject: [PATCH 88/97] Fix build errors: Turns out that the `-ffat-lto-objects` flag is only available in GCC, and clang will complain if there are any unused optimization flags passed. Adding the `-Wno-ignored-optimization-argument` flag seems to fix this. --- mypyc/build.py | 1 + mypyc/irbuild/match.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mypyc/build.py b/mypyc/build.py index 51696e86fa941..8b556db3cc5de 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -535,6 +535,7 @@ def mypycify( "-Wno-unknown-warning-option", "-Wno-unused-but-set-variable", "-Wno-cpp", + "-Wno-ignored-optimization-argument", ] elif compiler.compiler_type == "msvc": # msvc doesn't have levels, '/O2' is full and '/Od' is disable diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 129ecd26989be..5c6098d67aabb 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Generator, Optional, List, Tuple +from typing import Generator, List, Optional, Tuple from mypy.nodes import MatchStmt, NameExpr, TypeInfo from mypy.patterns import ( @@ -143,18 +143,18 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: for item in ty.type.items: assert isinstance(item, Instance) and item.last_known_value - value = item.last_known_value.value - assert isinstance(value, str) + match_arg = item.last_known_value.value + assert isinstance(match_arg, str) - match_args.append(value) + match_args.append(match_arg) for i, expr in enumerate(pattern.positionals): self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - value = self.builder.py_get_attr(self.subject, match_args[i], expr.line) + positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line) - with self.enter_subpattern(value): + with self.enter_subpattern(positional): expr.accept(self) for key, value in zip(pattern.keyword_keys, pattern.keyword_values): @@ -232,8 +232,8 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: self.builder.assign(target, rest, pattern.rest.line) - for i, key in enumerate(keys): - self.builder.call_c(dict_del_item, [rest, key], pattern.keys[i].line) + for i, key_name in enumerate(keys): + self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line) self.builder.goto(self.code_block) From 3836f009978573e282a99b9efb9d4173d259a0b4 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:12:05 -0700 Subject: [PATCH 89/97] Only run match code for Python 3.10+ --- mypyc/irbuild/match.py | 14 +++++++++----- mypyc/test/test_irbuild.py | 5 ++++- mypyc/test/test_run.py | 4 +++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 5c6098d67aabb..ebabb0b91babe 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -14,7 +14,7 @@ ValuePattern, ) from mypy.traverser import TraverserVisitor -from mypy.types import Instance, TupleType +from mypy.types import Instance, TupleType, get_proper_type from mypyc.ir.ops import BasicBlock, Value from mypyc.irbuild.builder import IRBuilder from mypyc.primitives.dict_ops import check_dict, dict_copy, dict_del_item, dict_get_item_op @@ -136,14 +136,18 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: assert isinstance(node, TypeInfo) ty = node.names.get("__match_args__") - assert ty and isinstance(ty.type, TupleType) + assert ty + + match_args_type = get_proper_type(ty.type) + assert isinstance(match_args_type, TupleType) match_args: List[str] = [] - for item in ty.type.items: - assert isinstance(item, Instance) and item.last_known_value + for item in match_args_type.items: + proper_item = get_proper_type(item) + assert isinstance(proper_item, Instance) and proper_item.last_known_value - match_arg = item.last_known_value.value + match_arg = proper_item.last_known_value.value assert isinstance(match_arg, str) match_args.append(match_arg) diff --git a/mypyc/test/test_irbuild.py b/mypyc/test/test_irbuild.py index b31cf15d58225..8928f94d62110 100644 --- a/mypyc/test/test_irbuild.py +++ b/mypyc/test/test_irbuild.py @@ -3,6 +3,7 @@ from __future__ import annotations import os.path +import sys from mypy.errors import CompileError from mypy.test.config import test_temp_dir @@ -45,10 +46,12 @@ "irbuild-dunders.test", "irbuild-singledispatch.test", "irbuild-constant-fold.test", - "irbuild-match.test", "irbuild-glue-methods.test", ] +if sys.version_info >= (3, 10): + files.append("irbuild-match.test") + class TestGenOps(MypycDataSuite): files = files diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 99804eb24bcdc..fff775ebfab5f 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -62,13 +62,15 @@ "run-dunders.test", "run-singledispatch.test", "run-attrs.test", - "run-match.test", ] files.append("run-python37.test") if sys.version_info >= (3, 8): files.append("run-python38.test") +if sys.version_info >= (3, 10): + files.append("run-match.test") + setup_format = """\ from setuptools import setup from mypyc.build import mypycify From 4f18437c94ffc10070bce1601f424023a60b11e6 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:34:58 -0700 Subject: [PATCH 90/97] Fix length of empty sequence patterns not being checked --- mypyc/irbuild/match.py | 6 ++-- mypyc/test-data/irbuild-match.test | 45 +++++++++++++++++------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index ebabb0b91babe..9412eec75387a 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -250,9 +250,6 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: min_len = len(patterns) - if not min_len: - return - self.builder.activate_block(self.code_block) self.code_block = BasicBlock() @@ -267,6 +264,9 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) + if not min_len: + return + for i, pattern in enumerate(patterns): self.builder.activate_block(self.code_block) self.code_block = BasicBlock() diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index b42268374d34b..f0a8a40fd6fe0 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1204,31 +1204,38 @@ def f(x): x :: object r0 :: int32 r1 :: bit - r2 :: str - r3 :: object - r4 :: str - r5 :: object - r6 :: object[1] - r7 :: object_ptr - r8, r9 :: object + r2 :: native_int + r3, r4 :: bit + r5 :: str + r6 :: object + r7 :: str + r8 :: object + r9 :: object[1] + r10 :: object_ptr + r11, r12 :: object L0: r0 = PyList_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L2 :: bool + if r1 goto L1 else goto L3 :: bool L1: - r2 = 'matched' - r3 = builtins :: module - r4 = 'print' - r5 = CPyObject_GetAttr(r3, r4) - r6 = [r2] - r7 = load_address r6 - r8 = _PyObject_Vectorcall(r5, r7, 1, 0) - keep_alive r2 - goto L3 + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = r2 == 0 + if r4 goto L2 else goto L3 :: bool L2: + r5 = 'matched' + r6 = builtins :: module + r7 = 'print' + r8 = CPyObject_GetAttr(r6, r7) + r9 = [r5] + r10 = load_address r9 + r11 = _PyObject_Vectorcall(r8, r10, 1, 0) + keep_alive r5 + goto L4 L3: - r9 = box(None, 1) - return r9 +L4: + r12 = box(None, 1) + return r12 [case testMatchFixedLengthSequencePattern_python3_10] def f(x): match x: From 4f218b954d969fe1ae7a9fca0182ebc734f97dab Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:43:39 -0700 Subject: [PATCH 91/97] Fix `[*rest]` patterns not binding to `rest` --- mypyc/irbuild/match.py | 6 +-- mypyc/test-data/irbuild-match.test | 65 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index 9412eec75387a..b5f1da15f2639 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -248,12 +248,11 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_list, self.code_block, self.next_block) - min_len = len(patterns) - self.builder.activate_block(self.code_block) self.code_block = BasicBlock() actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line) + min_len = len(patterns) is_long_enough = self.builder.binary_op( actual_len, @@ -264,9 +263,6 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) - if not min_len: - return - for i, pattern in enumerate(patterns): self.builder.activate_block(self.code_block) self.code_block = BasicBlock() diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index f0a8a40fd6fe0..4810d46d8886d 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1752,3 +1752,68 @@ L3: L4: r12 = box(None, 1) return r12 +[case testMatchSequenceCaptureAll_python3_10] +def f(x): + match x: + case [*rest]: + print("matched") +[out] +def f(x): + x :: object + r0 :: int32 + r1 :: bit + r2 :: native_int + r3, r4 :: bit + r5 :: native_int + r6 :: list + r7, r8 :: bit + r9, r10, r11 :: int + r12, rest :: object + r13 :: str + r14 :: object + r15 :: str + r16 :: object + r17 :: object[1] + r18 :: object_ptr + r19, r20 :: object +L0: + r0 = PyList_Check(x) + r1 = r0 != 0 + if r1 goto L1 else goto L8 :: bool +L1: + r2 = PyObject_Size(x) + r3 = r2 >= 0 :: signed + r4 = r2 >= 0 :: signed + if r4 goto L2 else goto L8 :: bool +L2: + r5 = r2 - 0 + r6 = cast(list, x) + r7 = r5 <= 4611686018427387903 :: signed + if r7 goto L3 else goto L4 :: bool +L3: + r8 = r5 >= -4611686018427387904 :: signed + if r8 goto L5 else goto L4 :: bool +L4: + r9 = CPyTagged_FromInt64(r5) + r10 = r9 + goto L6 +L5: + r11 = r5 << 1 + r10 = r11 +L6: + r12 = CPyList_GetSlice(r6, 0, r10) + rest = r12 +L7: + r13 = 'matched' + r14 = builtins :: module + r15 = 'print' + r16 = CPyObject_GetAttr(r14, r15) + r17 = [r13] + r18 = load_address r17 + r19 = _PyObject_Vectorcall(r16, r18, 1, 0) + keep_alive r13 + goto L9 +L8: +L9: + r20 = box(None, 1) + return r20 From 6c5de3381f5c52f23780af808f729b4e8656fe26 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 31 Oct 2022 16:11:25 -0700 Subject: [PATCH 92/97] Allow for pattern matching Mapping and Sequence protocols --- mypyc/irbuild/match.py | 33 +- mypyc/lib-rt/CPy.h | 2 + mypyc/lib-rt/dict_ops.c | 4 + mypyc/lib-rt/list_ops.c | 4 + mypyc/primitives/dict_ops.py | 20 +- mypyc/primitives/list_ops.py | 19 +- mypyc/test-data/irbuild-match.test | 783 ++++++++++++----------------- mypyc/test-data/run-match.test | 51 ++ 8 files changed, 431 insertions(+), 485 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index b5f1da15f2639..ab7ffd4f26c62 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -16,10 +16,20 @@ from mypy.traverser import TraverserVisitor from mypy.types import Instance, TupleType, get_proper_type from mypyc.ir.ops import BasicBlock, Value +from mypyc.ir.rtypes import object_rprimitive from mypyc.irbuild.builder import IRBuilder -from mypyc.primitives.dict_ops import check_dict, dict_copy, dict_del_item, dict_get_item_op +from mypyc.primitives.dict_ops import ( + dict_copy, + dict_del_item, + mapping_has_key, + supports_mapping_protocol, +) from mypyc.primitives.generic_ops import generic_ssize_t_len_op -from mypyc.primitives.list_ops import check_list, list_get_item_op, list_slice_op +from mypyc.primitives.list_ops import ( + sequence_get_item, + sequence_get_slice, + supports_sequence_protocol, +) from mypyc.primitives.misc_ops import slow_isinstance_op # From: https://peps.python.org/pep-0634/#class-patterns @@ -197,12 +207,7 @@ def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_mapping_pattern(self, pattern: MappingPattern) -> None: - # TODO: technically this should accept any object that supports the - # mapping protocol, but the PyMapping_Check function returns true for - # string types, which is confusing. This should work for the time - # being, but will need to be changed at some point. - - is_dict = self.builder.call_c(check_dict, [self.subject], pattern.line) + is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line) self.builder.add_bool_branch(is_dict, self.code_block, self.next_block) @@ -215,13 +220,15 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: key_value = self.builder.accept(key) keys.append(key_value) - exists = self.builder.binary_op(key_value, self.subject, "in", pattern.line) + exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line) self.builder.add_bool_branch(exists, self.code_block, self.next_block) self.builder.activate_block(self.code_block) self.code_block = BasicBlock() - item = self.builder.call_c(dict_get_item_op, [self.subject, key_value], pattern.line) + item = self.builder.gen_method_call( + self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line + ) with self.enter_subpattern(item): value.accept(self) @@ -244,7 +251,7 @@ def visit_mapping_pattern(self, pattern: MappingPattern) -> None: def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: star_index, capture, patterns = prep_sequence_pattern(seq_pattern) - is_list = self.builder.call_c(check_list, [self.subject], seq_pattern.line) + is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line) self.builder.add_bool_branch(is_list, self.code_block, self.next_block) @@ -275,7 +282,7 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: else: current = self.builder.load_int(i) - item = self.builder.call_c(list_get_item_op, [self.subject, current], pattern.line) + item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line) with self.enter_subpattern(item): pattern.accept(self) @@ -289,7 +296,7 @@ def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: ) rest = self.builder.call_c( - list_slice_op, + sequence_get_slice, [self.subject, self.builder.load_int(star_index), capture_end], capture.line, ) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index cffbbb3e16661..166c851d0155c 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -371,6 +371,7 @@ CPyTagged CPyList_Index(PyObject *list, PyObject *obj); PyObject *CPySequence_Multiply(PyObject *seq, CPyTagged t_size); PyObject *CPySequence_RMultiply(CPyTagged t_size, PyObject *seq); PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end); +int CPySequence_Check(PyObject *obj); // Dict operations @@ -402,6 +403,7 @@ PyObject *CPyDict_GetValuesIter(PyObject *dict); tuple_T3CIO CPyDict_NextKey(PyObject *dict_or_iter, CPyTagged offset); tuple_T3CIO CPyDict_NextValue(PyObject *dict_or_iter, CPyTagged offset); tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset); +int CPyMapping_Check(PyObject *obj); // Check that dictionary didn't change size during iteration. static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) { diff --git a/mypyc/lib-rt/dict_ops.c b/mypyc/lib-rt/dict_ops.c index b013a8a5f0b93..ccf3732a59d59 100644 --- a/mypyc/lib-rt/dict_ops.c +++ b/mypyc/lib-rt/dict_ops.c @@ -436,3 +436,7 @@ tuple_T4CIOO CPyDict_NextItem(PyObject *dict_or_iter, CPyTagged offset) { Py_INCREF(ret.f3); return ret; } + +int CPyMapping_Check(PyObject *obj) { + return Py_TYPE(obj)->tp_flags & Py_TPFLAGS_MAPPING; +} diff --git a/mypyc/lib-rt/list_ops.c b/mypyc/lib-rt/list_ops.c index cb72662e22eee..9849c8c3d7da1 100644 --- a/mypyc/lib-rt/list_ops.c +++ b/mypyc/lib-rt/list_ops.c @@ -325,3 +325,7 @@ PyObject *CPyList_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) { } return CPyObject_GetSlice(obj, start, end); } + +int CPySequence_Check(PyObject *obj) { + return Py_TYPE(obj)->tp_flags & Py_TPFLAGS_SEQUENCE; +} diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index 455b8922b53f7..9f477d0b7b90c 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -302,18 +302,24 @@ error_kind=ERR_NEVER, ) -# Check that the object is a dict or a subclass of dict -check_dict = custom_op( +# Delete an item from a dict +dict_del_item = custom_op( + arg_types=[object_rprimitive, object_rprimitive], + return_type=c_int_rprimitive, + c_function_name="PyDict_DelItem", + error_kind=ERR_NEG_INT, +) + +supports_mapping_protocol = custom_op( arg_types=[object_rprimitive], return_type=c_int_rprimitive, - c_function_name="PyDict_Check", + c_function_name="CPyMapping_Check", error_kind=ERR_NEVER, ) -# Delete an item from a dict -dict_del_item = custom_op( +mapping_has_key = custom_op( arg_types=[object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name="PyDict_DelItem", - error_kind=ERR_NEG_INT, + c_function_name="PyMapping_HasKey", + error_kind=ERR_NEVER, ) diff --git a/mypyc/primitives/list_ops.py b/mypyc/primitives/list_ops.py index 283e702734985..7fe3157f3a382 100644 --- a/mypyc/primitives/list_ops.py +++ b/mypyc/primitives/list_ops.py @@ -278,10 +278,23 @@ error_kind=ERR_MAGIC, ) -# Check that the object is a list or a subclass of list -check_list = custom_op( +supports_sequence_protocol = custom_op( arg_types=[object_rprimitive], return_type=c_int_rprimitive, - c_function_name="PyList_Check", + c_function_name="CPySequence_Check", error_kind=ERR_NEVER, ) + +sequence_get_item = custom_op( + arg_types=[object_rprimitive, c_pyssize_t_rprimitive], + return_type=object_rprimitive, + c_function_name="PySequence_GetItem", + error_kind=ERR_NEVER, +) + +sequence_get_slice = custom_op( + arg_types=[object_rprimitive, c_pyssize_t_rprimitive, c_pyssize_t_rprimitive], + return_type=object_rprimitive, + c_function_name="PySequence_GetSlice", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 4810d46d8886d..fa0faf0774bb1 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -1008,7 +1008,7 @@ def f(x): r7 :: object_ptr r8, r9 :: object L0: - r0 = PyDict_Check(x) + r0 = CPyMapping_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L2 :: bool L1: @@ -1038,54 +1038,50 @@ def f(x): r2 :: str r3 :: int32 r4 :: bit - r5 :: bool - r6 :: dict + r5 :: object + r6 :: str r7 :: object - r8 :: str - r9 :: object - r10 :: int32 - r11 :: bit - r12 :: bool + r8 :: int32 + r9 :: bit + r10 :: bool + r11 :: str + r12 :: object r13 :: str r14 :: object - r15 :: str - r16 :: object - r17 :: object[1] - r18 :: object_ptr - r19, r20 :: object + r15 :: object[1] + r16 :: object_ptr + r17, r18 :: object L0: - r0 = PyDict_Check(x) + r0 = CPyMapping_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L4 :: bool L1: r2 = 'key' - r3 = PySequence_Contains(x, r2) - r4 = r3 >= 0 :: signed - r5 = truncate r3: int32 to builtins.bool - if r5 goto L2 else goto L4 :: bool + r3 = PyMapping_HasKey(x, r2) + r4 = r3 != 0 + if r4 goto L2 else goto L4 :: bool L2: - r6 = cast(dict, x) - r7 = CPyDict_GetItem(r6, r2) - r8 = 'value' - r9 = PyObject_RichCompare(r7, r8, 2) - r10 = PyObject_IsTrue(r9) - r11 = r10 >= 0 :: signed - r12 = truncate r10: int32 to builtins.bool - if r12 goto L3 else goto L4 :: bool + r5 = PyObject_GetItem(x, r2) + r6 = 'value' + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L3 else goto L4 :: bool L3: - r13 = 'matched' - r14 = builtins :: module - r15 = 'print' - r16 = CPyObject_GetAttr(r14, r15) - r17 = [r13] - r18 = load_address r17 - r19 = _PyObject_Vectorcall(r16, r18, 1, 0) - keep_alive r13 + r11 = 'matched' + r12 = builtins :: module + r13 = 'print' + r14 = CPyObject_GetAttr(r12, r13) + r15 = [r11] + r16 = load_address r15 + r17 = _PyObject_Vectorcall(r14, r16, 1, 0) + keep_alive r11 goto L5 L4: L5: - r20 = box(None, 1) - return r20 + r18 = box(None, 1) + return r18 [case testMatchMappingPatternWithRest_python3_10] def f(x): match x: @@ -1105,7 +1101,7 @@ def f(x): r8 :: object_ptr r9, r10 :: object L0: - r0 = PyDict_Check(x) + r0 = CPyMapping_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L3 :: bool L1: @@ -1138,62 +1134,58 @@ def f(x): r2 :: str r3 :: int32 r4 :: bit - r5 :: bool - r6 :: dict + r5 :: object + r6 :: str r7 :: object - r8 :: str - r9 :: object - r10 :: int32 - r11 :: bit - r12 :: bool - r13, rest :: dict - r14 :: int32 - r15 :: bit + r8 :: int32 + r9 :: bit + r10 :: bool + r11, rest :: dict + r12 :: int32 + r13 :: bit + r14 :: str + r15 :: object r16 :: str r17 :: object - r18 :: str - r19 :: object - r20 :: object[1] - r21 :: object_ptr - r22, r23 :: object + r18 :: object[1] + r19 :: object_ptr + r20, r21 :: object L0: - r0 = PyDict_Check(x) + r0 = CPyMapping_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L5 :: bool L1: r2 = 'key' - r3 = PySequence_Contains(x, r2) - r4 = r3 >= 0 :: signed - r5 = truncate r3: int32 to builtins.bool - if r5 goto L2 else goto L5 :: bool + r3 = PyMapping_HasKey(x, r2) + r4 = r3 != 0 + if r4 goto L2 else goto L5 :: bool L2: - r6 = cast(dict, x) - r7 = CPyDict_GetItem(r6, r2) - r8 = 'value' - r9 = PyObject_RichCompare(r7, r8, 2) - r10 = PyObject_IsTrue(r9) - r11 = r10 >= 0 :: signed - r12 = truncate r10: int32 to builtins.bool - if r12 goto L3 else goto L5 :: bool + r5 = PyObject_GetItem(x, r2) + r6 = 'value' + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L3 else goto L5 :: bool L3: - r13 = CPyDict_FromAny(x) - rest = r13 - r14 = PyDict_DelItem(r13, r2) - r15 = r14 >= 0 :: signed + r11 = CPyDict_FromAny(x) + rest = r11 + r12 = PyDict_DelItem(r11, r2) + r13 = r12 >= 0 :: signed L4: - r16 = 'matched' - r17 = builtins :: module - r18 = 'print' - r19 = CPyObject_GetAttr(r17, r18) - r20 = [r16] - r21 = load_address r20 - r22 = _PyObject_Vectorcall(r19, r21, 1, 0) - keep_alive r16 + r14 = 'matched' + r15 = builtins :: module + r16 = 'print' + r17 = CPyObject_GetAttr(r15, r16) + r18 = [r14] + r19 = load_address r18 + r20 = _PyObject_Vectorcall(r17, r19, 1, 0) + keep_alive r14 goto L6 L5: L6: - r23 = box(None, 1) - return r23 + r21 = box(None, 1) + return r21 [case testMatchEmptySequencePattern_python3_10] def f(x): match x: @@ -1214,7 +1206,7 @@ def f(x): r10 :: object_ptr r11, r12 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L3 :: bool L1: @@ -1248,25 +1240,23 @@ def f(x): r1 :: bit r2 :: native_int r3, r4 :: bit - r5 :: list - r6, r7, r8 :: object - r9 :: int32 - r10 :: bit - r11 :: bool - r12 :: list - r13, r14, r15 :: object - r16 :: int32 - r17 :: bit - r18 :: bool + r5, r6, r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11, r12, r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: str + r18 :: object r19 :: str r20 :: object - r21 :: str - r22 :: object - r23 :: object[1] - r24 :: object_ptr - r25, r26 :: object + r21 :: object[1] + r22 :: object_ptr + r23, r24 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L5 :: bool L1: @@ -1275,37 +1265,35 @@ L1: r4 = r2 == 2 if r4 goto L2 else goto L5 :: bool L2: - r5 = cast(list, x) - r6 = CPyList_GetItem(r5, 0) - r7 = object 1 - r8 = PyObject_RichCompare(r6, r7, 2) - r9 = PyObject_IsTrue(r8) - r10 = r9 >= 0 :: signed - r11 = truncate r9: int32 to builtins.bool - if r11 goto L3 else goto L5 :: bool + r5 = PySequence_GetItem(x, 0) + r6 = object 1 + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L3 else goto L5 :: bool L3: - r12 = cast(list, x) - r13 = CPyList_GetItem(r12, 2) - r14 = object 2 - r15 = PyObject_RichCompare(r13, r14, 2) - r16 = PyObject_IsTrue(r15) - r17 = r16 >= 0 :: signed - r18 = truncate r16: int32 to builtins.bool - if r18 goto L4 else goto L5 :: bool + r11 = PySequence_GetItem(x, 1) + r12 = object 2 + r13 = PyObject_RichCompare(r11, r12, 2) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L4 else goto L5 :: bool L4: - r19 = 'matched' - r20 = builtins :: module - r21 = 'print' - r22 = CPyObject_GetAttr(r20, r21) - r23 = [r19] - r24 = load_address r23 - r25 = _PyObject_Vectorcall(r22, r24, 1, 0) - keep_alive r19 + r17 = 'matched' + r18 = builtins :: module + r19 = 'print' + r20 = CPyObject_GetAttr(r18, r19) + r21 = [r17] + r22 = load_address r21 + r23 = _PyObject_Vectorcall(r20, r22, 1, 0) + keep_alive r17 goto L6 L5: L6: - r26 = box(None, 1) - return r26 + r24 = box(None, 1) + return r24 [case testMatchSequencePatternWithTrailingUnboundStar_python3_10] def f(x): match x: @@ -1318,25 +1306,23 @@ def f(x): r1 :: bit r2 :: native_int r3, r4 :: bit - r5 :: list - r6, r7, r8 :: object - r9 :: int32 - r10 :: bit - r11 :: bool - r12 :: list - r13, r14, r15 :: object - r16 :: int32 - r17 :: bit - r18 :: bool + r5, r6, r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11, r12, r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: str + r18 :: object r19 :: str r20 :: object - r21 :: str - r22 :: object - r23 :: object[1] - r24 :: object_ptr - r25, r26 :: object + r21 :: object[1] + r22 :: object_ptr + r23, r24 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 if r1 goto L1 else goto L5 :: bool L1: @@ -1345,37 +1331,35 @@ L1: r4 = r2 >= 2 :: signed if r4 goto L2 else goto L5 :: bool L2: - r5 = cast(list, x) - r6 = CPyList_GetItem(r5, 0) - r7 = object 1 - r8 = PyObject_RichCompare(r6, r7, 2) - r9 = PyObject_IsTrue(r8) - r10 = r9 >= 0 :: signed - r11 = truncate r9: int32 to builtins.bool - if r11 goto L3 else goto L5 :: bool + r5 = PySequence_GetItem(x, 0) + r6 = object 1 + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L3 else goto L5 :: bool L3: - r12 = cast(list, x) - r13 = CPyList_GetItem(r12, 2) - r14 = object 2 - r15 = PyObject_RichCompare(r13, r14, 2) - r16 = PyObject_IsTrue(r15) - r17 = r16 >= 0 :: signed - r18 = truncate r16: int32 to builtins.bool - if r18 goto L4 else goto L5 :: bool + r11 = PySequence_GetItem(x, 1) + r12 = object 2 + r13 = PyObject_RichCompare(r11, r12, 2) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L4 else goto L5 :: bool L4: - r19 = 'matched' - r20 = builtins :: module - r21 = 'print' - r22 = CPyObject_GetAttr(r20, r21) - r23 = [r19] - r24 = load_address r23 - r25 = _PyObject_Vectorcall(r22, r24, 1, 0) - keep_alive r19 + r17 = 'matched' + r18 = builtins :: module + r19 = 'print' + r20 = CPyObject_GetAttr(r18, r19) + r21 = [r17] + r22 = load_address r21 + r23 = _PyObject_Vectorcall(r20, r22, 1, 0) + keep_alive r17 goto L6 L5: L6: - r26 = box(None, 1) - return r26 + r24 = box(None, 1) + return r24 [case testMatchSequencePatternWithTrailingBoundStar_python3_10] def f(x): match x: @@ -1388,87 +1372,66 @@ def f(x): r1 :: bit r2 :: native_int r3, r4 :: bit - r5 :: list - r6, r7, r8 :: object - r9 :: int32 - r10 :: bit - r11 :: bool - r12 :: list - r13, r14, r15 :: object - r16 :: int32 - r17 :: bit - r18 :: bool - r19 :: native_int - r20 :: list - r21, r22 :: bit - r23, r24, r25 :: int - r26, rest :: object - r27 :: str - r28 :: object - r29 :: str - r30 :: object - r31 :: object[1] - r32 :: object_ptr - r33, r34 :: object + r5, r6, r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11, r12, r13 :: object + r14 :: int32 + r15 :: bit + r16 :: bool + r17 :: native_int + r18, rest :: object + r19 :: str + r20 :: object + r21 :: str + r22 :: object + r23 :: object[1] + r24 :: object_ptr + r25, r26 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L10 :: bool + if r1 goto L1 else goto L6 :: bool L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed r4 = r2 >= 2 :: signed - if r4 goto L2 else goto L10 :: bool + if r4 goto L2 else goto L6 :: bool L2: - r5 = cast(list, x) - r6 = CPyList_GetItem(r5, 0) - r7 = object 1 - r8 = PyObject_RichCompare(r6, r7, 2) - r9 = PyObject_IsTrue(r8) - r10 = r9 >= 0 :: signed - r11 = truncate r9: int32 to builtins.bool - if r11 goto L3 else goto L10 :: bool + r5 = PySequence_GetItem(x, 0) + r6 = object 1 + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L3 else goto L6 :: bool L3: - r12 = cast(list, x) - r13 = CPyList_GetItem(r12, 2) - r14 = object 2 - r15 = PyObject_RichCompare(r13, r14, 2) - r16 = PyObject_IsTrue(r15) - r17 = r16 >= 0 :: signed - r18 = truncate r16: int32 to builtins.bool - if r18 goto L4 else goto L10 :: bool + r11 = PySequence_GetItem(x, 1) + r12 = object 2 + r13 = PyObject_RichCompare(r11, r12, 2) + r14 = PyObject_IsTrue(r13) + r15 = r14 >= 0 :: signed + r16 = truncate r14: int32 to builtins.bool + if r16 goto L4 else goto L6 :: bool L4: - r19 = r2 - 0 - r20 = cast(list, x) - r21 = r19 <= 4611686018427387903 :: signed - if r21 goto L5 else goto L6 :: bool + r17 = r2 - 0 + r18 = PySequence_GetSlice(x, 2, r17) + rest = r18 L5: - r22 = r19 >= -4611686018427387904 :: signed - if r22 goto L7 else goto L6 :: bool + r19 = 'matched' + r20 = builtins :: module + r21 = 'print' + r22 = CPyObject_GetAttr(r20, r21) + r23 = [r19] + r24 = load_address r23 + r25 = _PyObject_Vectorcall(r22, r24, 1, 0) + keep_alive r19 + goto L7 L6: - r23 = CPyTagged_FromInt64(r19) - r24 = r23 - goto L8 L7: - r25 = r19 << 1 - r24 = r25 -L8: - r26 = CPyList_GetSlice(r20, 4, r24) - rest = r26 -L9: - r27 = 'matched' - r28 = builtins :: module - r29 = 'print' - r30 = CPyObject_GetAttr(r28, r29) - r31 = [r27] - r32 = load_address r31 - r33 = _PyObject_Vectorcall(r30, r32, 1, 0) - keep_alive r27 - goto L11 -L10: -L11: - r34 = box(None, 1) - return r34 + r26 = box(None, 1) + return r26 [case testMatchSequenceWithStarPatternInTheMiddle_python3_10] def f(x): match x: @@ -1481,108 +1444,72 @@ def f(x): r1 :: bit r2 :: native_int r3, r4 :: bit - r5 :: list - r6 :: object - r7 :: str - r8 :: object - r9 :: int32 - r10 :: bit - r11 :: bool - r12 :: native_int - r13 :: list - r14, r15 :: bit - r16, r17, r18 :: int - r19 :: object + r5 :: object + r6 :: str + r7 :: object + r8 :: int32 + r9 :: bit + r10 :: bool + r11 :: native_int + r12 :: object + r13 :: str + r14 :: object + r15 :: int32 + r16 :: bit + r17 :: bool + r18 :: native_int + r19, rest :: object r20 :: str r21 :: object - r22 :: int32 - r23 :: bit - r24 :: bool - r25 :: native_int - r26 :: list - r27, r28 :: bit - r29, r30, r31 :: int - r32, rest :: object - r33 :: str - r34 :: object - r35 :: str - r36 :: object - r37 :: object[1] - r38 :: object_ptr - r39, r40 :: object + r22 :: str + r23 :: object + r24 :: object[1] + r25 :: object_ptr + r26, r27 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L14 :: bool + if r1 goto L1 else goto L6 :: bool L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed r4 = r2 >= 2 :: signed - if r4 goto L2 else goto L14 :: bool + if r4 goto L2 else goto L6 :: bool L2: - r5 = cast(list, x) - r6 = CPyList_GetItem(r5, 0) - r7 = 'start' - r8 = PyObject_RichCompare(r6, r7, 2) - r9 = PyObject_IsTrue(r8) - r10 = r9 >= 0 :: signed - r11 = truncate r9: int32 to builtins.bool - if r11 goto L3 else goto L14 :: bool + r5 = PySequence_GetItem(x, 0) + r6 = 'start' + r7 = PyObject_RichCompare(r5, r6, 2) + r8 = PyObject_IsTrue(r7) + r9 = r8 >= 0 :: signed + r10 = truncate r8: int32 to builtins.bool + if r10 goto L3 else goto L6 :: bool L3: - r12 = r2 - 1 - r13 = cast(list, x) - r14 = r12 <= 4611686018427387903 :: signed - if r14 goto L4 else goto L5 :: bool + r11 = r2 - 1 + r12 = PySequence_GetItem(x, r11) + r13 = 'end' + r14 = PyObject_RichCompare(r12, r13, 2) + r15 = PyObject_IsTrue(r14) + r16 = r15 >= 0 :: signed + r17 = truncate r15: int32 to builtins.bool + if r17 goto L4 else goto L6 :: bool L4: - r15 = r12 >= -4611686018427387904 :: signed - if r15 goto L6 else goto L5 :: bool + r18 = r2 - 1 + r19 = PySequence_GetSlice(x, 1, r18) + rest = r19 L5: - r16 = CPyTagged_FromInt64(r12) - r17 = r16 + r20 = 'matched' + r21 = builtins :: module + r22 = 'print' + r23 = CPyObject_GetAttr(r21, r22) + r24 = [r20] + r25 = load_address r24 + r26 = _PyObject_Vectorcall(r23, r25, 1, 0) + keep_alive r20 goto L7 L6: - r18 = r12 << 1 - r17 = r18 L7: - r19 = CPyList_GetItem(r13, r17) - r20 = 'end' - r21 = PyObject_RichCompare(r19, r20, 2) - r22 = PyObject_IsTrue(r21) - r23 = r22 >= 0 :: signed - r24 = truncate r22: int32 to builtins.bool - if r24 goto L8 else goto L14 :: bool -L8: - r25 = r2 - 1 - r26 = cast(list, x) - r27 = r25 <= 4611686018427387903 :: signed - if r27 goto L9 else goto L10 :: bool -L9: - r28 = r25 >= -4611686018427387904 :: signed - if r28 goto L11 else goto L10 :: bool -L10: - r29 = CPyTagged_FromInt64(r25) - r30 = r29 - goto L12 -L11: - r31 = r25 << 1 - r30 = r31 -L12: - r32 = CPyList_GetSlice(r26, 2, r30) - rest = r32 -L13: - r33 = 'matched' - r34 = builtins :: module - r35 = 'print' - r36 = CPyObject_GetAttr(r34, r35) - r37 = [r33] - r38 = load_address r37 - r39 = _PyObject_Vectorcall(r36, r38, 1, 0) - keep_alive r33 - goto L15 -L14: -L15: - r40 = box(None, 1) - return r40 + r27 = box(None, 1) + return r27 [case testMatchSequenceWithStarPatternAtTheStart_python3_10] def f(x): match x: @@ -1596,120 +1523,69 @@ def f(x): r2 :: native_int r3, r4 :: bit r5 :: native_int - r6 :: list - r7, r8 :: bit - r9, r10, r11 :: int - r12, r13, r14 :: object - r15 :: int32 - r16 :: bit - r17 :: bool - r18 :: native_int - r19 :: list - r20, r21 :: bit - r22, r23, r24 :: int - r25, r26, r27 :: object - r28 :: int32 - r29 :: bit - r30 :: bool - r31 :: native_int - r32 :: list - r33, r34 :: bit - r35, r36, r37 :: int - r38, rest :: object - r39 :: str - r40 :: object - r41 :: str - r42 :: object - r43 :: object[1] - r44 :: object_ptr - r45, r46 :: object + r6, r7, r8 :: object + r9 :: int32 + r10 :: bit + r11 :: bool + r12 :: native_int + r13, r14, r15 :: object + r16 :: int32 + r17 :: bit + r18 :: bool + r19 :: native_int + r20, rest :: object + r21 :: str + r22 :: object + r23 :: str + r24 :: object + r25 :: object[1] + r26 :: object_ptr + r27, r28 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L18 :: bool + if r1 goto L1 else goto L6 :: bool L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed r4 = r2 >= 2 :: signed - if r4 goto L2 else goto L18 :: bool + if r4 goto L2 else goto L6 :: bool L2: r5 = r2 - 2 - r6 = cast(list, x) - r7 = r5 <= 4611686018427387903 :: signed - if r7 goto L3 else goto L4 :: bool + r6 = PySequence_GetItem(x, r5) + r7 = object 1 + r8 = PyObject_RichCompare(r6, r7, 2) + r9 = PyObject_IsTrue(r8) + r10 = r9 >= 0 :: signed + r11 = truncate r9: int32 to builtins.bool + if r11 goto L3 else goto L6 :: bool L3: - r8 = r5 >= -4611686018427387904 :: signed - if r8 goto L5 else goto L4 :: bool + r12 = r2 - 1 + r13 = PySequence_GetItem(x, r12) + r14 = object 2 + r15 = PyObject_RichCompare(r13, r14, 2) + r16 = PyObject_IsTrue(r15) + r17 = r16 >= 0 :: signed + r18 = truncate r16: int32 to builtins.bool + if r18 goto L4 else goto L6 :: bool L4: - r9 = CPyTagged_FromInt64(r5) - r10 = r9 - goto L6 + r19 = r2 - 2 + r20 = PySequence_GetSlice(x, 0, r19) + rest = r20 L5: - r11 = r5 << 1 - r10 = r11 + r21 = 'matched' + r22 = builtins :: module + r23 = 'print' + r24 = CPyObject_GetAttr(r22, r23) + r25 = [r21] + r26 = load_address r25 + r27 = _PyObject_Vectorcall(r24, r26, 1, 0) + keep_alive r21 + goto L7 L6: - r12 = CPyList_GetItem(r6, r10) - r13 = object 1 - r14 = PyObject_RichCompare(r12, r13, 2) - r15 = PyObject_IsTrue(r14) - r16 = r15 >= 0 :: signed - r17 = truncate r15: int32 to builtins.bool - if r17 goto L7 else goto L18 :: bool L7: - r18 = r2 - 1 - r19 = cast(list, x) - r20 = r18 <= 4611686018427387903 :: signed - if r20 goto L8 else goto L9 :: bool -L8: - r21 = r18 >= -4611686018427387904 :: signed - if r21 goto L10 else goto L9 :: bool -L9: - r22 = CPyTagged_FromInt64(r18) - r23 = r22 - goto L11 -L10: - r24 = r18 << 1 - r23 = r24 -L11: - r25 = CPyList_GetItem(r19, r23) - r26 = object 2 - r27 = PyObject_RichCompare(r25, r26, 2) - r28 = PyObject_IsTrue(r27) - r29 = r28 >= 0 :: signed - r30 = truncate r28: int32 to builtins.bool - if r30 goto L12 else goto L18 :: bool -L12: - r31 = r2 - 2 - r32 = cast(list, x) - r33 = r31 <= 4611686018427387903 :: signed - if r33 goto L13 else goto L14 :: bool -L13: - r34 = r31 >= -4611686018427387904 :: signed - if r34 goto L15 else goto L14 :: bool -L14: - r35 = CPyTagged_FromInt64(r31) - r36 = r35 - goto L16 -L15: - r37 = r31 << 1 - r36 = r37 -L16: - r38 = CPyList_GetSlice(r32, 0, r36) - rest = r38 -L17: - r39 = 'matched' - r40 = builtins :: module - r41 = 'print' - r42 = CPyObject_GetAttr(r40, r41) - r43 = [r39] - r44 = load_address r43 - r45 = _PyObject_Vectorcall(r42, r44, 1, 0) - keep_alive r39 - goto L19 -L18: -L19: - r46 = box(None, 1) - return r46 + r28 = box(None, 1) + return r28 [case testMatchBuiltinClassPattern_python3_10] def f(x): match x: @@ -1765,55 +1641,38 @@ def f(x): r2 :: native_int r3, r4 :: bit r5 :: native_int - r6 :: list - r7, r8 :: bit - r9, r10, r11 :: int - r12, rest :: object - r13 :: str - r14 :: object - r15 :: str - r16 :: object - r17 :: object[1] - r18 :: object_ptr - r19, r20 :: object + r6, rest :: object + r7 :: str + r8 :: object + r9 :: str + r10 :: object + r11 :: object[1] + r12 :: object_ptr + r13, r14 :: object L0: - r0 = PyList_Check(x) + r0 = CPySequence_Check(x) r1 = r0 != 0 - if r1 goto L1 else goto L8 :: bool + if r1 goto L1 else goto L4 :: bool L1: r2 = PyObject_Size(x) r3 = r2 >= 0 :: signed r4 = r2 >= 0 :: signed - if r4 goto L2 else goto L8 :: bool + if r4 goto L2 else goto L4 :: bool L2: r5 = r2 - 0 - r6 = cast(list, x) - r7 = r5 <= 4611686018427387903 :: signed - if r7 goto L3 else goto L4 :: bool + r6 = PySequence_GetSlice(x, 0, r5) + rest = r6 L3: - r8 = r5 >= -4611686018427387904 :: signed - if r8 goto L5 else goto L4 :: bool + r7 = 'matched' + r8 = builtins :: module + r9 = 'print' + r10 = CPyObject_GetAttr(r8, r9) + r11 = [r7] + r12 = load_address r11 + r13 = _PyObject_Vectorcall(r10, r12, 1, 0) + keep_alive r7 + goto L5 L4: - r9 = CPyTagged_FromInt64(r5) - r10 = r9 - goto L6 L5: - r11 = r5 << 1 - r10 = r11 -L6: - r12 = CPyList_GetSlice(r6, 0, r10) - rest = r12 -L7: - r13 = 'matched' - r14 = builtins :: module - r15 = 'print' - r16 = CPyObject_GetAttr(r14, r15) - r17 = [r13] - r18 = load_address r17 - r19 = _PyObject_Vectorcall(r16, r18, 1, 0) - keep_alive r13 - goto L9 -L8: -L9: - r20 = box(None, 1) - return r20 + r14 = box(None, 1) + return r14 diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 6df4efbb6aed6..5b1da28e4ebaf 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -230,3 +230,54 @@ test 21 ('') test 21 (' as well') test sequence final test final +[case testCustomMappingAndSequenceObjects_python3_10] +def f(x): + match x: + case {"key": "value", **rest}: + print(rest, type(rest)) + + case [1, 2, *rest2]: + print(rest2, type(rest2)) + +[file driver.py] +from collections.abc import Mapping, Sequence + +from native import f + +class CustomMapping(Mapping): + inner: dict + + def __init__(self, inner: dict) -> None: + self.inner = inner + + def __getitem__(self, key): + return self.inner[key] + + def __iter__(self): + return iter(self.inner) + + def __len__(self) -> int: + return len(self.inner) + + +class CustomSequence(Sequence): + inner: list + + def __init__(self, inner: list) -> None: + self.inner = inner + + def __getitem__(self, index: int) -> None: + return self.inner[index] + + def __len__(self) -> int: + return len(self.inner) + +mapping = CustomMapping({"key": "value", "some": "data"}) +sequence = CustomSequence([1, 2, 3]) + +f(mapping) +f(sequence) + +[out] +{'some': 'data'} +[3] From c8f8c8432248ea0e2417e90e0832b68c0537c901 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 31 Oct 2022 16:15:24 -0700 Subject: [PATCH 93/97] Remove previously added command-line arguments: These lines will be moved to a separate PR so that they can be reviewed separately from the `match` stuff going on here. --- mypyc/build.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mypyc/build.py b/mypyc/build.py index 8b556db3cc5de..4f40a6cd08659 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -534,8 +534,6 @@ def mypycify( "-Wno-unused-command-line-argument", "-Wno-unknown-warning-option", "-Wno-unused-but-set-variable", - "-Wno-cpp", - "-Wno-ignored-optimization-argument", ] elif compiler.compiler_type == "msvc": # msvc doesn't have levels, '/O2' is full and '/Od' is disable From dff0c714a9052dae807a9c1477012b8286a43777 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Mon, 31 Oct 2022 16:59:00 -0700 Subject: [PATCH 94/97] Fix flags not being defined in Python 3.9 and below: Since the `Py_TPFLAGS_MAPPING` and `Py_TPFLAGS_SEQUENCE` are specific to the new pattern matching feature, they are not defined for Python 3.9 and below. Now they are defined if they don't exist, which should fix the issue. --- mypyc/lib-rt/dict_ops.c | 4 ++++ mypyc/lib-rt/list_ops.c | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/mypyc/lib-rt/dict_ops.c b/mypyc/lib-rt/dict_ops.c index ccf3732a59d59..ba565257fd724 100644 --- a/mypyc/lib-rt/dict_ops.c +++ b/mypyc/lib-rt/dict_ops.c @@ -5,6 +5,10 @@ #include #include "CPy.h" +#ifndef Py_TPFLAGS_MAPPING +#define Py_TPFLAGS_MAPPING (1 << 6) +#endif + // Dict subclasses like defaultdict override things in interesting // ways, so we don't want to just directly use the dict methods. Not // sure if it is actually worth doing all this stuff, but it saves diff --git a/mypyc/lib-rt/list_ops.c b/mypyc/lib-rt/list_ops.c index 9849c8c3d7da1..df87228a0d100 100644 --- a/mypyc/lib-rt/list_ops.c +++ b/mypyc/lib-rt/list_ops.c @@ -5,6 +5,10 @@ #include #include "CPy.h" +#ifndef Py_TPFLAGS_SEQUENCE +#define Py_TPFLAGS_SEQUENCE (1 << 5) +#endif + PyObject *CPyList_Build(Py_ssize_t len, ...) { Py_ssize_t i; From 07486a071a9a7872d076b2073d3ee0ce90f56e8c Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Fri, 4 Nov 2022 22:04:14 -0700 Subject: [PATCH 95/97] Attempt to fix last commit: It would appear that 91251559 is now causing issues with the IR tests. Reverting that commit seems to fix it, but I cannot seem to remember why I added it in the first place, and the build artifacts have long since expired. --- mypyc/irbuild/classdef.py | 2 +- mypyc/irbuild/prepare.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 4502c201a2e87..34fc1fd766b00 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -629,7 +629,7 @@ def find_attr_initializers( and not isinstance(stmt.rvalue, TempNode) ): name = stmt.lvalues[0].name - if name in ("__slots__", "__match_args__"): + if name == "__slots__": continue if name == "__deletable__": diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index dc153ea11561e..82162d1d0d0e5 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -226,11 +226,7 @@ def prepare_class_def( if isinstance(node.node, Var): assert node.node.type, "Class member %s missing type" % name - if not node.node.is_classvar and name not in ( - "__slots__", - "__deletable__", - "__match_args__", - ): + if not node.node.is_classvar and name not in ("__slots__", "__deletable__"): ir.attributes[name] = mapper.type_to_rtype(node.node.type) elif isinstance(node.node, (FuncDef, Decorator)): prepare_method_def(ir, module_name, cdef, mapper, node.node) From 9728bc69b1f5eaec189ac06d9d85e8f0f07f24cf Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Tue, 15 Nov 2022 00:08:42 -0800 Subject: [PATCH 96/97] Trigger CI From af5bca85efa841ef4808d55870f4b92f05ec3880 Mon Sep 17 00:00:00 2001 From: dosisod <39638017+dosisod@users.noreply.github.com> Date: Thu, 1 Dec 2022 13:38:33 -0800 Subject: [PATCH 97/97] Add review suggestions: * Use faster "isinstance" op if the expression is a built-in type * Add type annotations to run tests * Add type annotated IR example * Add todo comments --- mypyc/irbuild/match.py | 17 ++- mypyc/test-data/irbuild-match.test | 214 ++++++++++++++++------------- mypyc/test-data/run-match.test | 4 +- 3 files changed, 137 insertions(+), 98 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index ab7ffd4f26c62..a1e671911ea59 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -30,7 +30,7 @@ sequence_get_slice, supports_sequence_protocol, ) -from mypyc.primitives.misc_ops import slow_isinstance_op +from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op # From: https://peps.python.org/pep-0634/#class-patterns MATCHABLE_BUILTINS = { @@ -123,10 +123,16 @@ def visit_or_pattern(self, pattern: OrPattern) -> None: self.builder.goto(self.next_block) def visit_class_pattern(self, pattern: ClassPattern) -> None: + # TODO: use faster instance check for native classes (while still + # making sure to account for inheritence) + isinstance_op = ( + fast_isinstance_op + if self.builder.is_builtin_ref_expr(pattern.class_ref) + else slow_isinstance_op + ) + cond = self.builder.call_c( - slow_isinstance_op, - [self.subject, self.builder.accept(pattern.class_ref)], - pattern.line, + isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line ) self.builder.add_bool_branch(cond, self.code_block, self.next_block) @@ -166,6 +172,8 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() + # TODO: use faster "get_attr" method instead when calling on native or + # builtin objects positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line) with self.enter_subpattern(positional): @@ -175,6 +183,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None: self.builder.activate_block(self.code_block) self.code_block = BasicBlock() + # TODO: same as above "get_attr" comment attr = self.builder.py_get_attr(self.subject, key, value.line) with self.enter_subpattern(attr): diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index fa0faf0774bb1..2afe3d862f517 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -118,37 +118,33 @@ def f(): [out] def f(): r0, r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool + r2 :: bool + r3 :: str + r4 :: object r5 :: str r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11, r12 :: object + r7 :: object[1] + r8 :: object_ptr + r9, r10 :: object L0: r0 = load_address PyLong_Type r1 = object 123 - r2 = PyObject_IsInstance(r1, r0) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L2 :: bool + r2 = CPy_TypeCheck(r1, r0) + if r2 goto L1 else goto L2 :: bool L1: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 + r3 = 'matched' + r4 = builtins :: module + r5 = 'print' + r6 = CPyObject_GetAttr(r4, r5) + r7 = [r3] + r8 = load_address r7 + r9 = _PyObject_Vectorcall(r6, r8, 1, 0) + keep_alive r3 goto L3 L2: L3: - r12 = box(None, 1) - return r12 + r10 = box(None, 1) + return r10 [case testMatchExaustivePattern_python3_10] def f(): match 123: @@ -444,42 +440,38 @@ def f(): def f(): r0 :: bit r1, r2 :: object - r3 :: int32 - r4 :: bit - r5 :: bool + r3 :: bool + r4 :: str + r5 :: object r6 :: str r7 :: object - r8 :: str - r9 :: object - r10 :: object[1] - r11 :: object_ptr - r12, r13 :: object + r8 :: object[1] + r9 :: object_ptr + r10, r11 :: object L0: r0 = 2 == 2 if r0 goto L3 else goto L1 :: bool L1: r1 = load_address PyLong_Type r2 = object 1 - r3 = PyObject_IsInstance(r2, r1) - r4 = r3 >= 0 :: signed - r5 = truncate r3: int32 to builtins.bool - if r5 goto L3 else goto L2 :: bool + r3 = CPy_TypeCheck(r2, r1) + if r3 goto L3 else goto L2 :: bool L2: goto L4 L3: - r6 = 'matched' - r7 = builtins :: module - r8 = 'print' - r9 = CPyObject_GetAttr(r7, r8) - r10 = [r6] - r11 = load_address r10 - r12 = _PyObject_Vectorcall(r9, r11, 1, 0) - keep_alive r6 + r4 = 'matched' + r5 = builtins :: module + r6 = 'print' + r7 = CPyObject_GetAttr(r5, r6) + r8 = [r4] + r9 = load_address r8 + r10 = _PyObject_Vectorcall(r7, r9, 1, 0) + keep_alive r4 goto L5 L4: L5: - r13 = box(None, 1) - return r13 + r11 = box(None, 1) + return r11 [case testMatchAsPattern_python3_10] def f(): match 123: @@ -561,39 +553,35 @@ def f(): [out] def f(): r0, r1 :: object - r2 :: int32 - r3 :: bit - r4 :: bool + r2 :: bool i :: int - r5 :: object - r6 :: str - r7, r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11, r12 :: object + r3 :: object + r4 :: str + r5, r6 :: object + r7 :: object[1] + r8 :: object_ptr + r9, r10 :: object L0: r0 = load_address PyLong_Type r1 = object 123 - r2 = PyObject_IsInstance(r1, r0) - r3 = r2 >= 0 :: signed - r4 = truncate r2: int32 to builtins.bool - if r4 goto L1 else goto L3 :: bool + r2 = CPy_TypeCheck(r1, r0) + if r2 goto L1 else goto L3 :: bool L1: i = 246 L2: - r5 = builtins :: module - r6 = 'print' - r7 = CPyObject_GetAttr(r5, r6) - r8 = box(int, i) - r9 = [r8] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r7, r10, 1, 0) - keep_alive r8 + r3 = builtins :: module + r4 = 'print' + r5 = CPyObject_GetAttr(r3, r4) + r6 = box(int, i) + r7 = [r6] + r8 = load_address r7 + r9 = _PyObject_Vectorcall(r5, r8, 1, 0) + keep_alive r6 goto L4 L3: L4: - r12 = box(None, 1) - return r12 + r10 = box(None, 1) + return r10 [case testMatchClassPatternWithPositionalArgs_python3_10] class Position: __match_args__ = ("x", "y", "z") @@ -1594,40 +1582,36 @@ def f(x): [out] def f(x): x, r0 :: object - r1 :: int32 - r2 :: bit - r3 :: bool - r4, y :: int + r1 :: bool + r2, y :: int + r3 :: str + r4 :: object r5 :: str r6 :: object - r7 :: str - r8 :: object - r9 :: object[1] - r10 :: object_ptr - r11, r12 :: object + r7 :: object[1] + r8 :: object_ptr + r9, r10 :: object L0: r0 = load_address PyLong_Type - r1 = PyObject_IsInstance(x, r0) - r2 = r1 >= 0 :: signed - r3 = truncate r1: int32 to builtins.bool - if r3 goto L1 else goto L3 :: bool + r1 = CPy_TypeCheck(x, r0) + if r1 goto L1 else goto L3 :: bool L1: - r4 = unbox(int, x) - y = r4 + r2 = unbox(int, x) + y = r2 L2: - r5 = 'matched' - r6 = builtins :: module - r7 = 'print' - r8 = CPyObject_GetAttr(r6, r7) - r9 = [r5] - r10 = load_address r9 - r11 = _PyObject_Vectorcall(r8, r10, 1, 0) - keep_alive r5 + r3 = 'matched' + r4 = builtins :: module + r5 = 'print' + r6 = CPyObject_GetAttr(r4, r5) + r7 = [r3] + r8 = load_address r7 + r9 = _PyObject_Vectorcall(r6, r8, 1, 0) + keep_alive r3 goto L4 L3: L4: - r12 = box(None, 1) - return r12 + r10 = box(None, 1) + return r10 [case testMatchSequenceCaptureAll_python3_10] def f(x): match x: @@ -1676,3 +1660,49 @@ L4: L5: r14 = box(None, 1) return r14 +[case testMatchTypeAnnotatedNativeClass_python3_10] +class A: + a: int + +def f(x: A | int) -> int: + match x: + case A(a=a): + return a + case int(): + return x +[out] +def f(x): + x :: union[__main__.A, int] + r0 :: object + r1 :: int32 + r2 :: bit + r3 :: bool + r4 :: str + r5 :: object + r6, a :: int + r7 :: object + r8 :: bool + r9 :: int +L0: + r0 = __main__.A :: type + r1 = PyObject_IsInstance(x, r0) + r2 = r1 >= 0 :: signed + r3 = truncate r1: int32 to builtins.bool + if r3 goto L1 else goto L3 :: bool +L1: + r4 = 'a' + r5 = CPyObject_GetAttr(x, r4) + r6 = unbox(int, r5) + a = r6 +L2: + return a +L3: + r7 = load_address PyLong_Type + r8 = CPy_TypeCheck(x, r7) + if r8 goto L4 else goto L5 :: bool +L4: + r9 = unbox(int, x) + return r9 +L5: +L6: + unreachable diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 5b1da28e4ebaf..7b7ad9a4342ce 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -13,7 +13,7 @@ class Person: return f"Person(name={self.name!r}, age={self.age})" -def f(x): +def f(x: object) -> None: match x: case 123: print("test 1") @@ -231,7 +231,7 @@ test 21 (' as well') test sequence final test final [case testCustomMappingAndSequenceObjects_python3_10] -def f(x): +def f(x: object) -> None: match x: case {"key": "value", **rest}: print(rest, type(rest))