Skip to content

Commit

Permalink
feat: ImportExpr and ImportFromExpr (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
apkrelling authored Oct 15, 2024
1 parent 8fdb774 commit 7785ba7
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 34 deletions.
2 changes: 1 addition & 1 deletion conda/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
- conda-forge
dependencies:
- graphviz
- python >=3.9,<4
- python >=3.9,<3.13
- openjdk
- poetry >=1.5
- nodejs >=18.17 # used by semantic-release
Expand Down
4 changes: 4 additions & 0 deletions src/astx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
)
from astx.packages import (
AliasExpr,
ImportExpr,
ImportFromExpr,
ImportFromStmt,
ImportStmt,
Module,
Expand Down Expand Up @@ -149,6 +151,8 @@ def get_version() -> str:
"FunctionReturn",
"get_version",
"If",
"ImportFromExpr",
"ImportExpr",
"ImportStmt",
"ImportFromStmt",
"InlineVariableDeclaration",
Expand Down
2 changes: 2 additions & 0 deletions src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class ASTKind(Enum):
ImportStmtKind = -700
ImportFromStmtKind = -701
AliasExprKind = -702
ImportExprKind = -800
ImportFromExprKind = -801


class ASTMeta(type):
Expand Down
114 changes: 92 additions & 22 deletions src/astx/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class AliasExpr(Expr):
"""Represents an alias in an import statement."""

name: str
asname: Optional[str]
asname: str

@typechecked
def __init__(
Expand All @@ -185,14 +185,11 @@ def __str__(self) -> str:

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the alias."""
key = "Alias"

name_dict = {"name": self.name}
asname_dict = {"asname": self.asname} if self.asname else {}
value: ReprStruct = {
**name_dict,
**asname_dict,
}
str_asname = f", {self.asname}" if self.asname else ""
str_name_asname = f"[{self.name}{str_asname}]"
key = f"Alias {str_name_asname}"
value = ""

return self._prepare_struct(key, value, simplified)


Expand Down Expand Up @@ -220,7 +217,7 @@ def __str__(self) -> str:

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the import statement."""
key = "Import"
key = "ImportStmt"
value = cast(
ReprStruct, [name.get_struct(simplified) for name in self.names]
)
Expand Down Expand Up @@ -261,20 +258,93 @@ def __str__(self) -> str:

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the import-from statement."""
key = "ImportFrom"
level_dots = "." * self.level
module_str = (
f"{level_dots}{self.module}" if self.module else level_dots
)

module_dict = {"module": self.module} if self.module else {}
level_dict = {"level": self.level}
names_values = cast(
ReprStruct,
[name.get_struct(simplified) for name in self.names],
key = f"ImportFromStmt [{module_str}]"
value = cast(
ReprStruct, [name.get_struct(simplified) for name in self.names]
)

return self._prepare_struct(key, value, simplified)


@public
class ImportExpr(Expr):
"""Represents an import operation as an expression."""

names: list[AliasExpr]

@typechecked
def __init__(
self,
names: list[AliasExpr],
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
self.names = names
self.kind = ASTKind.ImportExprKind

def __str__(self) -> str:
"""Return a string representation of the import expression."""
names_str = ", ".join(str(name) for name in self.names)
return f"import {names_str}"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the import expression."""
key = "ImportExpr"
value = cast(
ReprStruct, [name.get_struct(simplified) for name in self.names]
)
return self._prepare_struct(key, value, simplified)


@public
class ImportFromExpr(Expr):
"""Represents a 'from ... import ...' operation as an expression."""

module: str
names: list[AliasExpr]
level: int # Number of leading dots for relative imports

@typechecked
def __init__(
self,
names: list[AliasExpr],
module: str = "",
level: int = 0,
loc: SourceLocation = NO_SOURCE_LOCATION,
parent: Optional[ASTNodes] = None,
) -> None:
super().__init__(loc=loc, parent=parent)
self.names = names
self.module = module
self.level = level
self.kind = ASTKind.ImportFromExprKind

def __str__(self) -> str:
"""Return a string representation of the import-from expression."""
level_dots = "." * self.level
module_str = (
f"{level_dots}{self.module}" if self.module else level_dots
)
names_str = ", ".join(str(name) for name in self.names)

return f"from {module_str} import {names_str}"

def get_struct(self, simplified: bool = False) -> ReprStruct:
"""Return the AST structure of the import-from expression."""
level_dots = "." * self.level
module_str = (
f"{level_dots}{self.module}" if self.module else level_dots
)
names_dict = {"names": names_values}

value: ReprStruct = {
**module_dict,
**level_dict,
**names_dict,
}
key = f"ImportFromExpr [{module_str}]"
value = cast(
ReprStruct, [name.get_struct(simplified) for name in self.names]
)

return self._prepare_struct(key, value, simplified)
55 changes: 55 additions & 0 deletions src/astx/transpilers/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,61 @@ def visit(self, node: astx.ImportStmt) -> str:
names_str = ", ".join(x for x in names)
return f"import {names_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ImportFromExpr) -> str:
"""Handle ImportFromExpr nodes."""
names = [self.visit(name) for name in node.names]
level_dots = "." * node.level
module_str = (
f"{level_dots}{node.module}" if node.module else level_dots
)
names_list = []
for name in names:
str_ = (
f"getattr(__import__('{module_str}', "
f"fromlist=['{name}']), '{name}')"
)
names_list.append(str_)
names_str = ", ".join(x for x in names_list)

# name if one import or name1, name2, etc if multiple imports
num = [
"" if len(names) == 1 else str(n) for n in range(1, len(names) + 1)
]
call = ["name" + str(n) for n in num]
call_str = ", ".join(x for x in call)

# assign tuple if multiple imports
names_str = (
names_str if len(names_list) == 1 else "(" + names_str + ")"
)

return f"{call_str} = {names_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: astx.ImportExpr) -> str:
"""Handle ImportExpr nodes."""
names = [self.visit(name) for name in node.names]
names_list = []
for name in names:
str_ = f"__import__('{name}') "
names_list.append(str_)
names_str = ", ".join(x for x in names_list)

# name if one import or name1, name2, etc if multiple imports
num = [
"" if len(names) == 1 else str(n) for n in range(1, len(names) + 1)
]
call = ["module" + str(n) for n in num]
call_str = ", ".join(x for x in call)

# assign tuple if multiple imports
names_str = (
names_str if len(names_list) == 1 else "(" + names_str + ")"
)

return f"{call_str} = {names_str}"

@dispatch # type: ignore[no-redef]
def visit(self, node: Type[astx.Int32]) -> str:
"""Handle Int32 nodes."""
Expand Down
62 changes: 58 additions & 4 deletions tests/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from astx.operators import BinaryOp
from astx.packages import (
AliasExpr,
ImportExpr,
ImportFromExpr,
ImportFromStmt,
ImportStmt,
Module,
Expand Down Expand Up @@ -93,7 +95,7 @@ def test_program() -> None:
visualize(program.get_struct())


def test_multiple_imports() -> None:
def test_multiple_imports_stmt() -> None:
"""Test ImportStmt multiple imports."""
alias1 = AliasExpr(name="math")
alias2 = AliasExpr(name="matplotlib", asname="mtlb")
Expand All @@ -105,7 +107,7 @@ def test_multiple_imports() -> None:
assert import_stmt.get_struct(simplified=True)


def test_import_from() -> None:
def test_import_from_stmt() -> None:
"""Test ImportFromStmt importing from module."""
alias = AliasExpr(name="pyplot", asname="plt")

Expand All @@ -117,7 +119,7 @@ def test_import_from() -> None:
assert import_from_stmt.get_struct(simplified=True)


def test_wildcard_import_from() -> None:
def test_wildcard_import_from_stmt() -> None:
"""Test ImportFromStmt wildcard import from module."""
alias = AliasExpr(name="*")

Expand All @@ -127,10 +129,62 @@ def test_wildcard_import_from() -> None:
assert import_from_stmt.get_struct(simplified=True)


def test_future_import_from() -> None:
def test_future_import_from_stmt() -> None:
"""Test ImportFromStmt from future import."""
alias = AliasExpr(name="division")

import_from_stmt = ImportFromStmt(module="__future__", names=[alias])
assert import_from_stmt.get_struct()
assert import_from_stmt.get_struct(simplified=True)


def test_multiple_imports_expr() -> None:
"""Test ImportExpr multiple imports."""
alias1 = AliasExpr(name="sqrt", asname="square_root")
alias2 = AliasExpr(name="pi")

import_expr = ImportExpr([alias1, alias2])

assert import_expr.get_struct()
assert import_expr.get_struct(simplified=True)


def test_import_from_expr() -> None:
"""Test ImportFromExpr importing from module."""
alias1 = AliasExpr(name="sqrt", asname="square_root")

import_from_expr = ImportFromExpr(module="math", names=[alias1])

assert import_from_expr.get_struct()
assert import_from_expr.get_struct(simplified=True)


def test_wildcard_import_from_expr() -> None:
"""Test ImportFromExpr wildcard import from module."""
alias1 = AliasExpr(name="*")

import_from_expr = ImportFromExpr(module="math", names=[alias1])

assert import_from_expr.get_struct()
assert import_from_expr.get_struct(simplified=True)


def test_future_import_from_expr() -> None:
"""Test ImportFromExpr from future import."""
alias1 = AliasExpr(name="division")

import_from_expr = ImportFromExpr(module="__future__", names=[alias1])

assert import_from_expr.get_struct()
assert import_from_expr.get_struct(simplified=True)


def test_relative_import_from_expr() -> None:
"""Test ImportFromExpr relative imports."""
alias1 = AliasExpr(name="division")
alias2 = AliasExpr(name="matplotlib", asname="mtlb")

import_from_expr = ImportFromExpr(names=[alias1, alias2], level=1)

assert import_from_expr.get_struct()
assert import_from_expr.get_struct(simplified=True)
Loading

0 comments on commit 7785ba7

Please sign in to comment.