Skip to content

Commit

Permalink
feat: Add support to Import, and ImportFrom statement and Alias
Browse files Browse the repository at this point in the history
… expression (#118)
  • Loading branch information
apkrelling authored Oct 7, 2024
1 parent 225f398 commit 617f506
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/astx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,15 @@
BinaryOp,
UnaryOp,
)
from astx.packages import Module, Package, Program, Target
from astx.packages import (
AliasExpr,
ImportFromStmt,
ImportStmt,
Module,
Package,
Program,
Target,
)
from astx.variables import (
InlineVariableDeclaration,
Variable,
Expand All @@ -109,6 +117,7 @@ def get_version() -> str:


__all__ = [
"AliasExpr",
"Argument",
"Arguments",
"AST",
Expand Down Expand Up @@ -140,6 +149,8 @@ def get_version() -> str:
"FunctionReturn",
"get_version",
"If",
"ImportStmt",
"ImportFromStmt",
"InlineVariableDeclaration",
"Int16",
"Int32",
Expand Down
5 changes: 5 additions & 0 deletions src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ class ASTKind(Enum):
Decimal128DTKind = -620
Decimal256DTKind = -621

# imports
ImportStmtKind = -700
ImportFromStmtKind = -701
AliasExprKind = -702


class ASTMeta(type):
def __str__(cls) -> str:
Expand Down
124 changes: 123 additions & 1 deletion src/astx/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import copy

from typing import cast
from typing import Optional, cast

from public import public

Expand All @@ -15,6 +15,7 @@
ASTNodes,
Expr,
SourceLocation,
StatementType,
)
from astx.blocks import Block
from astx.types import ReprStruct
Expand Down Expand Up @@ -148,3 +149,124 @@ def __init__(
def __str__(self) -> str:
"""Return the string representation of the object."""
return f"PROGRAM[{self.name}]"


@public
class AliasExpr(Expr):
"""Represents an alias in an import statement."""

name: str
asname: Optional[str]

def __init__(
self,
name: str,
asname: str = "",
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
self.name = name
self.asname = asname
self.kind = ASTKind.AliasExprKind

def __str__(self) -> str:
"""Return a string representation of the alias."""
if self.asname:
return f"{self.name} as {self.asname}"
else:
return self.name

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the alias."""
key = "Alias"

name_dict = {"name": self.name}
asname_dict = {"asname": self.asname} if self.asname else {}
value: ReprStruct = {
**name_dict,
**asname_dict,
}
return self._prepare_struct(key, value, simplified)


@public
class ImportStmt(StatementType):
"""Represents an import statement."""

names: list[AliasExpr]

def __init__(
self,
names: list[AliasExpr],
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
self.names = names
self.kind = ASTKind.ImportStmtKind

def __str__(self) -> str:
"""Return a string representation of the import statement."""
names_str = ", ".join(str(name) for name in self.names)
return f"import {names_str}"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the import statement."""
key = "Import"
value = cast(
ReprStruct, [name.get_struct(simplified) for name in self.names]
)
return self._prepare_struct(key, value, simplified)


@public
class ImportFromStmt(StatementType):
"""Represents an import-from statement."""

module: Optional[str]
names: list[AliasExpr]
level: int

def __init__(
self,
names: list[AliasExpr],
module: str = "",
level: int = 0,
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
self.module = module
self.names = names
self.level = level
self.kind = ASTKind.ImportFromStmtKind

def __str__(self) -> str:
"""Return a string representation of the import-from statement."""
level_dots = "." * self.level
module_str = (
f"{level_dots}{self.module}" if self.module else level_dots
)
names_str = ", ".join(str(name) for name in self.names)
return f"from {module_str} import {names_str}"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the import-from statement."""
key = "ImportFrom"

module_dict = {"module": self.module} if self.module else {}
level_dict = {"level": self.level}
names_values = cast(
ReprStruct,
[name.get_struct(simplified) for name in self.names],
)
names_dict = {"names": names_values}

value: ReprStruct = {
**module_dict,
**level_dict,
**names_dict,
}

return self._prepare_struct(key, value, simplified)
25 changes: 25 additions & 0 deletions src/astx/transpilers/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def visit(self, expr: astx.AST) -> str:
"""Translate an ASTx expression."""
raise Exception(f"Not implemented yet ({expr}).")

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.AliasExpr) -> str:
"""Handle AliasExpr nodes."""
if node.asname:
return f"{node.name} as {node.asname}"
return f"{node.name}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.Argument) -> str:
"""Handle UnaryOp nodes."""
Expand All @@ -62,6 +69,24 @@ def visit(self, node: astx.Block) -> str:
"""Handle Block nodes."""
return self._generate_block(node)

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ImportFromStmt) -> str:
"""Handle ImportFromStmt nodes."""
names = [self.visit(name) for name in node.names]
level_dots = "." * node.level
module_str = (
f"{level_dots}{node.module}" if node.module else level_dots
)
names_str = ", ".join(str(name) for name in names)
return f"from {module_str} import {names_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ImportStmt) -> str:
"""Handle ImportStmt nodes."""
names = [self.visit(name) for name in node.names]
names_str = ", ".join(x for x in names)
return f"import {names_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: Type[astx.Int32]) -> str:
"""Handle Int32 nodes."""
Expand Down
53 changes: 52 additions & 1 deletion tests/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

from astx.datatypes import Int32, LiteralInt32
from astx.operators import BinaryOp
from astx.packages import Module, Package, Program, Target
from astx.packages import (
AliasExpr,
ImportFromStmt,
ImportStmt,
Module,
Package,
Program,
Target,
)
from astx.variables import Variable, VariableDeclaration
from astx.viz import visualize

Expand Down Expand Up @@ -83,3 +91,46 @@ def test_program() -> None:
assert program.get_struct(simplified=True)

visualize(program.get_struct())


def test_multiple_imports() -> None:
"""Test ImportStmt multiple imports."""
alias1 = AliasExpr(name="math")
alias2 = AliasExpr(name="matplotlib", asname="mtlb")

# Create an import statement
import_stmt = ImportStmt(names=[alias1, alias2])

assert import_stmt.get_struct()
assert import_stmt.get_struct(simplified=True)


def test_import_from() -> None:
"""Test ImportFromStmt importing from module."""
alias = AliasExpr(name="pyplot", asname="plt")

import_from_stmt = ImportFromStmt(
module="matplotlib", names=[alias], level=1
)

assert import_from_stmt.get_struct()
assert import_from_stmt.get_struct(simplified=True)


def test_wildcard_import_from() -> None:
"""Test ImportFromStmt wildcard import from module."""
alias = AliasExpr(name="*")

import_from_stmt = ImportFromStmt(module="matplotlib", names=[alias])

assert import_from_stmt.get_struct()
assert import_from_stmt.get_struct(simplified=True)


def test_future_import_from() -> None:
"""Test ImportFromStmt from future import."""
alias = AliasExpr(name="division")

import_from_stmt = ImportFromStmt(module="__future__", names=[alias])
assert import_from_stmt.get_struct()
assert import_from_stmt.get_struct(simplified=True)
77 changes: 76 additions & 1 deletion tests/transpilers/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,82 @@
from astx.transpilers import python as astx2py


def test_function() -> None:
def test_transpiler_multiple_imports() -> None:
"""Test astx.ImportStmt multiple imports."""
alias1 = astx.AliasExpr(name="math")
alias2 = astx.AliasExpr(name="matplotlib", asname="mtlb")

# Create an import statement
import_stmt = astx.ImportStmt(names=[alias1, alias2])

# Initialize the generator
generator = astx2py.ASTxPythonTranspiler()

# Generate Python code
generated_code = generator.visit(import_stmt)

expected_code = "import math, matplotlib as mtlb"

assert generated_code == expected_code, "generated_code != expected_code"


def test_transpiler_import_from() -> None:
"""Test astx.ImportFromStmt importing from module."""
alias = astx.AliasExpr(name="pyplot", asname="plt")

import_from_stmt = astx.ImportFromStmt(
module="matplotlib", names=[alias], level=0
)

# Initialize the generator
generator = astx2py.ASTxPythonTranspiler()

# Generate Python code
generated_code = generator.visit(import_from_stmt)

# print generated code
generated_code

expected_code = "from matplotlib import pyplot as plt"

assert generated_code == expected_code, "generated_code != expected_code"


def test_transpiler_wildcard_import_from() -> None:
"""Test astx.ImportFromStmt wildcard import from module."""
alias = astx.AliasExpr(name="*")

import_from_stmt = astx.ImportFromStmt(module="matplotlib", names=[alias])

# Initialize the generator
generator = astx2py.ASTxPythonTranspiler()

# Generate Python code
generated_code = generator.visit(import_from_stmt)

expected_code = "from matplotlib import *"

assert generated_code == expected_code, "generated_code != expected_code"


def test_transpiler_future_import_from() -> None:
"""Test astx.ImportFromStmt from future import."""
alias = astx.AliasExpr(name="division")

import_from_stmt = astx.ImportFromStmt(module="__future__", names=[alias])

# Initialize the generator
generator = astx2py.ASTxPythonTranspiler()

# Generate Python code
generated_code = generator.visit(import_from_stmt)

expected_code = "from __future__ import division"

assert generated_code == expected_code, "generated_code != expected_code"


def test_transpiler_function() -> None:
"""Test astx.Function."""
# Function parameters
args = astx.Arguments(
Expand Down

0 comments on commit 617f506

Please sign in to comment.