Skip to content

Commit

Permalink
fix: Improve the usage of typechecked (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab authored Oct 23, 2024
1 parent 031de47 commit 97fb87c
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 100 deletions.
44 changes: 35 additions & 9 deletions src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
from abc import abstractmethod
from enum import Enum
from hashlib import sha256
from typing import ClassVar, Optional, Type, Union, cast
from typing import ClassVar, Dict, List, Optional, Type, Union, cast

from typeguard import typechecked

from astx.types import ReprStruct
from astx.viz import graph_to_ascii, traverse_ast_ascii

try:
from typing_extensions import TypeAlias
except ImportError:
Expand All @@ -30,7 +27,13 @@

from public import public

__all__ = ["ExprType"]
__all__ = [
"ExprType",
"PrimitivesStruct",
"DataTypesStruct",
"DictDataTypesStruct",
"ReprStruct",
]


def is_using_jupyter_notebook() -> bool:
Expand All @@ -46,11 +49,11 @@ def is_using_jupyter_notebook() -> bool:


@public
@typechecked
class SourceLocation:
line: int
col: int

@typechecked
def __init__(self, line: int, col: int):
self.line = line
self.col = col
Expand Down Expand Up @@ -138,6 +141,7 @@ def __str__(cls) -> str:


@public
@typechecked
class AST(metaclass=ASTMeta):
"""AST main expression class."""

Expand All @@ -147,7 +151,6 @@ class AST(metaclass=ASTMeta):
parent: Optional[ASTNodes] = None
ref: str

@typechecked
def __init__(
self,
loc: SourceLocation = NO_SOURCE_LOCATION,
Expand All @@ -172,6 +175,8 @@ def __str__(self) -> str:
def __repr__(self) -> str:
"""Return an string that represents the object."""
if not is_using_jupyter_notebook():
from astx.viz import graph_to_ascii, traverse_ast_ascii

graph = traverse_ast_ascii(self.get_struct(simplified=True))
return graph_to_ascii(graph)
return ""
Expand Down Expand Up @@ -233,14 +238,14 @@ def to_json(self, simplified: bool = False) -> str:


@public
@typechecked
class ASTNodes(AST):
"""AST with a list of nodes."""

name: str
nodes: list[AST]
position: int = 0

@typechecked
def __init__(
self,
name: str = "entry",
Expand Down Expand Up @@ -284,6 +289,7 @@ def __len__(self) -> int:


@public
@typechecked
class Expr(AST):
"""AST main expression class."""

Expand All @@ -292,8 +298,26 @@ class Expr(AST):

ExprType: TypeAlias = Type[Expr]

PrimitivesStruct: TypeAlias = Union[
int,
str,
float,
bool,
"astx.base.Undefined", # type: ignore[name-defined] # noqa: F821
]
DataTypesStruct: TypeAlias = Union[
PrimitivesStruct, Dict[str, "DataTypesStruct"], List["DataTypesStruct"]
]
DictDataTypesStruct: TypeAlias = Dict[str, DataTypesStruct]
ReprStruct: TypeAlias = Union[
List[DataTypesStruct],
DictDataTypesStruct,
"astx.base.Undefined", # type: ignore[name-defined] # noqa: F821
]


@public
@typechecked
class Undefined(Expr):
"""Undefined expression class."""

Expand All @@ -305,14 +329,14 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class DataType(Expr):
"""AST main expression class."""

type_: ExprType
name: str
_tmp_id: ClassVar[int] = 0

@typechecked
def __init__(
self,
loc: SourceLocation = NO_SOURCE_LOCATION,
Expand All @@ -337,10 +361,12 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class OperatorType(DataType):
"""AST main expression class."""


@public
@typechecked
class StatementType(AST):
"""AST main expression class."""
4 changes: 3 additions & 1 deletion src/astx/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from typing import cast

from public import public
from typeguard import typechecked

from astx.base import (
ASTNodes,
ReprStruct,
)
from astx.types import ReprStruct


@public
@typechecked
class Block(ASTNodes):
"""The AST tree."""

Expand Down
16 changes: 8 additions & 8 deletions src/astx/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
DataType,
Expr,
ExprType,
ReprStruct,
SourceLocation,
StatementType,
Undefined,
)
from astx.blocks import Block
from astx.modifiers import MutabilityKind, ScopeKind, VisibilityKind
from astx.types import ReprStruct
from astx.variables import Variable

UNDEFINED = Undefined()


@public
@typechecked
class Argument(Variable):
"""AST class for argument definition."""

Expand All @@ -35,7 +36,6 @@ class Argument(Variable):
type_: ExprType
default: Expr

@typechecked
def __init__(
self,
name: str,
Expand Down Expand Up @@ -65,10 +65,10 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class Arguments(ASTNodes):
"""AST class for argument definition."""

@typechecked
def __init__(self, *args: Argument, **kwargs: Any) -> None:
super().__init__(**kwargs)
for arg in args:
Expand All @@ -91,13 +91,13 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class FunctionCall(DataType):
"""AST class for function call."""

fn: Function
args: tuple[DataType, ...]

@typechecked
def __init__(
self,
fn: Function,
Expand Down Expand Up @@ -138,6 +138,7 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class FunctionPrototype(StatementType):
"""AST class for function prototype declaration."""

Expand All @@ -147,7 +148,6 @@ class FunctionPrototype(StatementType):
scope: ScopeKind
visibility: VisibilityKind

@typechecked
def __init__(
self,
name: str,
Expand All @@ -174,12 +174,12 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class FunctionReturn(StatementType):
"""AST class for function `return` statement."""

value: DataType

@typechecked
def __init__(
self,
value: DataType,
Expand All @@ -203,13 +203,13 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class Function(StatementType):
"""AST class for function definition."""

prototype: FunctionPrototype
body: Block

@typechecked
def __init__(
self,
prototype: FunctionPrototype,
Expand Down Expand Up @@ -255,13 +255,13 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class LambdaExpr(Expr):
"""AST class for lambda expressions."""

params: Arguments = Arguments()
body: Expr

@typechecked
def __init__(
self,
body: Expr,
Expand Down
Loading

0 comments on commit 97fb87c

Please sign in to comment.