From 5eb3b5a4359e8fe734c04c6041d7ce1f2a82a688 Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Sun, 21 Apr 2024 23:23:13 -0400 Subject: [PATCH] feat: Move local implemention for AST to ASTx --- examples/average.arx | 2 +- poetry.lock | 16 +- pyproject.toml | 4 +- src/arx/ast.py | 278 -------------- src/arx/codegen.py | 82 ++++ src/arx/codegen/__init__.py | 1 - src/arx/codegen/ast_output.py | 244 ------------ src/arx/codegen/base.py | 170 --------- src/arx/codegen/file_object.py | 678 --------------------------------- src/arx/lexer.py | 27 +- src/arx/main.py | 13 +- src/arx/parser.py | 173 ++++----- tests/test_parser.py | 56 +-- 13 files changed, 221 insertions(+), 1523 deletions(-) delete mode 100644 src/arx/ast.py create mode 100644 src/arx/codegen.py delete mode 100644 src/arx/codegen/__init__.py delete mode 100644 src/arx/codegen/ast_output.py delete mode 100644 src/arx/codegen/base.py delete mode 100644 src/arx/codegen/file_object.py diff --git a/examples/average.arx b/examples/average.arx index 94639b4..4477189 100644 --- a/examples/average.arx +++ b/examples/average.arx @@ -1,2 +1,2 @@ fn average(x, y): - return (x + y) * 0.5; + return (x + y) * 1; diff --git a/poetry.lock b/poetry.lock index da3b03c..d289b7f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -31,13 +31,13 @@ test = ["astroid (>=1,<2)", "astroid (>=2,<4)", "pytest"] [[package]] name = "astx" -version = "0.9.1" +version = "0.11.0" description = "ASTx is an agnostic expression structure for AST." optional = false python-versions = "<4,>=3.8.1" files = [ - {file = "astx-0.9.1-py3-none-any.whl", hash = "sha256:15285358c8e337f8cf6d060199e84672101681d3453fce9e8daf8e4fe098dea1"}, - {file = "astx-0.9.1.tar.gz", hash = "sha256:ce36f1243577cae87668c6c944231fdcfd2b5ea0682780005051b6a05e357921"}, + {file = "astx-0.11.0-py3-none-any.whl", hash = "sha256:d6d8771f69979d89bfa4a6a87db6f8db27318c66af7cc8c4fadb608a3574eaa6"}, + {file = "astx-0.11.0.tar.gz", hash = "sha256:ec73d73c3cca36d89e72a3bca63a77d4b0ead85960ea51a68e8085cc956160ec"}, ] [package.dependencies] @@ -1893,17 +1893,17 @@ windows-terminal = ["colorama (>=0.4.6)"] [[package]] name = "pyirx" -version = "1.1.0" +version = "1.2.0" description = "IRx" optional = false python-versions = "<4,>=3.8.1" files = [ - {file = "pyirx-1.1.0-py3-none-any.whl", hash = "sha256:e22fd47d249e72f2623b375afc1dda84738fbe921996d3283448f7a180705066"}, - {file = "pyirx-1.1.0.tar.gz", hash = "sha256:5151e30216dfb1645ebfa4220ed582b1ddeb2c3609d12a782f0bafef8e42e9e3"}, + {file = "pyirx-1.2.0-py3-none-any.whl", hash = "sha256:18e1733635c707a09b0217a094f5a0b520be4205299fe02349eb2cf3c4e21c48"}, + {file = "pyirx-1.2.0.tar.gz", hash = "sha256:7a0dac401d00fd8a49d4be93d8a002a1d94b05d10615d6a22e197522a92679ee"}, ] [package.dependencies] -astx = ">=0.9" +astx = "==0.11.*" atpublic = ">=4.0" llvmlite = ">=0.41.1" plum-dispatch = ">=2.2.2" @@ -3009,4 +3009,4 @@ test = ["coverage (>=5.3.1)", "prompt-toolkit (>=3.0.29,<3.0.41)", "pygments (>= [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6b441e58f7752fdb4638dcd90711b09ff190a6a34d5c201451746c8e505806bb" +content-hash = "e5dc6b48e90197e684c51d28d2f99241f9fe582c6d30ee3ab027fb0f66aafc78" diff --git a/pyproject.toml b/pyproject.toml index 9f6e7a0..3552e9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,8 +21,8 @@ packages = [ [tool.poetry.dependencies] python = "^3.10" pyyaml = ">=4" -astx = "0.9.1" -pyirx = "1.1.0" +astx = "0.11.*" +pyirx = "1.2.*" [tool.poetry.group.dev.dependencies] pytest = ">=7" diff --git a/src/arx/ast.py b/src/arx/ast.py deleted file mode 100644 index 3d40d2e..0000000 --- a/src/arx/ast.py +++ /dev/null @@ -1,278 +0,0 @@ -"""AST classes and functions.""" - -from enum import Enum -from typing import List, Tuple - -from arx.lexer import SourceLocation - - -class ExprKind(Enum): - """The expression kind class used for downcasting.""" - - GenericKind = -1 - ModuleKind = -2 - - # variables - VariableKind = -10 - VarKind = -11 # var keyword for variable declaration - - # operators - UnaryOpKind = -20 - BinaryOpKind = -21 - - # functions - PrototypeKind = -30 - FunctionKind = -31 - CallKind = -32 - ReturnKind = -33 - - # control flow - IfKind = -40 - ForKind = -41 - - # data types - NullDTKind = -100 - BooleanDTKind = -101 - Int8DTKind = -102 - UInt8DTKind = -103 - Int16DTKind = -104 - UInt16DTKind = -105 - Int32DTKind = -106 - UInt32DTKind = -107 - Int64DTKind = -108 - UInt64DTKind = -109 - FloatDTKind = -110 - DoubleDTKind = -111 - BinaryDTKind = -112 - StringDTKind = -113 - FixedSizeBinaryDTKind = -114 - Date32DTKind = -115 - Date64DTKind = -116 - TimestampDTKind = -117 - Time32DTKind = -118 - Time64DTKind = -119 - Decimal128DTKind = -120 - Decimal256DTKind = -121 - - -class ExprAST: - """AST main expression class.""" - - loc: SourceLocation - kind: ExprKind - - def __init__(self, loc: SourceLocation = SourceLocation(0, 0)) -> None: - """Initialize the ExprAST instance.""" - self.kind = ExprKind.GenericKind - self.loc = loc - - -class BlockAST(ExprAST): - """The AST tree.""" - - nodes: List[ExprAST] - - def __init__(self) -> None: - """Initialize the BlockAST instance.""" - super().__init__() - self.nodes: List[ExprAST] = [] - - -class ModuleAST(BlockAST): - """AST main expression class.""" - - name: str - - def __init__(self, name: str) -> None: - """Initialize the ExprAST instance.""" - super().__init__() - self.name = name - self.kind = ExprKind.ModuleKind - - -class FloatExprAST(ExprAST): - """AST for the literal float number.""" - - value: float - - def __init__(self, val: float) -> None: - """Initialize the FloatAST instance.""" - super().__init__() - self.value = val - self.kind = ExprKind.FloatDTKind - - -class VariableExprAST(ExprAST): - """AST class for the variable usage.""" - - def __init__(self, loc: SourceLocation, name: str, type_name: str) -> None: - """Initialize the VariableExprAST instance.""" - super().__init__(loc) - self.name = name - self.type_name = type_name - self.kind = ExprKind.VariableKind - - def get_name(self) -> str: - """Return the variable name.""" - return self.name - - -class UnaryExprAST(ExprAST): - """AST class for the unary operator.""" - - def __init__(self, op_code: str, operand: ExprAST) -> None: - """Initialize the UnaryExprAST instance.""" - super().__init__() - self.op_code = op_code - self.operand = operand - self.kind = ExprKind.UnaryOpKind - - -class BinaryExprAST(ExprAST): - """AST class for the binary operator.""" - - def __init__( - self, loc: SourceLocation, op: str, lhs: ExprAST, rhs: ExprAST - ) -> None: - """Initialize the BinaryExprAST instance.""" - super().__init__(loc) - self.op = op - self.lhs = lhs - self.rhs = rhs - self.kind = ExprKind.BinaryOpKind - - -class CallExprAST(ExprAST): - """AST class for function call.""" - - def __init__( - self, loc: SourceLocation, callee: str, args: List[ExprAST] - ) -> None: - """Initialize the CallExprAST instance.""" - super().__init__(loc) - self.callee = callee - self.args = args - self.kind = ExprKind.CallKind - - -class IfStmtAST(ExprAST): - """AST class for `if` statement.""" - - cond: ExprAST - then_: BlockAST - else_: BlockAST - - def __init__( - self, - loc: SourceLocation, - cond: ExprAST, - then_: BlockAST, - else_: BlockAST, - ) -> None: - """Initialize the IfStmtAST instance.""" - super().__init__(loc) - self.cond = cond - self.then_ = then_ - self.else_ = else_ - self.kind = ExprKind.IfKind - - -class ForStmtAST(ExprAST): - """AST class for `For` statement.""" - - var_name: str - start: ExprAST - end: ExprAST - step: ExprAST - body: BlockAST - - def __init__( - self, - var_name: str, - start: ExprAST, - end: ExprAST, - step: ExprAST, - body: BlockAST, - ) -> None: - """Initialize the ForStmtAST instance.""" - super().__init__() - self.var_name = var_name - self.start = start - self.end = end - self.step = step - self.body = body - self.kind = ExprKind.ForKind - - -class VarExprAST(ExprAST): - """AST class for variable declaration.""" - - var_names: List[Tuple[str, ExprAST]] - type_name: str - body: ExprAST - - def __init__( - self, - var_names: List[Tuple[str, ExprAST]], - type_name: str, - body: ExprAST, - ) -> None: - """Initialize the VarExprAST instance.""" - super().__init__() - self.var_names = var_names - self.type_name = type_name - self.body = body - self.kind = ExprKind.VarKind - - -class PrototypeAST(ExprAST): - """AST class for function prototype declaration.""" - - name: str - args: List[VariableExprAST] - type_name: str - - def __init__( - self, - loc: SourceLocation, - name: str, - type_name: str, - args: List[VariableExprAST], - ) -> None: - """Initialize the PrototypeAST instance.""" - super().__init__() - self.name = name - self.args = args - self.type_name = type_name - self.line = loc.line - self.kind = ExprKind.PrototypeKind - - def get_name(self) -> str: - """Return the prototype name.""" - return self.name - - -class ReturnStmtAST(ExprAST): - """AST class for function `return` statement.""" - - value: ExprAST - - def __init__(self, value: ExprAST) -> None: - """Initialize the ReturnStmtAST instance.""" - super().__init__() - self.value = value - self.kind = ExprKind.ReturnKind - - -class FunctionAST(ExprAST): - """AST class for function definition.""" - - proto: PrototypeAST - body: BlockAST - - def __init__(self, proto: PrototypeAST, body: BlockAST) -> None: - """Initialize the FunctionAST instance.""" - super().__init__() - self.proto = proto - self.body = body - self.kind = ExprKind.FunctionKind diff --git a/src/arx/codegen.py b/src/arx/codegen.py new file mode 100644 index 0000000..21f185f --- /dev/null +++ b/src/arx/codegen.py @@ -0,0 +1,82 @@ +"""File Object, Executable or LLVM IR generation.""" + +import logging + +import astx + +from irx.builders.llvmliteir import LLVMLiteIR + +from arx.io import ArxIO +from arx.lexer import Lexer +from arx.parser import Parser + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) + + +INPUT_FILE: str = "" +OUTPUT_FILE: str = "" +ARX_VERSION: str = "" +IS_BUILD_LIB: bool = True + + +class ObjectGenerator: + """Generate object files or executable from an AST.""" + + output_file: str = "" + input_file: str = "" + # is_lib: bool = True + + def __init__( + self, + input_file: str = "", + output_file: str = "tmp.o", + is_lib: bool = True, + ): + self.input_file = input_file + self.output_file = output_file or f"{input_file}.o" + self.is_lib = is_lib + + def evaluate(self, ast_node: astx.AST, show_llvm_ir: bool = False) -> None: + """ + Compile an AST to an object file. + + Parameters + ---------- + ast_node: An AST object. + + Returns + ------- + int: The compilation result. + """ + logging.info("Starting main_loop") + + builder = LLVMLiteIR() + + # Convert LLVM IR into in-memory representation + if show_llvm_ir: + return print(str(builder.translate(ast_node))) + + builder.build(ast_node, self.output_file) + + def open_interactive(self) -> None: + """ + Open the Arx shell. + + Returns + ------- + int: The compilation result. + """ + # Prime the first token. + print(f"Arx {ARX_VERSION} \n") + print(">>> ") + + lexer = Lexer() + parser = Parser() + + while True: + try: + ArxIO.string_to_buffer(input()) + self.evaluate(parser.parse(lexer.lex())) + except KeyboardInterrupt: + break diff --git a/src/arx/codegen/__init__.py b/src/arx/codegen/__init__.py deleted file mode 100644 index 6973610..0000000 --- a/src/arx/codegen/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Set of modules for code generation.""" diff --git a/src/arx/codegen/ast_output.py b/src/arx/codegen/ast_output.py deleted file mode 100644 index 7ab6c5c..0000000 --- a/src/arx/codegen/ast_output.py +++ /dev/null @@ -1,244 +0,0 @@ -"""Set of classes and functions to emit the AST from a given source code.""" - -from typing import Any, Dict, List, TypeAlias, Union - -import yaml - -from arx import ast -from arx.codegen.base import CodeGenBase - -OutputValueAST: TypeAlias = Union[str, int, float, List[Any], Dict[str, Any]] - - -class ASTtoOutput(CodeGenBase): - """Show the AST for the given source code.""" - - result_stack: List[OutputValueAST] - - def __init__(self) -> None: - self.result_stack: List[OutputValueAST] = [] - - def visit_binary_expr(self, expr: ast.BinaryExprAST) -> None: - """ - Visit a ast.BinaryExprAST node. - - Parameters - ---------- - expr: The ast.BinaryExprAST node to visit. - """ - self.visit(expr.lhs) - lhs = self.result_stack.pop() - - self.visit(expr.rhs) - rhs = self.result_stack.pop() - - node = {f"BINARY[{expr.op}]": {"lhs": lhs, "rhs": rhs}} - self.result_stack.append(node) - - def visit_block(self, expr: ast.BlockAST) -> None: - """ - Visit method for tree ast. - - Parameters - ---------- - expr: The ast.BlockAST node to visit. - """ - block_node = [] - - for node in expr.nodes: - self.visit(node) - block_node.append(self.result_stack.pop()) - - self.result_stack.append(block_node) - - def visit_call_expr(self, expr: ast.CallExprAST) -> None: - """ - Visit a ast.CallExprAST node. - - Parameters - ---------- - expr: The ast.CallExprAST node to visit. - """ - call_args = [] - - for node in expr.args: - self.visit(node) - call_args.append(self.result_stack.pop()) - - call_node = {f"CALL[{expr.callee}]": {"args": call_args}} - self.result_stack.append(call_node) - - def visit_float_expr(self, expr: ast.FloatExprAST) -> None: - """ - Visit a ast.FloatExprAST node. - - Parameters - ---------- - expr: The ast.FloatExprAST node to visit. - """ - self.result_stack.append(f"FLOAT[{expr.value}]") - - def visit_if_stmt(self, expr: ast.IfStmtAST) -> None: - """ - Visit an ast.IfStmtAST node. - - Parameters - ---------- - expr: The ast.IfStmtAST node to visit. - """ - self.visit(expr.cond) - if_condition = self.result_stack.pop() - - self.visit(expr.then_) - if_then = self.result_stack.pop() - - if expr.else_: - self.visit(expr.else_) - if_else = self.result_stack.pop() - else: - if_else = [] - - node = { - "IF-STMT": { - "CONDITION": if_condition, - "THEN": if_then, - } - } - - if if_else: - node["IF-STMT"]["ELSE"] = if_else - - self.result_stack.append(node) - - def visit_for_stmt(self, expr: ast.ForStmtAST) -> None: - """ - Visit a ast.ForStmtAST node. - - Parameters - ---------- - expr: The ast.ForStmtAST node to visit. - """ - self.visit(expr.start) - for_start = self.result_stack.pop() - - self.visit(expr.end) - for_end = self.result_stack.pop() - - self.visit(expr.step) - for_step = self.result_stack.pop() - - self.visit(expr.body) - for_body = self.result_stack.pop() - - node = { - "FOR-STMT": { - "start": for_start, - "end": for_end, - "step": for_step, - "body": for_body, - } - } - self.result_stack.append(node) - - def visit_function(self, expr: ast.FunctionAST) -> None: - """ - Visit a ast.FunctionAST node. - - Parameters - ---------- - expr: The ast.FunctionAST node to visit. - """ - fn_args = [] - for node in expr.proto.args: - self.visit(node) - fn_args.append(self.result_stack.pop()) - - self.visit(expr.body) - fn_body = self.result_stack.pop() - - fn = {} - fn[f"FUNCTION[{expr.proto.name}]"] = { - "args": fn_args, - "body": fn_body, - } - - self.result_stack.append(fn) - - def visit_module(self, expr: ast.ModuleAST) -> None: - """ - Visit method for tree ast. - - Parameters - ---------- - expr: The ast.BlockAST node to visit. - """ - block_node = [] - - for node in expr.nodes: - self.visit(node) - block_node.append(self.result_stack.pop()) - - module_node = {f"MODULE[{expr.name}]": block_node} - - self.result_stack.append(module_node) - - def visit_prototype(self, expr: ast.PrototypeAST) -> None: - """ - Visit a ast.PrototypeAST node. - - Parameters - ---------- - expr: The ast.PrototypeAST node to visit. - """ - raise Exception("Visitor method not necessary") - - def visit_return_stmt(self, expr: ast.ReturnStmtAST) -> None: - """ - Visit a ast.ReturnStmtAST node. - - Parameters - ---------- - expr: The ast.ReturnStmtAST node to visit. - """ - self.visit(expr.value) - node = {"RETURN": self.result_stack.pop()} - self.result_stack.append(node) - - def visit_unary_expr(self, expr: ast.UnaryExprAST) -> None: - """ - Visit a ast.UnaryExprAST node. - - Parameters - ---------- - expr: The ast.UnaryExprAST node to visit. - """ - self.visit(expr.operand) - node = {f"UNARY[{expr.op_code}]": self.result_stack.pop()} - self.result_stack.append(node) - - def visit_var_expr(self, expr: ast.VarExprAST) -> None: - """ - Visit a ast.VarExprAST node. - - Parameters - ---------- - expr: The ast.VarExprAST node to visit. - """ - raise Exception("Variable declaration will be changed soon.") - - def visit_variable_expr(self, expr: ast.VariableExprAST) -> None: - """ - Visit a ast.VariableExprAST node. - - Parameters - ---------- - expr: The ast.VariableExprAST node to visit. - """ - self.result_stack.append(f"VARIABLE[{expr.name, expr.type_name}]") - - def emit_ast(self, tree_ast: ast.BlockAST) -> None: - """Print the AST for the given source code.""" - self.visit_block(tree_ast) - - ast_output = {"ROOT": self.result_stack.pop()} - print(yaml.dump(ast_output, sort_keys=False)) diff --git a/src/arx/codegen/base.py b/src/arx/codegen/base.py deleted file mode 100644 index 12bf2a9..0000000 --- a/src/arx/codegen/base.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Base module for code generation.""" - -from typing import Any, Callable, Dict, Type - -import llvmlite.binding as llvm - -from llvmlite import ir - -from arx import ast -from arx.exceptions import CodeGenException - - -class CodeGenBase: - """A base Visitor pattern class.""" - - def visit(self, expr: ast.ExprAST) -> None: - """Call the correspondent visit function for the given expr type.""" - map_visit_expr: Dict[Type[ast.ExprAST], Callable[[Any], None]] = { - ast.BinaryExprAST: self.visit_binary_expr, - ast.BlockAST: self.visit_block, - ast.CallExprAST: self.visit_call_expr, - ast.FloatExprAST: self.visit_float_expr, - ast.ForStmtAST: self.visit_for_stmt, - ast.FunctionAST: self.visit_function, - ast.IfStmtAST: self.visit_if_stmt, - ast.ModuleAST: self.visit_module, - ast.PrototypeAST: self.visit_prototype, - ast.ReturnStmtAST: self.visit_return_stmt, - ast.UnaryExprAST: self.visit_unary_expr, - ast.VarExprAST: self.visit_var_expr, - ast.VariableExprAST: self.visit_variable_expr, - } - - fn = map_visit_expr.get(type(expr)) - - if not fn: - print("Fail to downcasting ExprAST.") - return - - fn(expr) - - def visit_binary_expr(self, expr: ast.BinaryExprAST) -> None: - """Visit method for binary expression.""" - raise CodeGenException("Not implemented yet.") - - def visit_block(self, expr: ast.BlockAST) -> None: - """Visit method for tree ast.""" - raise CodeGenException("Not implemented yet.") - - def visit_call_expr(self, expr: ast.CallExprAST) -> None: - """Visit method for function call.""" - raise CodeGenException("Not implemented yet.") - - def visit_float_expr(self, expr: ast.FloatExprAST) -> None: - """Visit method for float.""" - raise CodeGenException("Not implemented yet.") - - def visit_for_stmt(self, expr: ast.ForStmtAST) -> None: - """Visit method for `for` loop.""" - raise CodeGenException("Not implemented yet.") - - def visit_if_stmt(self, expr: ast.IfStmtAST) -> None: - """Visit method for if statement.""" - raise CodeGenException("Not implemented yet.") - - def visit_function(self, expr: ast.FunctionAST) -> None: - """Visit method for function definition.""" - raise CodeGenException("Not implemented yet.") - - def visit_module(self, expr: ast.ModuleAST) -> None: - """Visit method for module.""" - raise CodeGenException("Not implemented yet.") - - def visit_prototype(self, expr: ast.PrototypeAST) -> None: - """Visit method for prototype.""" - raise CodeGenException("Not implemented yet.") - - def visit_return_stmt(self, expr: ast.ReturnStmtAST) -> None: - """Visit method for expression.""" - raise CodeGenException("Not implemented yet.") - - def visit_unary_expr(self, expr: ast.UnaryExprAST) -> None: - """Visit method for unary expression.""" - raise CodeGenException("Not implemented yet.") - - def visit_var_expr(self, expr: ast.VarExprAST) -> None: - """Visit method for variable declaration.""" - raise CodeGenException("Not implemented yet.") - - def visit_variable_expr(self, expr: ast.VariableExprAST) -> None: - """Visit method for variable usage.""" - raise CodeGenException("Not implemented yet.") - - -class VariablesLLVM: - """Store all the LLVM variables that is used for the code generation.""" - - FLOAT_TYPE: ir.types.Type - DOUBLE_TYPE: ir.types.Type - INT8_TYPE: ir.types.Type - INT32_TYPE: ir.types.Type - VOID_TYPE: ir.types.Type - - context: ir.context.Context - module: ir.module.Module - - ir_builder: ir.builder.IRBuilder - - def get_data_type(self, type_name: str) -> ir.types.Type: - """ - Get the LLVM data type for the given type name. - - Parameters - ---------- - type_name (str): The name of the type. - - Returns - ------- - ir.Type: The LLVM data type. - """ - if type_name == "float": - return self.FLOAT_TYPE - elif type_name == "double": - return self.DOUBLE_TYPE - elif type_name == "int8": - return self.INT8_TYPE - elif type_name == "int32": - return self.INT32_TYPE - elif type_name == "char": - return self.INT8_TYPE - elif type_name == "void": - return self.VOID_TYPE - - raise CodeGenException("[EE] CodeGen(LLVM): type_name not valid.") - - -class CodeGenLLVMBase(CodeGenBase): - """ArxLLVM gathers all the main global variables for LLVM workflow.""" - - # AllocaInst - named_values: Dict[str, Any] = {} # noqa: RUF012 - _llvm: VariablesLLVM - - def initialize(self) -> None: - """Initialize self.""" - # self._llvm.context = ir.context.Context() - self._llvm = VariablesLLVM() - self._llvm.module = ir.module.Module("Arx") - - # initialize the target registry etc. - llvm.initialize() - llvm.initialize_all_asmprinters() - llvm.initialize_all_targets() - llvm.initialize_native_target() - llvm.initialize_native_asmparser() - llvm.initialize_native_asmprinter() - - # Create a new builder for the module. - self._llvm.ir_builder = ir.IRBuilder() - - # Data Types - self._llvm.FLOAT_TYPE = ir.FloatType() - self._llvm.DOUBLE_TYPE = ir.DoubleType() - self._llvm.INT8_TYPE = ir.IntType(8) - self._llvm.INT32_TYPE = ir.IntType(32) - self._llvm.VOID_TYPE = ir.VoidType() - - def evaluate(self, tree: ast.BlockAST) -> None: - """Evaluate the given AST object.""" - raise CodeGenException(f"Not an evaluation for {tree} implement yet.") diff --git a/src/arx/codegen/file_object.py b/src/arx/codegen/file_object.py deleted file mode 100644 index 165b160..0000000 --- a/src/arx/codegen/file_object.py +++ /dev/null @@ -1,678 +0,0 @@ -"""File Object, Executable or LLVM IR generation.""" - -import logging -import os - -from typing import Any, Dict, List, Union - -from llvmlite import binding as llvm -from llvmlite import ir - -from arx import ast -from arx.codegen.base import CodeGenLLVMBase -from arx.io import ArxFile, ArxIO -from arx.lexer import Lexer -from arx.parser import Parser - -logging.basicConfig(level=logging.INFO) -LOG = logging.getLogger(__name__) - - -INPUT_FILE: str = "" -OUTPUT_FILE: str = "" -ARX_VERSION: str = "" -IS_BUILD_LIB: bool = True - - -class ObjectGenerator(CodeGenLLVMBase): - """Generate object files or executable from an AST.""" - - function_protos: Dict[str, ast.PrototypeAST] - output_file: str = "" - input_file: str = "" - is_lib: bool = True - result_stack: List[Union[ir.Value, ir.Function]] = [] # noqa: RUF012 - - def __init__( - self, - input_file: str = "", - output_file: str = "tmp.o", - is_lib: bool = True, - ): - self.input_file = input_file - self.output_file = output_file or f"{input_file}.o" - self.is_lib = is_lib - - self.function_protos: Dict[str, ast.PrototypeAST] = {} - self.module = ir.Module() - - self.result_stack: List[Union[ir.Value, ir.Function]] = [] - - super().initialize() - - logging.info("target_triple") - self.target = llvm.Target.from_default_triple() - self.target_machine = self.target.create_target_machine( - codemodel="small" - ) - - self._add_builtins() - - def _add_builtins(self) -> None: - # The C++ tutorial adds putchard() simply by defining it in the host - # C++ code, which is then accessible to the JIT. It doesn't work as - # simply for us; but luckily it's very easy to define new "C level" - # functions for our JITed code to use - just emit them as LLVM IR. - # This is what this method does. - - # Add the declaration of putchar - putchar_ty = ir.FunctionType( - self._llvm.INT32_TYPE, [self._llvm.INT32_TYPE] - ) - putchar = ir.Function(self._llvm.module, putchar_ty, "putchar") - - # Add putchard - putchard_ty = ir.FunctionType( - self._llvm.FLOAT_TYPE, [self._llvm.FLOAT_TYPE] - ) - putchard = ir.Function(self._llvm.module, putchard_ty, "putchard") - - ir_builder = ir.IRBuilder(putchard.append_basic_block("entry")) - - ival = ir_builder.fptoui( - putchard.args[0], self._llvm.INT32_TYPE, "intcast" - ) - - ir_builder.call(putchar, [ival]) - ir_builder.ret(ir.Constant(self._llvm.FLOAT_TYPE, 0)) - - def evaluate( - self, block_ast: ast.BlockAST, show_llvm_ir: bool = False - ) -> None: - """ - Compile an AST to an object file. - - Parameters - ---------- - block_ast: The AST tree object. - - Returns - ------- - int: The compilation result. - """ - logging.info("Starting main_loop") - self.emit_object(block_ast) - - # Convert LLVM IR into in-memory representation - if show_llvm_ir: - return print(str(self._llvm.module)) - - result_mod = llvm.parse_assembly(str(self._llvm.module)) - result_object = self.target_machine.emit_object(result_mod) - - if self.output_file == "": - self.output_file = self.input_file + ".o" - - # Output object code to a file. - with open(self.output_file, "wb") as obj_file: - obj_file.write(result_object) - print("Wrote " + self.output_file) - - if not self.is_lib: - self.compile_executable() - - def compile_executable(self) -> None: - """Compile into an executable file.""" - print("Not fully implemented yet.") - # generate an executable file - - linker_path = "clang++" - executable_path = self.input_file + "c" - # note: it just has a purpose to demonstrate an initial implementation - # it will be improved in a follow-up PR - content = ( - "#include \n" - "int main() {\n" - ' std::cout << "ARX[WARNING]: ' - 'This is an empty executable file" << std::endl;\n' - "}\n" - ) - - main_cpp_path = ArxFile.create_tmp_file(content) - - if main_cpp_path == "": - raise Exception("ARX[FAIL]: Executable file was not created.") - - # Example (running it from a shell prompt): - # clang++ \ - # ${CLANG_EXTRAS} \ - # ${DEBUG_FLAGS} \ - # -fPIC \ - # -std=c++20 \ - # "${TEST_DIR_PATH}/integration/${test_name}.cpp" \ - # ${OBJECT_FILE} \ - # -o "${TMP_DIR}/main" - - compiler_args = [ - "-fPIC", - "-std=c++20", - main_cpp_path, - self.output_file, - "-o", - executable_path, - ] - - # Add any additional compiler flags or include paths as needed - # compiler_args.append("-I/path/to/include") - - linker_path = "clang++" - compiler_cmd = linker_path + " " + " ".join(compiler_args) - - print("ARX[INFO]: ", compiler_cmd) - compile_result = os.system(compiler_cmd) # nosec - - ArxFile.delete_file(main_cpp_path) - - if compile_result != 0: - llvm.errs() << "failed to compile and link object file" - exit(1) - - def open_interactive(self) -> None: - """ - Open the Arx shell. - - Returns - ------- - int: The compilation result. - """ - # Prime the first token. - print(f"Arx {ARX_VERSION} \n") - print(">>> ") - - lexer = Lexer() - parser = Parser() - - while True: - try: - ArxIO.string_to_buffer(input()) - self.evaluate(parser.parse(lexer.lex())) - except KeyboardInterrupt: - break - - def get_function(self, name: str) -> ir.Function: - """ - Put the function defined by the given name to result_func. - - Parameters - ---------- - name: Function name - """ - if name in self._llvm.module.globals: - fn = self._llvm.module.get_global(name) - self.result_stack.append(fn) - return - - if name in self.function_protos: - self.visit(self.function_protos[name]) - return - - def create_entry_block_alloca( - self, var_name: str, type_name: str - ) -> Any: # llvm.AllocaInst - """ - Create an alloca instruction in the entry block of the function. - - This is used for mutable variables, etc. - - Parameters - ---------- - fn: The llvm function - var_name: The variable name - type_name: The type name - - Returns - ------- - An llvm allocation instance. - """ - tmp_builder = ir.IRBuilder() - tmp_builder.position_at_start( - self._llvm.ir_builder.function.entry_basic_block - ) - return tmp_builder.alloca( - self._llvm.get_data_type(type_name), None, var_name - ) - - def emit_object(self, tree: ast.BlockAST) -> None: - """ - Walk the AST and generate code for each node. - - top ::= definition | external | expression | ';' - - Parameters - ---------- - tree: The ast.BlockAST instance. - """ - self.visit_block(tree) - - def visit_float_expr(self, expr: ast.FloatExprAST) -> None: - """ - Code generation for ast.FloatExprAST. - - Parameters - ---------- - expr: The ast.FloatExprAST instance - """ - result = ir.Constant(self._llvm.FLOAT_TYPE, expr.value) - self.result_stack.append(result) - - def visit_variable_expr(self, expr: ast.VariableExprAST) -> None: - """ - Code generation for ast.VariableExprAST. - - Parameters - ---------- - expr: The ast.VariableExprAST instance - """ - expr_var = self.named_values.get(expr.name) - - if not expr_var: - msg = f"Unknown variable name: {expr.name}" - raise Exception(msg) - - result = self._llvm.ir_builder.load(expr_var, expr.name) - self.result_stack.append(result) - - def visit_unary_expr(self, expr: ast.UnaryExprAST) -> None: - """ - Code generation for ast.UnaryExprAST. - - Parameters - ---------- - expr: The ast.UnaryExprAST instance - """ - self.visit(expr.operand) - operand_value = self.result_stack.pop() - if not operand_value: - raise Exception("ObjectGen: Empty unary operand.") - - fn = self.get_function("unary" + expr.op_code) - if not fn: - raise Exception("Unknown unary operator") - - result = self._llvm.ir_builder.call(fn, [operand_value], "unop") - self.result_stack.append(result) - - def visit_binary_expr(self, expr: ast.BinaryExprAST) -> None: - """ - Code generation for ast.BinaryExprAST. - - Parameters - ---------- - expr: The ast.BinaryExprAST instance - """ - if expr.op == "=": - # Special case '=' because we don't want to emit the lhs as an - # expression. - # Assignment requires the lhs to be an identifier. - # This assumes we're building without RTTI because LLVM builds - # that way by default. - # If you build LLVM with RTTI, this can be changed to a - # dynamic_cast for automatic error checking. - var_lhs = expr.lhs - - if not isinstance(var_lhs, ast.VariableExprAST): - raise Exception("destination of '=' must be a variable") - - # Codegen the rhs. - self.visit(expr.rhs) - llvm_rhs = self.result_stack.pop() - - if not llvm_rhs: - raise Exception("codegen: Invalid rhs expression.") - - # Look up the name. - llvm_lhs = self.named_values[var_lhs.get_name()] - - if not llvm_lhs: - raise Exception("codegen: Invalid lhs variable name") - - self._llvm.ir_builder.store(llvm_rhs, llvm_lhs) - result = llvm_rhs - self.result_stack.append(result) - return - - self.visit(expr.lhs) - llvm_lhs = self.result_stack.pop() - self.visit(expr.rhs) - llvm_rhs = self.result_stack.pop() - - if not llvm_lhs or not llvm_rhs: - raise Exception("codegen: Invalid lhs/rhs") - - if expr.op == "+": - result = self._llvm.ir_builder.fadd(llvm_lhs, llvm_rhs, "addtmp") - self.result_stack.append(result) - return - elif expr.op == "-": - result = self._llvm.ir_builder.fsub(llvm_lhs, llvm_rhs, "subtmp") - self.result_stack.append(result) - return - elif expr.op == "*": - result = self._llvm.ir_builder.fmul(llvm_lhs, llvm_rhs, "multmp") - self.result_stack.append(result) - return - elif expr.op == "<": - cmp_result = self._llvm.ir_builder.fcmp_unordered( - "<", llvm_lhs, llvm_rhs, "lttmp" - ) - # Convert bool 0/1 to float 0.0 or 1.0 - result = self._llvm.ir_builder.uitofp( - cmp_result, self._llvm.FLOAT_TYPE, "booltmp" - ) - self.result_stack.append(result) - return - elif expr.op == ">": - cmp_result = self._llvm.ir_builder.fcmp_unordered( - ">", llvm_lhs, llvm_rhs, "gttmp" - ) - # Convert bool 0/1 to float 0.0 or 1.0 - result = self._llvm.ir_builder.uitofp( - cmp_result, self._llvm.FLOAT_TYPE, "booltmp" - ) - self.result_stack.append(result) - return - - # If it wasn't a builtin binary operator, it must be a user defined - # one. Emit a call to it. - fn = self.get_function("binary" + expr.op) - result = self._llvm.ir_builder.call(fn, [llvm_lhs, llvm_rhs], "binop") - self.result_stack.append(result) - - def visit_block(self, expr: ast.BlockAST) -> None: - """Visit method for BlockAST.""" - result = [] - for node in expr.nodes: - self.visit(node) - result.append(self.result_stack.pop()) - self.result_stack.append(result) - - def visit_call_expr(self, expr: ast.CallExprAST) -> None: - """ - Code generation for ast.CallExprAST. - - Parameters - ---------- - expr: The ast.CallExprAST instance - """ - callee_f = self.get_function(expr.callee) - - if not callee_f: - raise Exception("Unknown function referenced") - - if len(callee_f.args) != len(expr.args): - raise Exception("codegen: Incorrect # arguments passed.") - - llvm_args = [] - for arg in expr.args: - self.visit(arg) - llvm_arg = self.result_stack.pop() - if not llvm_arg: - raise Exception("codegen: Invalid callee argument.") - llvm_args.append(llvm_arg) - - result = self._llvm.ir_builder.call(callee_f, llvm_args, "calltmp") - self.result_stack.append(result) - - def visit_if_stmt(self, expr: ast.IfStmtAST) -> None: - """ - Code generation for ast.IfStmtAST. - - Parameters - ---------- - expr: The ast.IfStmtAST instance - """ - self.visit(expr.cond) - cond_v = self.result_stack.pop() - - if not cond_v: - raise Exception("codegen: Invalid condition expression.") - - # Convert condition to a bool by comparing non-equal to 0.0. - cond_v = self._llvm.ir_builder.fcmp_ordered( - "!=", - cond_v, - ir.Constant(self._llvm.FLOAT_TYPE, 0.0), - ) - - # fn = self._llvm.ir_builder.position_at_start().getParent() - - # Create blocks for the then and else cases. Insert the 'then' block - # at the end of the function. - # then_bb = ir.Block(self._llvm.ir_builder.function, "then", fn) - then_bb = self._llvm.ir_builder.function.append_basic_block("then") - else_bb = ir.Block(self._llvm.ir_builder.function, "else") - merge_bb = ir.Block(self._llvm.ir_builder.function, "ifcont") - - self._llvm.ir_builder.cbranch(cond_v, then_bb, else_bb) - - # Emit then value. - self._llvm.ir_builder.position_at_start(then_bb) - self.visit(expr.then_) - then_v = self.result_stack.pop() - - if not then_v: - raise Exception("codegen: `Then` expression is invalid.") - - self._llvm.ir_builder.branch(merge_bb) - - # Codegen of 'then' can change the current block, update then_bb - # for the PHI. - then_bb = self._llvm.ir_builder.block - - # Emit else block. - self._llvm.ir_builder.function.basic_blocks.append(else_bb) - self._llvm.ir_builder.position_at_start(else_bb) - self.visit(expr.else_) - else_v = self.result_stack.pop() - if not else_v: - raise Exception("Revisit this!") - - # Emission of else_val could have modified the current basic block. - else_bb = self._llvm.ir_builder.block - self._llvm.ir_builder.branch(merge_bb) - - # Emit merge block. - self._llvm.ir_builder.function.basic_blocks.append(merge_bb) - self._llvm.ir_builder.position_at_start(merge_bb) - phi = self._llvm.ir_builder.phi(self._llvm.FLOAT_TYPE, "iftmp") - - phi.add_incoming(then_v, then_bb) - phi.add_incoming(else_v, else_bb) - - self.result_stack.append(phi) - - def visit_for_stmt(self, expr: ast.ForStmtAST) -> None: - """ - Code generation for ast.ForStmtAST. - - Parameters - ---------- - expr: The ast.ForStmtAST instance. - """ - saved_block = self._llvm.ir_builder.block - var_addr = self.create_entry_block_alloca(expr.var_name, "float") - self._llvm.ir_builder.position_at_end(saved_block) - - # Emit the start code first, without 'variable' in scope. - self.visit(expr.start) - start_val = self.result_stack.pop() - if not start_val: - raise Exception("codegen: Invalid start argument.") - - # Store the value into the alloca. - self._llvm.ir_builder.store(start_val, var_addr) - - # Make the new basic block for the loop header, inserting after - # current block. - loop_bb = self._llvm.ir_builder.function.append_basic_block("loop") - - # Insert an explicit fall through from the current block to the - # loop_bb. - self._llvm.ir_builder.branch(loop_bb) - - # Start insertion in loop_bb. - self._llvm.ir_builder.position_at_start(loop_bb) - - # Within the loop, the variable is defined equal to the PHI node. - # If it shadows an existing variable, we have to restore it, so save - # it now. - old_val = self.named_values.get(expr.var_name) - self.named_values[expr.var_name] = var_addr - - # Emit the body of the loop. This, like any other expr, can change - # the current basic_block. Note that we ignore the value computed by - # the body, but don't allow an error. - self.visit(expr.body) - body_val = self.result_stack.pop() - - if not body_val: - return - - # Emit the step value. - if expr.step: - self.visit(expr.step) - step_val = self.result_stack.pop() - if not step_val: - return - else: - # If not specified, use 1.0. - step_val = ir.Constant(self._llvm.FLOAT_TYPE, 1.0) - - # Compute the end condition. - self.visit(expr.end) - end_cond = self.result_stack.pop() - if not end_cond: - return - - # Reload, increment, and restore the var_addr. This handles the case - # where the body of the loop mutates the variable. - cur_var = self._llvm.ir_builder.load(var_addr, expr.var_name) - next_var = self._llvm.ir_builder.fadd(cur_var, step_val, "nextvar") - self._llvm.ir_builder.store(next_var, var_addr) - - # Convert condition to a bool by comparing non-equal to 0.0. - end_cond = self._llvm.ir_builder.fcmp_ordered( - "!=", - end_cond, - ir.Constant(self._llvm.DOUBLE_TYPE, 0.0), - "loopcond", - ) - - # Create the "after loop" block and insert it. - after_bb = self._llvm.ir_builder.function.append_basic_block( - "afterloop" - ) - - # Insert the conditional branch into the end of loop_bb. - self._llvm.ir_builder.cbranch(end_cond, loop_bb, after_bb) - - # Any new code will be inserted in after_bb. - self._llvm.ir_builder.position_at_start(after_bb) - - # Restore the unshadowed variable. - if old_val: - self.named_values[expr.var_name] = old_val - else: - self.named_values.pop(expr.var_name, None) - - # for expr always returns 0.0. - result = ir.Constant(self._llvm.FLOAT_TYPE, 0.0) - self.result_stack.append(result) - - def visit_var_expr(self, expr: ast.VarExprAST) -> None: - """ - Code generation for ast.VarExprAST. - - Parameters - ---------- - expr: The ast.VarExprAST instance. - """ - raise Exception(f"CodeGen for {expr} not implemented yet.") - - def visit_prototype(self, expr: ast.PrototypeAST) -> ir.Function: - """ - Code generation for PrototypeExprAST. - - Parameters - ---------- - expr: The ast.PrototypeAST instance. - """ - args_type = [self._llvm.FLOAT_TYPE] * len(expr.args) - return_type = self._llvm.get_data_type("float") - fn_type = ir.FunctionType(return_type, args_type, False) - - fn = ir.Function(self._llvm.module, fn_type, expr.name) - - # Set names for all arguments. - for idx, arg in enumerate(fn.args): - fn.args[idx].name = expr.args[idx].name - - return fn - - def visit_function(self, expr: ast.FunctionAST) -> ir.Function: - """ - Code generation for FunctionExprAST. - - Transfer ownership of the prototype to the ArxLLVM::function_protos - map, but keep a reference to it for use below. - - Parameters - ---------- - expr: The ast.FunctionAST instance. - """ - proto = expr.proto - self.function_protos[expr.proto.get_name()] = expr.proto - fn = self.get_function(proto.get_name()) - - if not fn: - raise Exception("codegen: Invalid function.") - - # Create a new basic block to start insertion into. - basic_block = fn.append_basic_block("entry") - self._llvm.ir_builder = ir.IRBuilder(basic_block) - - for llvm_arg in fn.args: - # Create an alloca for this variable. - alloca = self._llvm.ir_builder.alloca( - self._llvm.FLOAT_TYPE, name=llvm_arg.name - ) - - # Store the initial value into the alloca. - self._llvm.ir_builder.store(llvm_arg, alloca) - - # Add arguments to variable symbol table. - self.named_values[llvm_arg.name] = alloca - - self.visit(expr.body) - retval = self.result_stack.pop() - - # Validate the generated code, checking for consistency. - if retval: - self._llvm.ir_builder.ret(retval) - else: - self._llvm.ir_builder.ret(ir.Constant(self._llvm.FLOAT_TYPE, 0)) - return fn - - def visit_return_stmt(self, expr: ast.ReturnStmtAST) -> None: - """ - Code generation for ast.ReturnStmtAST. - - Parameters - ---------- - expr: The ast.ReturnStmtAST instance. - """ - # llvm_return_val = self.result_val - # - # if llvm_return_val: - # self._llvm.ir_builder.CreateRet(llvm_return_val) - return diff --git a/src/arx/lexer.py b/src/arx/lexer.py index 13d393a..236c467 100644 --- a/src/arx/lexer.py +++ b/src/arx/lexer.py @@ -8,28 +8,13 @@ from enum import Enum from typing import Any, Dict, List, cast +from astx import SourceLocation + from arx.io import ArxIO EOF = "" -@dataclass -class SourceLocation: - """ - Represents the source location with line and column information. - - Attributes - ---------- - line : int - Line number. - col : int - Column number. - """ - - line: int = 0 - col: int = 0 - - class TokenKind(Enum): """TokenKind enumeration for known variables returned by the lexer.""" @@ -43,6 +28,7 @@ class TokenKind(Enum): # data types identifier: int = -10 float_literal: int = -11 + int32_literal: int = -12 # control flow kw_if: int = -20 @@ -91,6 +77,7 @@ class TokenKind(Enum): TokenKind.identifier: "identifier", TokenKind.indent: "indent", TokenKind.float_literal: "float", + TokenKind.int32_literal: "int32", TokenKind.kw_if: "if", TokenKind.kw_then: "then", TokenKind.kw_else: "else", @@ -149,7 +136,7 @@ def get_display_value(self) -> str: return "(" + str(self.value) + ")" if self.kind == TokenKind.indent: return "(" + str(self.value) + ")" - elif self.kind == TokenKind.float_literal: + elif self.kind == TokenKind.int32_literal: return "(" + str(self.value) + ")" return "" @@ -327,8 +314,8 @@ def get_token(self) -> Token: self.last_char = self.advance() return Token( - kind=TokenKind.float_literal, - value=float(num_str), + kind=TokenKind.int32_literal, + value=int(num_str), location=self.lex_loc, ) diff --git a/src/arx/main.py b/src/arx/main.py index cf7def7..52a9cab 100644 --- a/src/arx/main.py +++ b/src/arx/main.py @@ -4,9 +4,9 @@ from typing import Any, List -from arx import ast -from arx.codegen.ast_output import ASTtoOutput -from arx.codegen.file_object import ObjectGenerator +import astx + +from arx.codegen import ObjectGenerator from arx.io import ArxIO from arx.lexer import Lexer from arx.parser import Parser @@ -49,7 +49,7 @@ def show_ast(self) -> None: """Print the AST for the given input file.""" lexer = Lexer() parser = Parser() - tree_ast = ast.BlockAST() + tree_ast = astx.Block() for input_file in self.input_files: ArxIO.file_to_buffer(input_file) @@ -57,8 +57,7 @@ def show_ast(self) -> None: module_ast = parser.parse(lexer.lex(), module_name) tree_ast.nodes.append(module_ast) - printer = ASTtoOutput() - printer.emit_ast(tree_ast) + print(tree_ast.get_struct()) def show_tokens(self) -> None: """Print the AST for the given input file.""" @@ -83,7 +82,7 @@ def compile(self, show_llvm_ir: bool = False) -> None: lexer = Lexer() parser = Parser() - tree_ast: ast.BlockAST = ast.BlockAST() + tree_ast: astx.Block = astx.Block() for input_file in self.input_files: ArxIO.file_to_buffer(input_file) diff --git a/src/arx/parser.py b/src/arx/parser.py index 003334b..f08b915 100644 --- a/src/arx/parser.py +++ b/src/arx/parser.py @@ -1,10 +1,13 @@ """parser module gather all functions and classes for parsing.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, cast + +import astx + +from astx.base import SourceLocation -from arx import ast from arx.exceptions import ParserException -from arx.lexer import SourceLocation, Token, TokenKind, TokenList +from arx.lexer import Token, TokenKind, TokenList INDENT_SIZE = 2 @@ -38,19 +41,19 @@ def clean(self) -> None: def parse( self, tokens: TokenList, module_name: str = "main" - ) -> ast.BlockAST: + ) -> astx.Block: """ Parse the input code. Returns ------- - ast.BlockAST + astx.Block The parsed abstract syntax tree (AST), or None if parsing fails. """ self.clean() self.tokens = tokens - tree: ast.ModuleAST = ast.ModuleAST(module_name) + tree: astx.Module = astx.Module(module_name) self.tokens.get_next_token() if self.tokens.cur_tok.kind == TokenKind.not_initialized: @@ -84,45 +87,45 @@ def get_tok_precedence(self) -> int: """ return self.bin_op_precedence.get(self.tokens.cur_tok.value, -1) - def parse_function(self) -> ast.FunctionAST: + def parse_function(self) -> astx.Function: """ Parse the function definition expression. Returns ------- - ast.FunctionAST + astx.Function The parsed function definition, or None if parsing fails. """ self.tokens.get_next_token() # eat function. - proto: ast.PrototypeAST = self.parse_prototype() - return ast.FunctionAST(proto, self.parse_block()) + proto: astx.FunctionPrototype = self.parse_prototype() + return astx.Function(proto, self.parse_block()) - def parse_extern(self) -> ast.PrototypeAST: + def parse_extern(self) -> astx.FunctionPrototype: """ Parse the extern expression. Returns ------- - ast.PrototypeAST + astx.FunctionPrototype The parsed extern expression as a prototype, or None if parsing fails. """ self.tokens.get_next_token() # eat extern. return self.parse_extern_prototype() - def parse_primary(self) -> ast.ExprAST: + def parse_primary(self) -> astx.AST: """ Parse the primary expression. Returns ------- - ast.ExprAST + astx.Expr The parsed primary expression, or None if parsing fails. """ if self.tokens.cur_tok.kind == TokenKind.identifier: return self.parse_identifier_expr() - elif self.tokens.cur_tok.kind == TokenKind.float_literal: - return self.parse_float_expr() + elif self.tokens.cur_tok.kind == TokenKind.int32_literal: + return self.parse_int32_expr() elif self.tokens.cur_tok == Token(kind=TokenKind.operator, value="("): return self.parse_paren_expr() elif self.tokens.cur_tok.kind == TokenKind.kw_if: @@ -147,13 +150,13 @@ def parse_primary(self) -> ast.ExprAST: self.tokens.get_next_token() # eat unknown token raise Exception(msg) - def parse_block(self) -> ast.BlockAST: + def parse_block(self) -> astx.Block: """Parse a block of nodes.""" cur_indent: int = self.tokens.cur_tok.value self.tokens.get_next_token() # eat indentation - block: ast.BlockAST = ast.BlockAST() + block: astx.Block = astx.Block() if cur_indent == self.indent_level: raise ParserException("There is no new block to be parsed.") @@ -163,7 +166,7 @@ def parse_block(self) -> ast.BlockAST: while expr := self.parse_expression(): block.nodes.append(expr) - # if isinstance(expr, ast.IfStmtAST): + # if isinstance(expr, astx.If): # breakpoint() if self.tokens.cur_tok.kind != TokenKind.indent: break @@ -181,32 +184,32 @@ def parse_block(self) -> ast.BlockAST: self.indent_level -= INDENT_SIZE return block - def parse_expression(self) -> ast.ExprAST: + def parse_expression(self) -> astx.DataType: """ Parse an expression. Returns ------- - ast.ExprAST + astx.Expr The parsed expression, or None if parsing fails. """ - lhs: ast.ExprAST = self.parse_unary() + lhs: astx.Expr = self.parse_unary() return self.parse_bin_op_rhs(0, lhs) - def parse_if_stmt(self) -> ast.IfStmtAST: + def parse_if_stmt(self) -> astx.If: """ Parse the `if` expression. Returns ------- - ast.IfStmtAST + astx.If The parsed `if` expression, or None if parsing fails. """ if_loc: SourceLocation = self.tokens.cur_tok.location self.tokens.get_next_token() # eat the if. - cond: ast.ExprAST = self.parse_expression() + cond: astx.Expr = self.parse_expression() if self.tokens.cur_tok != Token(kind=TokenKind.operator, value=":"): msg = ( @@ -218,8 +221,8 @@ def parse_if_stmt(self) -> ast.IfStmtAST: self.tokens.get_next_token() # eat the ':' - then_block: ast.BlockAST = ast.BlockAST() - else_block: ast.BlockAST = ast.BlockAST() + then_block: astx.Block = astx.Block() + else_block: astx.Block = astx.Block() then_block = self.parse_block() @@ -242,28 +245,28 @@ def parse_if_stmt(self) -> ast.IfStmtAST: self.tokens.get_next_token() # eat the ':' else_block = self.parse_block() - return ast.IfStmtAST(if_loc, cond, then_block, else_block) + return astx.If(cond, then_block, else_block, loc=if_loc) - def parse_float_expr(self) -> ast.FloatExprAST: + def parse_int32_expr(self) -> astx.LiteralInt32: """ Parse the number expression. Returns ------- - ast.FloatExprAST - The parsed float expression. + astx.LiteralInt32 + The parsed int32 expression. """ - result = ast.FloatExprAST(self.tokens.cur_tok.value) + result = astx.LiteralInt32(value=self.tokens.cur_tok.value) self.tokens.get_next_token() # consume the number return result - def parse_paren_expr(self) -> ast.ExprAST: + def parse_paren_expr(self) -> astx.Expr: """ Parse the parenthesis expression. Returns ------- - ast.ExprAST + astx.Expr The parsed expression. """ self.tokens.get_next_token() # eat (. @@ -274,13 +277,13 @@ def parse_paren_expr(self) -> ast.ExprAST: self.tokens.get_next_token() # eat ). return expr - def parse_identifier_expr(self) -> ast.ExprAST: + def parse_identifier_expr(self) -> astx.Expr: """ Parse the identifier expression. Returns ------- - ast.ExprAST + astx.Expr The parsed expression, or None if parsing fails. """ id_name: str = self.tokens.cur_tok.value @@ -292,11 +295,11 @@ def parse_identifier_expr(self) -> ast.ExprAST: if self.tokens.cur_tok != Token(kind=TokenKind.operator, value="("): # Simple variable ref, not a function call # todo: we need to get the variable type from a specific scope - return ast.VariableExprAST(id_loc, id_name, "float") + return astx.Variable(id_name, loc=id_loc) # Call. self.tokens.get_next_token() # eat ( - args: List[ast.ExprAST] = [] + args: List[astx.DataType] = [] if self.tokens.cur_tok != Token(kind=TokenKind.operator, value=")"): while True: args.append(self.parse_expression()) @@ -317,15 +320,15 @@ def parse_identifier_expr(self) -> ast.ExprAST: # Eat the ')'. self.tokens.get_next_token() - return ast.CallExprAST(id_loc, id_name, args) + return astx.FunctionCall(id_name, args=tuple(args), loc=id_loc) - def parse_for_stmt(self) -> ast.ForStmtAST: + def parse_for_stmt(self) -> astx.ForRangeLoop: """ Parse the `for` expression. Returns ------- - ast.ForStmtAST + astx.ForRangeLoop The parsed `for` expression, or None if parsing fails. """ self.tokens.get_next_token() # eat the for. @@ -340,40 +343,40 @@ def parse_for_stmt(self) -> ast.ForStmtAST: raise Exception("Parser: Expected '=' after for") self.tokens.get_next_token() # eat '='. - start: ast.ExprAST = self.parse_expression() + start: astx.Expr = self.parse_expression() if self.tokens.cur_tok != Token(kind=TokenKind.operator, value=","): raise Exception("Parser: Expected ',' after for start value") self.tokens.get_next_token() - end: ast.ExprAST = self.parse_expression() + end: astx.Expr = self.parse_expression() # The step value is optional if self.tokens.cur_tok == Token(kind=TokenKind.operator, value=","): self.tokens.get_next_token() step = self.parse_expression() else: - step = ast.FloatExprAST(1.0) + step = astx.LiteralInt32(1) if self.tokens.cur_tok.kind != TokenKind.kw_in: # type: ignore raise Exception("Parser: Expected 'in' after for") self.tokens.get_next_token() # eat 'in'. - body_block: ast.BlockAST = ast.BlockAST() + body_block: astx.Block = astx.Block() body_block.nodes.append(self.parse_expression()) - return ast.ForStmtAST(id_name, start, end, step, body_block) + return astx.ForRangeLoop(id_name, start, end, step, body_block) - def parse_var_expr(self) -> ast.VarExprAST: + def parse_var_expr(self) -> astx.VariableDeclaration: """ Parse the `var` declaration expression. Returns ------- - ast.VarExprAST + astx.VariableDeclaration The parsed `var` expression, or None if parsing fails. """ self.tokens.get_next_token() # eat the var. - var_names: List[Tuple[str, ast.ExprAST]] = [] + var_names: List[Tuple[str, astx.Expr]] = [] # At least one variable name is required. # if self.tokens.cur_tok.kind != TokenKind.identifier: @@ -384,7 +387,7 @@ def parse_var_expr(self) -> ast.VarExprAST: self.tokens.get_next_token() # eat identifier. # Read the optional initializer. # - Init: ast.ExprAST + Init: astx.Expr if self.tokens.cur_tok == Token( kind=TokenKind.operator, value="=" ): @@ -392,7 +395,7 @@ def parse_var_expr(self) -> ast.VarExprAST: Init = self.parse_expression() else: - Init = ast.FloatExprAST(0.0) + Init = astx.LiteralInt32(0) var_names.append((name, Init)) @@ -411,16 +414,16 @@ def parse_var_expr(self) -> ast.VarExprAST: raise Exception("Parser: Expected 'in' keyword after 'var'") self.tokens.get_next_token() # eat 'in'. - body: ast.ExprAST = self.parse_expression() - return ast.VarExprAST(var_names, "float", body) + body: astx.Expr = self.parse_expression() + return astx.VariableDeclaration(var_names, "int32", body) - def parse_unary(self) -> ast.ExprAST: + def parse_unary(self) -> astx.DataType: """ Parse a unary expression. Returns ------- - ast.ExprAST + astx.Expr The parsed unary expression, or None if parsing fails. """ # If the current token is not an operator, it must be a primary expr. @@ -428,19 +431,19 @@ def parse_unary(self) -> ast.ExprAST: self.tokens.cur_tok.kind != TokenKind.operator or self.tokens.cur_tok.value in ("(", ",") ): - return self.parse_primary() + return cast(astx.DataType, self.parse_primary()) # If this is a unary operator, read it. op_code: str = self.tokens.cur_tok.value self.tokens.get_next_token() - operand: ast.ExprAST = self.parse_unary() - return ast.UnaryExprAST(op_code, operand) + operand: astx.DataType = self.parse_unary() + return cast(astx.DataType, astx.UnaryOp(op_code, operand)) def parse_bin_op_rhs( self, expr_prec: int, - lhs: ast.ExprAST, - ) -> ast.ExprAST: + lhs: astx.DataType, + ) -> astx.DataType: """ Parse a binary expression. @@ -448,12 +451,12 @@ def parse_bin_op_rhs( ---------- expr_prec : int Expression precedence (deprecated). - lhs : ast.ExprAST + lhs : astx.DataType Left-hand side expression. Returns ------- - ast.ExprAST + astx.DataType The parsed binary expression, or None if parsing fails. """ # If this is a binop, find its precedence. # @@ -463,7 +466,7 @@ def parse_bin_op_rhs( # If this is a binop that binds at least as tightly as the current # binop, consume it, otherwise we are done. if cur_prec < expr_prec: - return lhs + return cast(astx.DataType, lhs) # Okay, we know this is a binop. bin_op: str = self.tokens.cur_tok.value @@ -471,7 +474,7 @@ def parse_bin_op_rhs( self.tokens.get_next_token() # eat binop # Parse the unary expression after the binary operator. - rhs: ast.ExprAST = self.parse_unary() + rhs: astx.DataType = self.parse_unary() # If BinOp binds less tightly with rhs than the operator after # rhs, let the pending operator take rhs as its lhs @@ -480,15 +483,17 @@ def parse_bin_op_rhs( rhs = self.parse_bin_op_rhs(cur_prec + 1, rhs) # Merge lhs/rhs. - lhs = ast.BinaryExprAST(bin_loc, bin_op, lhs, rhs) + lhs = cast( + astx.DataType, astx.BinaryOp(bin_op, lhs, rhs, loc=bin_loc) + ) - def parse_prototype(self) -> ast.PrototypeAST: + def parse_prototype(self) -> astx.FunctionPrototype: """ Parse the prototype expression. Returns ------- - ast.PrototypeAST + astx.FunctionPrototype The parsed prototype, or None if parsing fails. """ fn_name: str @@ -508,17 +513,15 @@ def parse_prototype(self) -> ast.PrototypeAST: if self.tokens.cur_tok != Token(kind=TokenKind.operator, value="("): raise Exception("Parser: Expected '(' in the function definition.") - args: List[ast.VariableExprAST] = [] + args: List[astx.Variable] = [] while self.tokens.get_next_token().kind == TokenKind.identifier: # note: this is a workaround identifier_name = self.tokens.cur_tok.value cur_loc = self.tokens.cur_tok.location - var_typing = "float" + var_typing = "int32" - args.append( - ast.VariableExprAST(cur_loc, identifier_name, var_typing) - ) + args.append(astx.Variable(cur_loc, identifier_name, var_typing)) if self.tokens.get_next_token() != Token( kind=TokenKind.operator, value="," @@ -531,22 +534,22 @@ def parse_prototype(self) -> ast.PrototypeAST: # success. # self.tokens.get_next_token() # eat ')'. - ret_typing = "float" + ret_typing = "int32" if self.tokens.cur_tok != Token(kind=TokenKind.operator, value=":"): raise Exception("Parser: Expected ':' in the function definition") self.tokens.get_next_token() # eat ':'. - return ast.PrototypeAST(fn_loc, fn_name, ret_typing, args) + return astx.FunctionPrototype(fn_loc, fn_name, ret_typing, args) - def parse_extern_prototype(self) -> ast.PrototypeAST: + def parse_extern_prototype(self) -> astx.FunctionPrototype: """ Parse an extern prototype expression. Returns ------- - ast.PrototypeAST + astx.FunctionPrototype The parsed extern prototype, or None if parsing fails. """ fn_name: str @@ -566,17 +569,15 @@ def parse_extern_prototype(self) -> ast.PrototypeAST: if self.tokens.cur_tok != Token(kind=TokenKind.operator, value="("): raise Exception("Parser: Expected '(' in the function definition.") - args: List[ast.VariableExprAST] = [] + args: List[astx.Variable] = [] while self.tokens.get_next_token().kind == TokenKind.identifier: # note: this is a workaround identifier_name = self.tokens.cur_tok.value cur_loc = self.tokens.cur_tok.location - var_typing = "float" + var_typing = "int32" - args.append( - ast.VariableExprAST(cur_loc, identifier_name, var_typing) - ) + args.append(astx.Variable(cur_loc, identifier_name, var_typing)) if self.tokens.get_next_token() != Token( kind=TokenKind.operator, value="," @@ -589,18 +590,18 @@ def parse_extern_prototype(self) -> ast.PrototypeAST: # success. # self.tokens.get_next_token() # eat ')'. - ret_typing = "float" + ret_typing = "int32" - return ast.PrototypeAST(fn_loc, fn_name, ret_typing, args) + return astx.FunctionPrototype(fn_name, ret_typing, args, loc=fn_loc) - def parse_return_function(self) -> ast.ReturnStmtAST: + def parse_return_function(self) -> astx.FunctionReturn: """ Parse the return expression. Returns ------- - ast.ReturnStmtAST + astx.FunctionReturn The parsed return expression, or None if parsing fails. """ self.tokens.get_next_token() # eat return - return ast.ReturnStmtAST(self.parse_expression()) + return astx.FunctionReturn(self.parse_expression()) diff --git a/tests/test_parser.py b/tests/test_parser.py index 2065573..ecaad27 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,4 +1,4 @@ -from arx import ast +import astx from arx.io import ArxIO from arx.lexer import Lexer from arx.parser import Parser @@ -17,30 +17,30 @@ def test_binop_precedence() -> None: assert parser.bin_op_precedence["*"] == 40 -def test_parse_float_expr() -> None: +def test_parse_int32_expr() -> None: """Test gettok for main tokens""" ArxIO.string_to_buffer("1 2") lexer = Lexer() parser = Parser(lexer.lex()) parser.tokens.get_next_token() - expr = parser.parse_float_expr() + expr = parser.parse_int32_expr() assert expr - assert isinstance(expr, ast.FloatExprAST) - assert expr.value == 1.0 + assert isinstance(expr, astx.LiteralInt32) + assert expr.value == 1 - expr = parser.parse_float_expr() + expr = parser.parse_int32_expr() assert expr - assert isinstance(expr, ast.FloatExprAST) + assert isinstance(expr, astx.LiteralInt32) assert expr.value == 2 ArxIO.string_to_buffer("3") parser = Parser(lexer.lex()) tok = parser.tokens.get_next_token() - expr = parser.parse_float_expr() + expr = parser.parse_int32_expr() assert expr - assert isinstance(expr, ast.FloatExprAST) + assert isinstance(expr, astx.LiteralInt32) assert expr.value == 3 @@ -54,7 +54,7 @@ def test_parse() -> None: expr = parser.parse(lexer.lex()) assert expr - assert isinstance(expr, ast.BlockAST) + assert isinstance(expr, astx.Block) def test_parse_if_stmt() -> None: @@ -69,12 +69,12 @@ def test_parse_if_stmt() -> None: parser.tokens.get_next_token() expr = parser.parse_primary() assert expr - assert isinstance(expr, ast.IfStmtAST) - assert isinstance(expr.cond, ast.BinaryExprAST) - assert isinstance(expr.then_, ast.BlockAST) - assert isinstance(expr.then_.nodes[0], ast.BinaryExprAST) - assert isinstance(expr.else_, ast.BlockAST) - assert isinstance(expr.else_.nodes[0], ast.BinaryExprAST) + assert isinstance(expr, astx.If) + assert isinstance(expr.condition, astx.BinaryOp) + assert isinstance(expr.then, astx.Block) + assert isinstance(expr.then.nodes[0], astx.BinaryOp) + assert isinstance(expr.else_, astx.Block) + assert isinstance(expr.else_.nodes[0], astx.BinaryOp) def test_parse_fn() -> None: @@ -94,16 +94,16 @@ def test_parse_fn() -> None: parser.tokens.get_next_token() expr = parser.parse_function() assert expr - assert isinstance(expr, ast.FunctionAST) - assert isinstance(expr.proto, ast.PrototypeAST) - assert isinstance(expr.proto.args[0], ast.VariableExprAST) - assert isinstance(expr.body, ast.BlockAST) - assert isinstance(expr.body.nodes[0], ast.IfStmtAST) - assert isinstance(expr.body.nodes[0].cond, ast.BinaryExprAST) - assert isinstance(expr.body.nodes[0].then_, ast.BlockAST) - assert isinstance(expr.body.nodes[0].then_.nodes[0], ast.BinaryExprAST) - assert isinstance(expr.body.nodes[0].else_, ast.BlockAST) - assert isinstance(expr.body.nodes[0].else_.nodes[0], ast.BinaryExprAST) - assert isinstance(expr.body.nodes[1], ast.ReturnStmtAST) - assert isinstance(expr.body.nodes[1].value, ast.VariableExprAST) + assert isinstance(expr, astx.Function) + assert isinstance(expr.prototype, astx.FunctionPrototype) + assert isinstance(expr.prototype.args[0], astx.Variable) + assert isinstance(expr.body, astx.Block) + assert isinstance(expr.body.nodes[0], astx.If) + assert isinstance(expr.body.nodes[0].condition, astx.BinaryOp) + assert isinstance(expr.body.nodes[0].then, astx.Block) + assert isinstance(expr.body.nodes[0].then.nodes[0], astx.BinaryOp) + assert isinstance(expr.body.nodes[0].else_, astx.Block) + assert isinstance(expr.body.nodes[0].else_.nodes[0], astx.BinaryOp) + assert isinstance(expr.body.nodes[1], astx.FunctionReturn) + assert isinstance(expr.body.nodes[1].value, astx.Variable) assert expr.body.nodes[1].value.name == "a"