Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Improve the usage of typechecked #127

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions src/astx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
from abc import abstractmethod
from enum import Enum
from hashlib import sha256
from typing import ClassVar, Optional, Type, Union, cast
from typing import ClassVar, Dict, List, Optional, Type, Union, cast

from typeguard import typechecked

from astx.types import ReprStruct
from astx.viz import graph_to_ascii, traverse_ast_ascii

try:
from typing_extensions import TypeAlias
except ImportError:
Expand All @@ -30,7 +27,13 @@

from public import public

__all__ = ["ExprType"]
__all__ = [
"ExprType",
"PrimitivesStruct",
"DataTypesStruct",
"DictDataTypesStruct",
"ReprStruct",
]


def is_using_jupyter_notebook() -> bool:
Expand All @@ -46,11 +49,11 @@ def is_using_jupyter_notebook() -> bool:


@public
@typechecked
class SourceLocation:
line: int
col: int

@typechecked
def __init__(self, line: int, col: int):
self.line = line
self.col = col
Expand Down Expand Up @@ -138,6 +141,7 @@ def __str__(cls) -> str:


@public
@typechecked
class AST(metaclass=ASTMeta):
"""AST main expression class."""

Expand All @@ -147,7 +151,6 @@ class AST(metaclass=ASTMeta):
parent: Optional[ASTNodes] = None
ref: str

@typechecked
def __init__(
self,
loc: SourceLocation = NO_SOURCE_LOCATION,
Expand All @@ -172,6 +175,8 @@ def __str__(self) -> str:
def __repr__(self) -> str:
"""Return an string that represents the object."""
if not is_using_jupyter_notebook():
from astx.viz import graph_to_ascii, traverse_ast_ascii

graph = traverse_ast_ascii(self.get_struct(simplified=True))
return graph_to_ascii(graph)
return ""
Expand Down Expand Up @@ -233,14 +238,14 @@ def to_json(self, simplified: bool = False) -> str:


@public
@typechecked
class ASTNodes(AST):
"""AST with a list of nodes."""

name: str
nodes: list[AST]
position: int = 0

@typechecked
def __init__(
self,
name: str = "entry",
Expand Down Expand Up @@ -284,6 +289,7 @@ def __len__(self) -> int:


@public
@typechecked
class Expr(AST):
"""AST main expression class."""

Expand All @@ -292,8 +298,26 @@ class Expr(AST):

ExprType: TypeAlias = Type[Expr]

PrimitivesStruct: TypeAlias = Union[
int,
str,
float,
bool,
"astx.base.Undefined", # type: ignore[name-defined] # noqa: F821
]
DataTypesStruct: TypeAlias = Union[
PrimitivesStruct, Dict[str, "DataTypesStruct"], List["DataTypesStruct"]
]
DictDataTypesStruct: TypeAlias = Dict[str, DataTypesStruct]
ReprStruct: TypeAlias = Union[
List[DataTypesStruct],
DictDataTypesStruct,
"astx.base.Undefined", # type: ignore[name-defined] # noqa: F821
]


@public
@typechecked
class Undefined(Expr):
"""Undefined expression class."""

Expand All @@ -305,14 +329,14 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class DataType(Expr):
"""AST main expression class."""

type_: ExprType
name: str
_tmp_id: ClassVar[int] = 0

@typechecked
def __init__(
self,
loc: SourceLocation = NO_SOURCE_LOCATION,
Expand All @@ -337,10 +361,12 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class OperatorType(DataType):
"""AST main expression class."""


@public
@typechecked
class StatementType(AST):
"""AST main expression class."""
4 changes: 3 additions & 1 deletion src/astx/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from typing import cast

from public import public
from typeguard import typechecked

from astx.base import (
ASTNodes,
ReprStruct,
)
from astx.types import ReprStruct


@public
@typechecked
class Block(ASTNodes):
"""The AST tree."""

Expand Down
16 changes: 8 additions & 8 deletions src/astx/callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
DataType,
Expr,
ExprType,
ReprStruct,
SourceLocation,
StatementType,
Undefined,
)
from astx.blocks import Block
from astx.modifiers import MutabilityKind, ScopeKind, VisibilityKind
from astx.types import ReprStruct
from astx.variables import Variable

UNDEFINED = Undefined()


@public
@typechecked
class Argument(Variable):
"""AST class for argument definition."""

Expand All @@ -35,7 +36,6 @@ class Argument(Variable):
type_: ExprType
default: Expr

@typechecked
def __init__(
self,
name: str,
Expand Down Expand Up @@ -65,10 +65,10 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class Arguments(ASTNodes):
"""AST class for argument definition."""

@typechecked
def __init__(self, *args: Argument, **kwargs: Any) -> None:
super().__init__(**kwargs)
for arg in args:
Expand All @@ -91,13 +91,13 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class FunctionCall(DataType):
"""AST class for function call."""

fn: Function
args: tuple[DataType, ...]

@typechecked
def __init__(
self,
fn: Function,
Expand Down Expand Up @@ -138,6 +138,7 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class FunctionPrototype(StatementType):
"""AST class for function prototype declaration."""

Expand All @@ -147,7 +148,6 @@ class FunctionPrototype(StatementType):
scope: ScopeKind
visibility: VisibilityKind

@typechecked
def __init__(
self,
name: str,
Expand All @@ -174,12 +174,12 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class FunctionReturn(StatementType):
"""AST class for function `return` statement."""

value: DataType

@typechecked
def __init__(
self,
value: DataType,
Expand All @@ -203,13 +203,13 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class Function(StatementType):
"""AST class for function definition."""

prototype: FunctionPrototype
body: Block

@typechecked
def __init__(
self,
prototype: FunctionPrototype,
Expand Down Expand Up @@ -255,13 +255,13 @@ def get_struct(self, simplified: bool = False) -> ReprStruct:


@public
@typechecked
class LambdaExpr(Expr):
"""AST class for lambda expressions."""

params: Arguments = Arguments()
body: Expr

@typechecked
def __init__(
self,
body: Expr,
Expand Down
Loading
Loading