Skip to content

Commit

Permalink
chore(frontends): enable simultaneous execution and simulation in mod…
Browse files Browse the repository at this point in the history
…ules
  • Loading branch information
aPere3 committed Aug 30, 2024
1 parent 2afc3d2 commit 7389ed2
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 104 deletions.
40 changes: 32 additions & 8 deletions frontends/concrete-python/concrete/fhe/compilation/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
import shutil
import subprocess
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from ..representation import Graph

if TYPE_CHECKING:
from .module import ExecutionRt
from .utils import Lazy

DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")


Expand Down Expand Up @@ -83,7 +87,7 @@ class ModuleDebugArtifacts:

output_directory: Path
mlir_to_compile: Optional[str]
client_parameters: Optional[bytes]
_execution_runtime: Optional["Lazy[ExecutionRt]"]
functions: Dict[str, FunctionDebugArtifacts]

def __init__(
Expand All @@ -93,7 +97,7 @@ def __init__(
):
self.output_directory = Path(output_directory)
self.mlir_to_compile = None
self.client_parameters = None
self._execution_runtime = None
self.functions = (
{name: FunctionDebugArtifacts() for name in function_names} if function_names else {}
)
Expand All @@ -108,15 +112,28 @@ def add_mlir_to_compile(self, mlir: str):
"""
self.mlir_to_compile = mlir

def add_client_parameters(self, client_parameters: bytes):
def add_execution_runtime(self, execution_runtime: "Lazy[ExecutionRt]"):
"""
Add client parameters used.
Add the (lazy) execution runtime to get the client parameters if needed.
Args:
client_parameters (bytes): client parameters
execution_runtime (Lazy[ExecutionRt]):
The lazily initialized execution runtime.
"""

self.client_parameters = client_parameters
self._execution_runtime = execution_runtime

@property
def client_parameters(self) -> Optional[bytes]:
"""
The client parameters associated with the execution runtime.
"""

return (
self._execution_runtime.val.client.specs.client_parameters.serialize()
if self._execution_runtime is not None
else None
)

def export(self):
"""
Expand Down Expand Up @@ -217,6 +234,7 @@ class DebugArtifacts:
"""

module_artifacts: ModuleDebugArtifacts
_client_parameters: Optional[bytes]

def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY):
self.module_artifacts = ModuleDebugArtifacts(["main"], output_directory)
Expand Down Expand Up @@ -280,13 +298,19 @@ def add_client_parameters(self, client_parameters: bytes):
client_parameters (bytes): client parameters
"""

self.module_artifacts.add_client_parameters(client_parameters)
self._client_parameters = client_parameters

def export(self):
"""
Export the collected information to `self.output_directory`.
"""

# This is a quick fix before we refactor compiler and module_compiler
# to use the same abstraction.
class _ModuleDebugArtifacts(ModuleDebugArtifacts):
client_parameters = self._client_parameters

self.module_artifacts.__class__ = _ModuleDebugArtifacts
self.module_artifacts.export()

@property
Expand Down
Loading

0 comments on commit 7389ed2

Please sign in to comment.