Skip to content

Commit

Permalink
feat: extract module manager from runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Jul 10, 2022
1 parent 87df070 commit 20e7499
Show file tree
Hide file tree
Showing 49 changed files with 457 additions and 398 deletions.
1 change: 1 addition & 0 deletions bolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .codegen import *
from .helpers import *
from .loop_info import *
from .module import *
from .parse import *
from .plugin import *
from .runtime import *
9 changes: 5 additions & 4 deletions bolt/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"AstMacroMatchLiteral",
"AstMacroMatchArgument",
"AstInterpolation",
"AstImportedIdentifier",
"AstImportedItem",
]


Expand Down Expand Up @@ -320,7 +320,8 @@ class AstInterpolation(AstNode):


@dataclass(frozen=True)
class AstImportedIdentifier(AstNode):
"""Ast imported identifier node."""
class AstImportedItem(AstNode):
"""Ast imported item node."""

value: str = required_field()
name: str = required_field()
identifier: bool = True
15 changes: 8 additions & 7 deletions bolt/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
AstFormatString,
AstFunctionSignature,
AstIdentifier,
AstImportedIdentifier,
AstImportedItem,
AstInterpolation,
AstKeyword,
AstList,
Expand All @@ -67,6 +67,7 @@
AstUnpack,
AstValue,
)
from .module import CodegenResult


@dataclass
Expand Down Expand Up @@ -408,18 +409,18 @@ def visit_binding(
class Codegen(Visitor):
"""Code generator."""

def __call__(self, node: AstRoot) -> Tuple[Optional[str], Optional[str], List[Any]]: # type: ignore
def __call__(self, node: AstRoot) -> CodegenResult: # type: ignore
acc = Accumulator()
result = self.invoke(node, acc)
if result is None:
return None, None, acc.refs
return CodegenResult(refs=acc.refs)
elif len(result) != 1:
raise ValueError(
f"Expected single result for {node.__class__.__name__} {result!r}."
)
output = acc.make_variable()
acc.statement(f"{output} = {result[0]}")
return acc.get_source(), output, acc.refs
return CodegenResult(source=acc.get_source(), output=output, refs=acc.refs)

@rule(AstNode)
def fallback(
Expand Down Expand Up @@ -678,8 +679,8 @@ def import_statement(
subcommand = cast(AstCommand, node.arguments[1])

while True:
if isinstance(name := subcommand.arguments[0], AstImportedIdentifier):
names.append(name.value)
if isinstance(item := subcommand.arguments[0], AstImportedItem):
names.append(item.name)
if subcommand.identifier == "from:module:import:name:subcommand":
subcommand = cast(AstCommand, subcommand.arguments[1])
else:
Expand All @@ -696,7 +697,7 @@ def import_statement(
acc.statement(f"{name} = {rhs}", lineno=node)

elif node.identifier == "import:module:as:alias":
alias = cast(AstImportedIdentifier, node.arguments[1]).value
alias = cast(AstImportedItem, node.arguments[1]).name

if module.namespace:
acc.statement(
Expand Down
10 changes: 5 additions & 5 deletions bolt/contrib/debug_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,25 @@

from beet import Context
from beet.core.utils import required_field
from mecha import AstRoot, CompilationDatabase, Mecha, Visitor, rule
from mecha import AstRoot, Mecha, Visitor, rule

from bolt import Runtime
from bolt.module import ModuleManager


def beet_default(ctx: Context):
mc = ctx.inject(Mecha)
runtime = ctx.inject(Runtime)
mc.steps[:] = [DebugCodegenEmitter(runtime=runtime, database=mc.database)]
mc.steps[:] = [DebugCodegenEmitter(modules=runtime.modules)]


@dataclass
class DebugCodegenEmitter(Visitor):
"""Visitor that interrupts the compilation process and dumps the generated code."""

runtime: Runtime = required_field()
database: CompilationDatabase = required_field()
modules: ModuleManager = required_field()

@rule(AstRoot)
def debug_codegen(self, node: AstRoot):
self.database.current.text = self.runtime.codegen(node)[0] or ""
self.modules.database.current.text = self.modules.codegen(node).source or ""
return None
19 changes: 9 additions & 10 deletions bolt/contrib/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

from beet import Context, ListOption, TextFileBase, configurable
from beet.core.utils import extra_field, required_field
from mecha import CompilationDatabase, Mecha, Visitor, rule
from mecha import Mecha, Visitor, rule
from pathspec import PathSpec
from pydantic import BaseModel

from bolt import AstModuleRoot, Runtime
from bolt import AstModuleRoot, ModuleManager, Runtime


class BoltLazyOptions(BaseModel):
Expand Down Expand Up @@ -53,7 +53,7 @@ def __post_init__(self, ctx: Context):

mc.steps.insert(
mc.steps.index(runtime.evaluate),
LazyFilter(runtime=runtime, database=mc.database, lazy=self),
LazyFilter(modules=runtime.modules, lazy=self),
)

def register(self, *match: str):
Expand All @@ -70,26 +70,25 @@ def check(self, resource_location: str) -> bool:
class LazyFilter(Visitor):
"""Visitor that filters lazy modules from the compilation by matching resource location."""

runtime: Runtime = required_field()
database: CompilationDatabase = required_field()
modules: ModuleManager = required_field()
lazy: LazyExecution = required_field()

@rule(AstModuleRoot)
def lazy_module(self, node: AstModuleRoot):
module = self.runtime.get_module(node)
module = self.modules.get(node)
if module.resource_location and self.lazy.check(module.resource_location):
module.execution_hooks.append(
partial(
self.restore_lazy,
self.database.current,
self.modules.database.current,
node,
self.database.step + 1,
self.modules.database.step + 1,
)
)
return None
return node

def restore_lazy(self, key: TextFileBase[Any], node: AstModuleRoot, step: int):
compilation_unit = self.database[key]
compilation_unit = self.modules.database[key]
compilation_unit.ast = node
self.database.enqueue(key, step)
self.modules.database.enqueue(key, step)
Loading

0 comments on commit 20e7499

Please sign in to comment.