-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
226 additions
and
5 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""ASTx Transpilers.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""ASTx Python transpiler.""" | ||
|
||
from typing import Type | ||
|
||
from plum import dispatch | ||
|
||
import astx | ||
|
||
|
||
class ASTxPythonTranspiler: | ||
""" | ||
Transpiler that converts ASTx nodes to Python code. | ||
Notes | ||
----- | ||
Please keep the visit method in alphabet order according to the node type. | ||
The visit method for astx.AST should be the first one. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self.indent_level = 0 | ||
self.indent_str = " " # 4 spaces | ||
|
||
def _generate_block(self, block: astx.Block) -> str: | ||
"""Generate code for a block of statements with proper indentation.""" | ||
self.indent_level += 1 | ||
indent = self.indent_str * self.indent_level | ||
lines = [indent + self.visit(node) for node in block.nodes] | ||
result = ( | ||
"\n".join(lines) | ||
if lines | ||
else self.indent_str * self.indent_level + "pass" | ||
) | ||
self.indent_level -= 1 | ||
return result | ||
|
||
@dispatch.abstract | ||
def visit(self, expr: astx.AST) -> str: | ||
"""Translate an ASTx expression.""" | ||
raise Exception(f"Not implemented yet ({expr}).") | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.Argument) -> str: | ||
"""Handle UnaryOp nodes.""" | ||
type_ = self.visit(node.type_) | ||
return f"{node.name}: {type_}" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.Arguments) -> str: | ||
"""Handle UnaryOp nodes.""" | ||
return ", ".join([self.visit(arg) for arg in node.nodes]) | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.BinaryOp) -> str: | ||
"""Handle BinaryOp nodes.""" | ||
lhs = self.visit(node.lhs) | ||
rhs = self.visit(node.rhs) | ||
return f"({lhs} {node.op_code} {rhs})" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.Block) -> str: | ||
"""Handle Block nodes.""" | ||
return self._generate_block(node) | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: Type[astx.Int32]) -> str: | ||
"""Handle Int32 nodes.""" | ||
return "int" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.LiteralBoolean) -> str: | ||
"""Handle LiteralBoolean nodes.""" | ||
return "True" if node.value else "False" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.LiteralInt32) -> str: | ||
"""Handle LiteralInt32 nodes.""" | ||
return str(node.value) | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.Function) -> str: | ||
"""Handle Function nodes.""" | ||
args = self.visit(node.prototype.args) | ||
returns = ( | ||
f" -> {self.visit(node.prototype.return_type)}" | ||
if node.prototype.return_type | ||
else "" | ||
) | ||
header = f"def {node.name}({args}){returns}:" | ||
body = self.visit(node.body) | ||
return f"{header}\n{body}" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.FunctionReturn) -> str: | ||
"""Handle FunctionReturn nodes.""" | ||
value = self.visit(node.value) if node.value else "" | ||
return f"return {value}" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.UnaryOp) -> str: | ||
"""Handle UnaryOp nodes.""" | ||
operand = self.visit(node.operand) | ||
return f"({node.op_code}{operand})" | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.Variable) -> str: | ||
"""Handle Variable nodes.""" | ||
return node.name | ||
|
||
@dispatch # type: ignore[no-redef] | ||
def visit(self, node: astx.VariableAssignment) -> str: | ||
"""Handle VariableAssignment nodes.""" | ||
target = node.name | ||
value = self.visit(node.value) | ||
return f"{target} = {value}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Set of tests for transpilers.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Test Python Transpiler.""" | ||
|
||
import astx | ||
|
||
from astx.transpilers import python as astx2py | ||
|
||
|
||
def test_function() -> None: | ||
"""Test astx.Function.""" | ||
# Function parameters | ||
args = astx.Arguments( | ||
astx.Argument(name="x", type_=astx.Int32), | ||
astx.Argument(name="y", type_=astx.Int32), | ||
) | ||
|
||
# Function body | ||
body = astx.Block() | ||
body.append( | ||
astx.VariableAssignment( | ||
name="result", | ||
value=astx.BinaryOp( | ||
op_code="+", | ||
lhs=astx.Variable(name="x"), | ||
rhs=astx.Variable(name="y"), | ||
loc=astx.SourceLocation(line=2, col=8), | ||
), | ||
loc=astx.SourceLocation(line=2, col=4), | ||
) | ||
) | ||
body.append( | ||
astx.FunctionReturn( | ||
value=astx.Variable(name="result"), | ||
loc=astx.SourceLocation(line=3, col=4), | ||
) | ||
) | ||
|
||
# Function definition | ||
add_function = astx.Function( | ||
prototype=astx.FunctionPrototype( | ||
name="add", | ||
args=args, | ||
return_type=astx.Int32, | ||
), | ||
body=body, | ||
loc=astx.SourceLocation(line=1, col=0), | ||
) | ||
|
||
# Initialize the generator | ||
generator = astx2py.ASTxPythonTranspiler() | ||
|
||
# Generate Python code | ||
generated_code = generator.visit(add_function) | ||
expected_code = "\n".join( | ||
[ | ||
"def add(x: int, y: int) -> int:", | ||
" result = (x + y)", | ||
" return result", | ||
] | ||
) | ||
|
||
assert generated_code == expected_code, "generated_code != expected_code" |