Skip to content

Commit

Permalink
feat: add macro imports
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Jul 13, 2022
1 parent 52916e0 commit e8ca750
Show file tree
Hide file tree
Showing 22 changed files with 474 additions and 236 deletions.
17 changes: 17 additions & 0 deletions bolt/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
"AstMacroMatchLiteral",
"AstMacroMatchArgument",
"AstInterpolation",
"AstFromImport",
"AstImportedItem",
"AstImportedMacro",
]


Expand Down Expand Up @@ -319,9 +321,24 @@ class AstInterpolation(AstNode):
value: AstExpression = required_field()


@dataclass(frozen=True)
class AstFromImport(AstCommand):
"""Ast from import node."""

identifier: str = ""


@dataclass(frozen=True)
class AstImportedItem(AstNode):
"""Ast imported item node."""

name: str = required_field()
identifier: bool = True


@dataclass(frozen=True)
class AstImportedMacro(AstNode):
"""Ast imported macro node."""

name: str = required_field()
declaration: AstMacro = required_field()
103 changes: 73 additions & 30 deletions bolt/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


from contextlib import contextmanager
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field, fields, replace
from typing import (
Any,
Dict,
Expand All @@ -24,6 +24,7 @@
Optional,
Tuple,
Type,
Union,
cast,
overload,
)
Expand All @@ -47,9 +48,11 @@
AstExpressionBinary,
AstExpressionUnary,
AstFormatString,
AstFromImport,
AstFunctionSignature,
AstIdentifier,
AstImportedItem,
AstImportedMacro,
AstInterpolation,
AstKeyword,
AstList,
Expand All @@ -67,7 +70,7 @@
AstUnpack,
AstValue,
)
from .module import CodegenResult
from .module import CodegenResult, MacroLibrary


@dataclass
Expand All @@ -76,6 +79,7 @@ class Accumulator:

indentation: str = ""
refs: List[Any] = field(default_factory=list)
macros: MacroLibrary = field(default_factory=dict)
macro_ids: Dict[str, int] = field(default_factory=dict)
lines: List[str] = field(default_factory=list)
counter: int = 0
Expand Down Expand Up @@ -156,6 +160,23 @@ def make_variable(self) -> str:
self.counter += 1
return name

def make_macro(self, node: AstMacro) -> str:
"""Add macro to macro library."""
macro = self.get_macro(node.identifier)
group = node.identifier.partition(":")[0]
self.macros.setdefault(group, {})[macro, node] = None
return macro

def import_macro(self, resource_location: str, node: AstImportedMacro) -> str:
"""Import macro into macro library."""
macro = self.get_macro(node.declaration.identifier)
group = node.declaration.identifier.partition(":")[0]
self.macros.setdefault(group, {})[macro, node.declaration] = (
resource_location,
node.name,
)
return macro

def get_macro(self, name: str) -> str:
"""Get macro."""
if name not in self.macro_ids:
Expand Down Expand Up @@ -189,7 +210,10 @@ def function(self, name: str, *args: str):
"""Emit function."""
self.statement(f"def {name}({', '.join(args)}):")
with self.block():
previous_macros = self.macros
self.macros = {k: dict(v) for k, v in self.macros.items()}
yield
self.macros = previous_macros

@contextmanager
def if_statement(self, condition: str):
Expand Down Expand Up @@ -412,15 +436,24 @@ class Codegen(Visitor):
def __call__(self, node: AstRoot) -> CodegenResult: # type: ignore
acc = Accumulator()
result = self.invoke(node, acc)

if result is None:
return CodegenResult(refs=acc.refs)
elif len(result) != 1:
return CodegenResult()

if len(result) != 1:
raise ValueError(
f"Expected single result for {node.__class__.__name__} {result!r}."
)

output = acc.make_variable()
acc.statement(f"{output} = {result[0]}")
return CodegenResult(source=acc.get_source(), output=output, refs=acc.refs)

return CodegenResult(
source=acc.get_source(),
output=output,
refs=acc.refs,
macros=acc.macros,
)

@rule(AstNode)
def fallback(
Expand Down Expand Up @@ -524,7 +557,9 @@ def macro(
node: AstMacro,
acc: Accumulator,
) -> Generator[AstNode, Optional[List[str]], Optional[List[str]]]:
macro = acc.get_macro(node.identifier)
macro = acc.make_macro(
replace(node, arguments=AstChildren(node.arguments[:-1]))
)
arguments = [
arg.match_identifier.value
for arg in node.arguments
Expand Down Expand Up @@ -666,37 +701,14 @@ def with_statement(

@rule(AstCommand, identifier="import:module")
@rule(AstCommand, identifier="import:module:as:alias")
@rule(AstCommand, identifier="from:module:import:subcommand")
def import_statement(
self,
node: AstCommand,
acc: Accumulator,
) -> Optional[List[str]]:
module = cast(AstResourceLocation, node.arguments[0])

if node.identifier == "from:module:import:subcommand":
names: List[str] = []
subcommand = cast(AstCommand, node.arguments[1])

while True:
if isinstance(item := subcommand.arguments[0], AstImportedItem):
names.append(item.name)
if subcommand.identifier == "from:module:import:name:subcommand":
subcommand = cast(AstCommand, subcommand.arguments[1])
else:
break

if module.namespace:
acc.statement(
f"{', '.join(names)} = _bolt_runtime.from_module_import({module.get_value()!r}, {', '.join(map(repr, names))})",
lineno=node,
)
else:
for name in names:
rhs = acc.get_attribute(acc.import_module(module.path), name)
acc.statement(f"{name} = {rhs}", lineno=node)

elif node.identifier == "import:module:as:alias":
if node.identifier == "import:module:as:alias":
alias = cast(AstImportedItem, node.arguments[1]).name

if module.namespace:
Expand All @@ -719,6 +731,37 @@ def import_statement(

return []

@rule(AstFromImport)
def from_import_statement(
self,
node: AstFromImport,
acc: Accumulator,
) -> Optional[List[str]]:
module = cast(AstResourceLocation, node.arguments[0])
items = cast(
AstChildren[Union[AstImportedItem, AstImportedMacro]], node.arguments[1:]
)
names = [item.name for item in items]

if module.namespace:
full_path = module.get_value()
targets: List[str] = []
for item in items:
if isinstance(item, AstImportedMacro):
targets.append(acc.import_macro(full_path, item))
else:
targets.append(item.name)
stmt = f"{', '.join(targets)} = _bolt_runtime.from_module_import({full_path!r}, {', '.join(map(repr, names))})"
acc.statement(stmt, lineno=node)
else:
stmt = f"_bolt_from_import = {acc.import_module(module.path)}"
acc.statement(stmt, lineno=node)
for name in names:
rhs = acc.get_attribute("_bolt_from_import", name)
acc.statement(f"{name} = {rhs}", lineno=node)

return []

@rule(AstCommand, identifier="global:subcommand")
@rule(AstCommand, identifier="nonlocal:subcommand")
def global_nonlocal_statement(
Expand Down
2 changes: 1 addition & 1 deletion bolt/contrib/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LazyFilter(Visitor):

@rule(AstModuleRoot)
def lazy_module(self, node: AstModuleRoot):
module = self.modules.for_current_ast(node)
module = self.modules.match_ast(node)
if module.resource_location and self.lazy.check(module.resource_location):
module.execution_hooks.append(
partial(
Expand Down
Loading

0 comments on commit e8ca750

Please sign in to comment.