Skip to content

Commit

Permalink
feat: add IfExpr class (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
apkrelling authored Nov 19, 2024
1 parent e94ca1a commit 4301e2a
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 14 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/fibonacci.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"base_case_return = astx.FunctionReturn(astx.Variable(name=\"n\"))\n",
"base_case_block.append(base_case_return)\n",
"\n",
"base_case_if = astx.If(condition=base_case_cond, then=base_case_block)\n",
"base_case_if = astx.IfStmt(condition=base_case_cond, then=base_case_block)\n",
"\n",
"# Recursive case: return fib(n - 1) + fib(n - 2);\n",
"fib_n1_call = astx.FunctionCall(\n",
Expand Down Expand Up @@ -210,7 +210,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
6 changes: 4 additions & 2 deletions src/astx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@
ForCountLoopStmt,
ForRangeLoopExpr,
ForRangeLoopStmt,
If,
IfExpr,
IfStmt,
WhileExpr,
WhileStmt,
)
Expand Down Expand Up @@ -169,7 +170,8 @@ def get_version() -> str:
"FunctionPrototype",
"FunctionReturn",
"get_version",
"If",
"IfStmt",
"IfExpr",
"ImportFromExpr",
"ImportExpr",
"ImportStmt",
Expand Down
3 changes: 2 additions & 1 deletion src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,14 @@ class ASTKind(Enum):
LambdaExprKind = -404

# control flow
IfKind = -500
IfStmtKind = -500
ForCountLoopStmtKind = -501
ForRangeLoopStmtKind = -502
WhileStmtKind = -503
ForRangeLoopExprKind = -504
ForCountLoopExprKind = -505
WhileExprKind = -506
IfExprKind = -507

# data types
NullDTKind = -600
Expand Down
56 changes: 52 additions & 4 deletions src/astx/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@public
@typechecked
class If(StatementType):
class IfStmt(StatementType):
"""AST class for `if` statement."""

condition: Expr
Expand All @@ -38,17 +38,17 @@ def __init__(
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
"""Initialize the If instance."""
"""Initialize the IfStmt instance."""
super().__init__(loc=loc, parent=parent)
self.loc = loc
self.condition = condition
self.then = then
self.else_ = else_
self.kind = ASTKind.IfKind
self.kind = ASTKind.IfStmtKind

def __str__(self) -> str:
"""Return a string representation of the object."""
return f"If[{self.condition}]"
return f"IfStmt[{self.condition}]"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the object."""
Expand All @@ -69,6 +69,54 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:
return self._prepare_struct(key, value, simplified)


@public
@typechecked
class IfExpr(Expr):
"""AST class for `if` expression."""

condition: Expr
then: Block
else_: Optional[Block]

def __init__(
self,
condition: Expr,
then: Block,
else_: Optional[Block] = None,
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
"""Initialize the IfExpr instance."""
super().__init__(loc=loc, parent=parent)
self.loc = loc
self.condition = condition
self.then = then
self.else_ = else_
self.kind = ASTKind.IfExprKind

def __str__(self) -> str:
"""Return a string representation of the object."""
return f"IfExpr[{self.condition}]"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the object."""
if_condition = {"condition": self.condition.get_struct(simplified)}
if_then = {"then-block": self.then.get_struct(simplified)}
if_else: ReprStruct = {}

if self.else_ is not None:
if_else = {"else-block": self.else_.get_struct(simplified)}

key = "IF-EXPR"
value: ReprStruct = {
**cast(DictDataTypesStruct, if_condition),
**cast(DictDataTypesStruct, if_then),
**cast(DictDataTypesStruct, if_else),
}

return self._prepare_struct(key, value, simplified)


@public
@typechecked
class ForRangeLoopStmt(StatementType):
Expand Down
40 changes: 40 additions & 0 deletions src/astx/transpilers/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,46 @@ def visit(self, node: astx.FunctionReturn) -> str:
value = self.visit(node.value) if node.value else ""
return f"return {value}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.IfExpr) -> str:
"""Handle IfExpr nodes."""
if node.else_:
return (
f"{self.visit(node.then)} if "
f" {self.visit(node.condition)}"
f" else {self.visit(node.else_)}"
)
return (
f"{self.visit(node.then)} if "
f" {self.visit(node.condition)} else None"
)

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.IfStmt) -> str:
"""Handle IfStmt nodes."""
if node.else_:
return (
f"if {self.visit(node.condition)}:"
f"\n{self._generate_block(node.then)}"
f"\nelse:"
f"\n{self._generate_block(node.else_)}"
)
return (
f"if {self.visit(node.condition)}:"
f"\n{self._generate_block(node.then)}"
)

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ImportFromStmt) -> str:
"""Handle ImportFromStmt nodes."""
names = [self.visit(name) for name in node.names]
level_dots = "." * node.level
module_str = (
f"{level_dots}{node.module}" if node.module else level_dots
)
names_str = ", ".join(str(name) for name in names)
return f"from {module_str} import {names_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ImportExpr) -> str:
"""Handle ImportExpr nodes."""
Expand Down
36 changes: 31 additions & 5 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
ForCountLoopStmt,
ForRangeLoopExpr,
ForRangeLoopStmt,
If,
IfExpr,
IfStmt,
WhileExpr,
WhileStmt,
)
Expand All @@ -17,31 +18,56 @@
from astx.viz import visualize


def test_if() -> None:
def test_if_stmt() -> None:
"""Test `if` statement."""
op = BinaryOp(op_code=">", lhs=LiteralInt32(1), rhs=LiteralInt32(2))
then_block = Block()
if_stmt = If(condition=op, then=then_block)
if_stmt = IfStmt(condition=op, then=then_block)

assert str(if_stmt)
assert if_stmt.get_struct()
assert if_stmt.get_struct(simplified=True)
visualize(if_stmt.get_struct())


def test_if_else() -> None:
def test_if_else_stmt() -> None:
"""Test `if`/`else` statement."""
cond = BinaryOp(op_code=">", lhs=LiteralInt32(1), rhs=LiteralInt32(2))
then_block = Block()
else_block = Block()
if_stmt = If(condition=cond, then=then_block, else_=else_block)
if_stmt = IfStmt(condition=cond, then=then_block, else_=else_block)

assert str(if_stmt)
assert if_stmt.get_struct()
assert if_stmt.get_struct(simplified=True)
visualize(if_stmt.get_struct())


def test_if_expr() -> None:
"""Test `if` expression."""
op = BinaryOp(op_code=">", lhs=LiteralInt32(1), rhs=LiteralInt32(2))
then_block = Block()
if_expr = IfExpr(condition=op, then=then_block)

assert str(if_expr)
assert if_expr.get_struct()
assert if_expr.get_struct(simplified=True)
visualize(if_expr.get_struct())


def test_if_else_expr() -> None:
"""Test `if`/`else` expression."""
cond = BinaryOp(op_code=">", lhs=LiteralInt32(1), rhs=LiteralInt32(2))
then_block = Block()
else_block = Block()
if_expr = IfExpr(condition=cond, then=then_block, else_=else_block)

assert str(if_expr)
assert if_expr.get_struct()
assert if_expr.get_struct(simplified=True)
visualize(if_expr.get_struct())


def test_for_range_loop_expr() -> None:
"""Test `For Range Loop` expression`."""
decl_a = InlineVariableDeclaration(
Expand Down
Loading

0 comments on commit 4301e2a

Please sign in to comment.