Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for match ... case #60

Merged
merged 30 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d7b4aac
Add support for match ... case
Mar 19, 2024
d7fe210
Remove unused import in main.py
Mar 19, 2024
31fe10a
Remove unused function
Mar 19, 2024
2b254cd
Merge branch 'main' into main
prsabahrami Mar 20, 2024
eb9188c
Merge branch 'main' into main
prsabahrami Mar 20, 2024
04bdb63
Add match _ case and | functionality
Mar 21, 2024
4f3e06c
Add match , case, * case, and support for py3.9
Mar 21, 2024
603fd26
Create functions_310.py
Mar 21, 2024
e129951
Cleanup extra comments
Mar 21, 2024
0f1b4f4
Add match with guard and multiple variable match
Mar 22, 2024
dd3af8b
Fix py310 test functions
Mar 27, 2024
be67c18
Raise error for matching lists
Mar 27, 2024
e2b48ce
Adding support for guards and some fixes
prsabahrami Apr 27, 2024
d48605a
separating resolved and unresolved case and fixing issues
prsabahrami Apr 30, 2024
28da007
Merge branch 'Quantco:main' into main
prsabahrami May 1, 2024
56279e9
Merge branch 'main' into main
prsabahrami May 15, 2024
1d7a295
Fixing Coverage
prsabahrami May 15, 2024
edecbfe
Fix functions
prsabahrami May 15, 2024
bde858e
Fixing test functions
prsabahrami May 15, 2024
f7723b3
Remove extra test function
prsabahrami May 15, 2024
96b2337
Adding coverage for L329 - L333
prsabahrami May 16, 2024
2258a8c
some improvements
May 17, 2024
2af055c
add failing tests
May 17, 2024
2850d44
Fixing failing cases
prsabahrami May 18, 2024
5779202
Fixing test functions
prsabahrami May 18, 2024
10c9753
small fixes
pavelzw May 24, 2024
cbbedf1
bump version
pavelzw May 24, 2024
ed16d6d
only run release on non-forks
pavelzw May 24, 2024
52ece52
Update main.py
prsabahrami May 24, 2024
701b0d0
Updating comments for translate_match
prsabahrami May 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:

release:
name: Publish package
if: github.event_name == 'push' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true'
if: github.event_name == 'push' && github.repository == 'Quantco/polarify' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true'
needs: [build]
runs-on: ubuntu-latest
permissions:
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,17 @@ polarIFy is still in an early stage of development and doesn't support the full
- assignments (like `x = 1`)
- polars expressions (like `pl.col("x")`, TODO)
- side-effect free functions that return a polars expression (can be generated by `@polarify`) (TODO)
- `match` statements

### Unsupported operations

- `for` loops
- `while` loops
- `break` statements
- `:=` walrus operator
- `match ... case` statements (TODO)
- dictionary mappings in `match` statements
- list matching in `match` statements
- star patterns in `match statements
- functions with side-effects (`print`, `pl.write_csv`, ...)

## 🚀 Benchmarks
Expand Down
2,981 changes: 1,668 additions & 1,313 deletions pixi.lock

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ lint = "pre-commit run --all"

[environments]
default = ["test"]
pl014 = ["pl014", "py39", "test"]
pl015 = ["pl015", "py39", "test"]
pl016 = ["pl016", "py39", "test"]
pl017 = ["pl017", "py39", "test"]
pl018 = ["pl018", "py39", "test"]
pl019 = ["pl019", "py39", "test"]
pl020 = ["pl020", "py39", "test"]
pl014 = ["pl014", "py310", "test"]
pl015 = ["pl015", "py310", "test"]
pl016 = ["pl016", "py310", "test"]
pl017 = ["pl017", "py310", "test"]
pl018 = ["pl018", "py310", "test"]
pl019 = ["pl019", "py310", "test"]
pl020 = ["pl020", "py310", "test"]
py39 = ["py39", "test"]
py310 = ["py310", "test"]
py311 = ["py311", "test"]
Expand Down
235 changes: 210 additions & 25 deletions polarify/main.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,74 @@
from __future__ import annotations

import ast
import sys
from collections.abc import Sequence
from copy import copy, deepcopy
from dataclasses import dataclass

PY_39 = sys.version_info <= (3, 9)

# TODO: make walrus throw ValueError
# TODO: match ... case


def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast.expr) -> ast.Call:
when_node = ast.Call(
func=ast.Attribute(value=ast.Name(id="pl", ctx=ast.Load()), attr="when", ctx=ast.Load()),
args=[test],
keywords=[],
)
@dataclass
class UnresolvedCase:
"""
An unresolved case in a conditional statement. (if, match, etc.)
Each case consists of a test expression and a state.
The value of the state is not yet resolved.
"""

then_node = ast.Call(
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
args=[then],
keywords=[],
)
test: ast.expr
state: State

def __init__(self, test: ast.expr, then: State):
self.test = test
self.state = then


@dataclass
class ResolvedCase:
"""
A resolved case in a conditional statement. (if, match, etc.)
Each case consists of a test expression and a state.
The value of the state is resolved.
"""

test: ast.expr
state: ast.expr

def __init__(self, test: ast.expr, then: ast.expr):
self.test = test
self.state = then

def __iter__(self):
return iter([self.test, self.state])


def build_polars_when_then_otherwise(body: Sequence[ResolvedCase], orelse: ast.expr) -> ast.Call:
nodes: list[ast.Call] = []

assert body or orelse, "No when-then cases provided."

for test, then in body:
when_node = ast.Call(
func=ast.Attribute(
value=nodes[-1] if nodes else ast.Name(id="pl", ctx=ast.Load()),
attr="when",
ctx=ast.Load(),
),
args=[test],
keywords=[],
)
then_node = ast.Call(
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
args=[then],
keywords=[],
)
nodes.append(then_node)
final_node = ast.Call(
func=ast.Attribute(value=then_node, attr="otherwise", ctx=ast.Load()),
func=ast.Attribute(value=nodes[-1], attr="otherwise", ctx=ast.Load()),
args=[orelse],
keywords=[],
)
Expand Down Expand Up @@ -63,7 +110,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Call:
test = self.visit(node.test)
body = self.visit(node.body)
orelse = self.visit(node.orelse)
return build_polars_when_then_otherwise(test, body, orelse)
return build_polars_when_then_otherwise([ResolvedCase(test, body)], orelse)

def visit_Constant(self, node: ast.Constant) -> ast.Constant:
return node
Expand Down Expand Up @@ -122,11 +169,11 @@ class ReturnState:
@dataclass
class ConditionalState:
"""
A conditional state, with a test expression and two branches.
A list of conditional states.
Each case consists of a test expression and a state.
"""

test: ast.expr
then: State
body: Sequence[UnresolvedCase]
orelse: State


Expand All @@ -139,25 +186,106 @@ class State:

node: UnresolvedState | ReturnState | ConditionalState

def translate_match(
self,
subj: ast.expr | Sequence[ast.expr] | ast.Tuple,
pattern: ast.pattern,
guard: ast.expr | None = None,
):
"""
Translate a match_case statement into a regular AST expression.
translate_match takes a subject, a pattern and a guard.
patterns can be a MatchValue, MatchAs, MatchOr, or MatchSequence.
subjects can be a single expression (e.g x or (2 * x + 1)) or a list of expressions.
translate_match is called per each case in a match statement.
"""

if isinstance(pattern, ast.MatchValue):
equality_ast = ast.Compare(
left=subj,
ops=[ast.Eq()],
comparators=[pattern.value],
)

if guard is not None:
return ast.BinOp(
left=guard,
op=ast.BitAnd(),
right=equality_ast,
)

return equality_ast
elif isinstance(pattern, ast.MatchAs):
if pattern.name is not None:
self.handle_assign(
ast.Assign(
targets=[ast.Name(id=pattern.name, ctx=ast.Store())],
value=subj,
)
)
return guard
elif isinstance(pattern, ast.MatchOr):
return ast.BinOp(
left=self.translate_match(subj, pattern.patterns[0], guard),
op=ast.BitOr(),
right=(
self.translate_match(subj, ast.MatchOr(patterns=pattern.patterns[1:]))
if pattern.patterns[2:]
else self.translate_match(subj, pattern.patterns[1])
),
)
elif isinstance(pattern, ast.MatchSequence):
if isinstance(pattern.patterns[-1], ast.MatchStar):
raise ValueError("starred patterns are not supported.")

if isinstance(subj, ast.Tuple):
# TODO: Use polars list operations in the future
left = self.translate_match(subj.elts[0], pattern.patterns[0], guard)
right = (
self.translate_match(
ast.Tuple(elts=subj.elts[1:]),
ast.MatchSequence(patterns=pattern.patterns[1:]),
)
if pattern.patterns[2:]
else self.translate_match(subj.elts[1], pattern.patterns[1])
)

return (
left or right
if left is None or right is None
else ast.BinOp(left=left, op=ast.BitAnd(), right=right)
)
raise ValueError("Matching lists is not supported.")
else:
raise ValueError(
f"Incompatible match and subject types: {type(pattern)} and {type(subj)}."
)

def handle_assign(self, expr: ast.Assign | ast.AnnAssign):
if isinstance(expr, ast.AnnAssign):
expr = ast.Assign(targets=[expr.target], value=expr.value)

if isinstance(self.node, UnresolvedState):
self.node.handle_assign(expr)
elif isinstance(self.node, ConditionalState):
self.node.then.handle_assign(expr)
for case in self.node.body:
case.state.handle_assign(expr)
self.node.orelse.handle_assign(expr)

def handle_if(self, stmt: ast.If):
if isinstance(self.node, UnresolvedState):
self.node = ConditionalState(
test=InlineTransformer.inline_expr(stmt.test, self.node.assignments),
then=parse_body(stmt.body, copy(self.node.assignments)),
body=[
UnresolvedCase(
InlineTransformer.inline_expr(stmt.test, self.node.assignments),
parse_body(stmt.body, copy(self.node.assignments)),
)
],
orelse=parse_body(stmt.orelse, copy(self.node.assignments)),
)
elif isinstance(self.node, ConditionalState):
self.node.then.handle_if(stmt)
for case in self.node.body:
case.state.handle_if(stmt)
self.node.orelse.handle_if(stmt)

def handle_return(self, value: ast.expr):
Expand All @@ -166,9 +294,58 @@ def handle_return(self, value: ast.expr):
expr=InlineTransformer.inline_expr(value, self.node.assignments)
)
elif isinstance(self.node, ConditionalState):
self.node.then.handle_return(value)
for case in self.node.body:
case.state.handle_return(value)
self.node.orelse.handle_return(value)

def handle_match(self, stmt: ast.Match):
def is_catch_all(case: ast.match_case) -> bool:
# We check if the case is a catch-all pattern without a guard
# If it has a guard, we treat it as a regular case
return (
isinstance(case.pattern, ast.MatchAs)
and case.pattern.name is None
and case.guard is None
)

def ignore_case(case: ast.match_case) -> bool:
# if the length of the pattern is not equal to the length of the subject, python ignores the case
return (
isinstance(case.pattern, ast.MatchSequence)
and isinstance(stmt.subject, ast.Tuple)
and len(stmt.subject.elts) != len(case.pattern.patterns)
) or (isinstance(case.pattern, ast.MatchValue) and isinstance(stmt.subject, ast.Tuple))

if isinstance(self.node, UnresolvedState):
# We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case.
orelse = next(
iter([case.body for case in stmt.cases if is_catch_all(case)]),
[],
)
self.node = ConditionalState(
body=[
UnresolvedCase(
# translate_match transforms the match statement case into regular AST expressions so that the InlineTransformer can handle assignments correctly
# Note that by the time parse_body is called this has mutated the assignments
InlineTransformer.inline_expr(
self.translate_match(stmt.subject, case.pattern, case.guard),
self.node.assignments,
),
parse_body(case.body, copy(self.node.assignments)),
)
for case in stmt.cases
if not is_catch_all(case) and not ignore_case(case)
],
orelse=parse_body(
orelse,
copy(self.node.assignments),
),
)
elif isinstance(self.node, ConditionalState):
for case in self.node.body:
case.state.handle_match(stmt)
self.node.orelse.handle_match(stmt)


def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State:
if assignments is None:
Expand All @@ -182,9 +359,11 @@ def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | Non
elif isinstance(stmt, ast.Return):
if stmt.value is None:
raise ValueError("return needs a value")

state.handle_return(stmt.value)
break
elif isinstance(stmt, ast.Match):
assert not PY_39
state.handle_match(stmt)
else:
raise ValueError(f"Unsupported statement type: {type(stmt)}")
return state
Expand All @@ -194,9 +373,15 @@ def transform_tree_into_expr(node: State) -> ast.expr:
if isinstance(node.node, ReturnState):
return node.node.expr
elif isinstance(node.node, ConditionalState):
if not node.node.body:
# this happens if none of the cases will ever match or exist
# in these cases we just need to return the orelse body
return transform_tree_into_expr(node.node.orelse)
return build_polars_when_then_otherwise(
node.node.test,
transform_tree_into_expr(node.node.then),
[
ResolvedCase(case.test, transform_tree_into_expr(case.state))
for case in node.node.body
],
transform_tree_into_expr(node.node.orelse),
)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "polarify"
description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️"
version = "0.1.5"
version = "0.2.0"
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.9"
Expand Down
Loading