Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for match ... case #60

Merged
merged 30 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d7b4aac
Add support for match ... case
Mar 19, 2024
d7fe210
Remove unused import in main.py
Mar 19, 2024
31fe10a
Remove unused function
Mar 19, 2024
2b254cd
Merge branch 'main' into main
prsabahrami Mar 20, 2024
eb9188c
Merge branch 'main' into main
prsabahrami Mar 20, 2024
04bdb63
Add match _ case and | functionality
Mar 21, 2024
4f3e06c
Add match , case, * case, and support for py3.9
Mar 21, 2024
603fd26
Create functions_310.py
Mar 21, 2024
e129951
Cleanup extra comments
Mar 21, 2024
0f1b4f4
Add match with guard and multiple variable match
Mar 22, 2024
dd3af8b
Fix py310 test functions
Mar 27, 2024
be67c18
Raise error for matching lists
Mar 27, 2024
e2b48ce
Adding support for guards and some fixes
prsabahrami Apr 27, 2024
d48605a
separating resolved and unresolved case and fixing issues
prsabahrami Apr 30, 2024
28da007
Merge branch 'Quantco:main' into main
prsabahrami May 1, 2024
56279e9
Merge branch 'main' into main
prsabahrami May 15, 2024
1d7a295
Fixing Coverage
prsabahrami May 15, 2024
edecbfe
Fix functions
prsabahrami May 15, 2024
bde858e
Fixing test functions
prsabahrami May 15, 2024
f7723b3
Remove extra test function
prsabahrami May 15, 2024
96b2337
Adding coverage for L329 - L333
prsabahrami May 16, 2024
2258a8c
some improvements
May 17, 2024
2af055c
add failing tests
May 17, 2024
2850d44
Fixing failing cases
prsabahrami May 18, 2024
5779202
Fixing test functions
prsabahrami May 18, 2024
10c9753
small fixes
pavelzw May 24, 2024
cbbedf1
bump version
pavelzw May 24, 2024
ed16d6d
only run release on non-forks
pavelzw May 24, 2024
52ece52
Update main.py
prsabahrami May 24, 2024
701b0d0
Updating comments for translate_match
prsabahrami May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions polarify/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ def __iter__(self):
def build_polars_when_then_otherwise(body: Sequence[ResolvedCase], orelse: ast.expr) -> ast.Call:
nodes: list[ast.Call] = []

assert body, "No when-then cases provided."
assert body or orelse, "No when-then cases provided."

if not body:
"""
When a match statement has no valid cases (i.e., all cases except catch-all pattern are ignored),
we return the orelse expression but the test setup does not work with literal expressions.
"""
raise ValueError("No valid cases provided.")
Copy link
Member

@pavelzw pavelzw May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added support for these cases in 10c9753 (#60) in transform_tree_into_expr


for test, then in body:
when_node = ast.Call(
Expand Down Expand Up @@ -195,9 +202,9 @@ def translate_match(
"""
TODO: Explain the purpose and goal of this method, it's quite complex
"""
if isinstance(pattern, ast.MatchValue) and isinstance(subj, ast.Name):
if isinstance(pattern, ast.MatchValue):
equality_ast = ast.Compare(
left=ast.Name(id=subj.id, ctx=ast.Load()),
left=subj,
ops=[ast.Eq()],
comparators=[pattern.value],
)
Expand All @@ -210,14 +217,12 @@ def translate_match(
)

return equality_ast
elif isinstance(pattern, ast.MatchValue) and isinstance(subj, ast.Tuple):
return self.translate_match(subj, ast.MatchSequence(patterns=[pattern]))
elif isinstance(pattern, ast.MatchAs) and isinstance(subj, ast.Name):
elif isinstance(pattern, ast.MatchAs):
if pattern.name is not None:
self.handle_assign(
ast.Assign(
targets=[ast.Name(id=pattern.name, ctx=ast.Store())],
value=ast.Name(id=subj.id, ctx=ast.Load()),
value=subj,
)
)
return guard
Expand All @@ -234,10 +239,8 @@ def translate_match(
elif isinstance(pattern, ast.MatchSequence):
if isinstance(pattern.patterns[-1], ast.MatchStar):
raise ValueError("starred patterns are not supported.")
if isinstance(subj, ast.Tuple):
while len(subj.elts) > len(pattern.patterns):
pattern.patterns.append(ast.MatchValue(value=ast.Constant(value=None)))

if isinstance(subj, ast.Tuple):
# TODO: Use polars list operations in the future
left = self.translate_match(subj.elts[0], pattern.patterns[0], guard)
right = (
Expand Down Expand Up @@ -298,13 +301,27 @@ def handle_return(self, value: ast.expr):
self.node.orelse.handle_return(value)

def handle_match(self, stmt: ast.Match):
def is_catch_all(pattern: ast.pattern) -> bool:
return isinstance(pattern, ast.MatchAs) and pattern.name is None
def is_catch_all(case: ast.match_case) -> bool:
# We check if the case is a catch-all pattern without a guard
# If it has a guard, we treat it as a regular case
return (
isinstance(case.pattern, ast.MatchAs)
and case.pattern.name is None
and case.guard is None
)

def ignore_case(case: ast.match_case) -> bool:
# if the length of the pattern is not equal to the length of the subject, python ignores the case
return (
isinstance(case.pattern, ast.MatchSequence)
and isinstance(stmt.subject, ast.Tuple)
and len(stmt.subject.elts) != len(case.pattern.patterns)
) or (isinstance(case.pattern, ast.MatchValue) and isinstance(stmt.subject, ast.Tuple))

if isinstance(self.node, UnresolvedState):
# We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case.
orelse = next(
iter([case.body for case in stmt.cases if is_catch_all(case.pattern)]),
iter([case.body for case in stmt.cases if is_catch_all(case)]),
[],
)
self.node = ConditionalState(
Expand All @@ -319,7 +336,7 @@ def is_catch_all(pattern: ast.pattern) -> bool:
parse_body(case.body, copy(self.node.assignments)),
)
for case in stmt.cases
if not is_catch_all(case.pattern)
if not is_catch_all(case) and not ignore_case(case)
],
orelse=parse_body(
orelse,
Expand Down
39 changes: 37 additions & 2 deletions tests/functions_310.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def match_guarded_match_as(x):
return 3


def match_sequence_padded_length(x):
def match_sequence_padded_length_no_case(x):
y = 2
z = None

Expand All @@ -244,6 +244,38 @@ def match_sequence_padded_length(x):
return -1


def match_sequence_padded_length_return(x):
y = 1
z = 2

match x, y, z:
case 1, 2:
return 1
return -1


def match_sequence_padded_length(x):
y = 1
z = 2

match x, y, z:
case 1, 2:
return 1
case 3, 4:
return -1
case 1, 2, 3:
return 2
return -2


def match_guard_no_assignation(x):
match x:
case _ if x > 1:
return 0
case _:
return 2


functions_310 = [
nested_match,
match_assignments_inside_branch,
Expand All @@ -263,15 +295,18 @@ def match_sequence_padded_length(x):
match_complex_subject,
match_guarded_match_as,
match_sequence_padded_length,
match_guard_no_assignation,
]

xfail_functions_310 = [
match_mapping,
match_sequence_padded_length_no_case,
match_sequence_padded_length_return,
]

unsupported_functions_310 = [
(match_sequence_star, "starred patterns are not supported."),
(match_sequence, "Matching lists is not supported."),
(match_sequence_with_brackets, "Matching lists is not supported."),
(match_guarded_match_as_no_return, "return"),
(match_guarded_match_as_no_return, "Not all branches return"),
]