Skip to content

Commit

Permalink
feat(frontend): compile using in-memory module
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Jul 20, 2023
1 parent 18aceec commit e66cd37
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 22 deletions.
26 changes: 21 additions & 5 deletions frontends/concrete-python/concrete/fhe/compilation/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Any, List, Optional, Tuple, Union

import numpy as np
from concrete.compiler import SimulatedValueDecrypter, SimulatedValueExporter
from concrete.compiler import CompilationContext, SimulatedValueDecrypter, SimulatedValueExporter
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from ..representation import Graph
Expand All @@ -29,17 +30,25 @@ class Circuit:
configuration: Configuration

graph: Graph
mlir: str
mlir: MlirModule
compilation_context: CompilationContext

client: Client
server: Server
simulator: Server

def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
def __init__(
self,
graph: Graph,
mlir: MlirModule,
compilation_context: CompilationContext,
configuration: Optional[Configuration] = None,
):
self.configuration = configuration if configuration is not None else Configuration()

self.graph = graph
self.mlir = mlir
self.compilation_context = compilation_context

if self.configuration.fhe_simulation:
self.enable_fhe_simulation()
Expand All @@ -56,15 +65,22 @@ def enable_fhe_simulation(self):
"""

if not hasattr(self, "simulator"):
self.simulator = Server.create(self.mlir, self.configuration, is_simulated=True)
self.simulator = Server.create(
self.mlir,
self.configuration,
is_simulated=True,
compilation_context=self.compilation_context,
)

def enable_fhe_execution(self):
"""
Enable FHE execution.
"""

if not hasattr(self, "server"):
self.server = Server.create(self.mlir, self.configuration)
self.server = Server.create(
self.mlir, self.configuration, compilation_context=self.compilation_context
)

keyset_cache_directory = None
if self.configuration.use_insecure_key_cache:
Expand Down
24 changes: 19 additions & 5 deletions frontends/concrete-python/concrete/fhe/compilation/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from concrete.compiler import CompilationContext

from ..extensions import AutoRounder
from ..mlir import GraphConverter
Expand Down Expand Up @@ -46,6 +47,8 @@ class Compiler:
inputset: List[Any]
graph: Optional[Graph]

compilation_context: CompilationContext

_is_direct: bool
_parameter_values: Dict[str, ValueDescription]

Expand Down Expand Up @@ -155,6 +158,8 @@ def __init__(
self.inputset = []
self.graph = None

self.compilation_context = CompilationContext.new()

self._is_direct = False
self._parameter_values = {}

Expand Down Expand Up @@ -446,9 +451,13 @@ def compile(
message = "Function you are trying to compile cannot be compiled\n\n" + fmtd_graph
raise RuntimeError(message)

mlir = GraphConverter().convert(self.graph, self.configuration)
# in-memory MLIR module
mlir_context = self.compilation_context.mlir_context()
mlir_module = GraphConverter().convert(self.graph, self.configuration, mlir_context)
# textual representation of the MLIR module
mlir_str = str(mlir_module).strip()
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)
self.artifacts.add_mlir_to_compile(mlir_str)

show_graph = (
self.configuration.show_graph
Expand All @@ -475,7 +484,7 @@ def compile(
)

longest_graph_line = max(len(line) for line in graph.split("\n"))
longest_mlir_line = max(len(line) for line in mlir.split("\n"))
longest_mlir_line = max(len(line) for line in mlir_str.split("\n"))
longest_line = max(longest_graph_line, longest_mlir_line)

try: # pragma: no cover
Expand Down Expand Up @@ -506,7 +515,7 @@ def compile(

print("MLIR")
print("-" * columns)
print(mlir)
print(mlir_str)
print("-" * columns)

print()
Expand All @@ -517,7 +526,12 @@ def compile(
print("Optimizer")
print("-" * columns)

circuit = Circuit(self.graph, mlir, self.configuration)
circuit = Circuit(
self.graph,
mlir_module,
self.compilation_context,
self.configuration,
)

if hasattr(circuit, "client"):
client_parameters = circuit.client.specs.client_parameters
Expand Down
36 changes: 29 additions & 7 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# mypy: disable-error-code=attr-defined
import concrete.compiler
from concrete.compiler import (
CompilationContext,
CompilationFeedback,
CompilationOptions,
EvaluationKeys,
Expand All @@ -25,6 +26,7 @@
PublicArguments,
)
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from .configuration import (
Expand Down Expand Up @@ -81,19 +83,27 @@ def __init__(
)

@staticmethod
def create(mlir: str, configuration: Configuration, is_simulated: bool = False) -> "Server":
def create(
mlir: Union[str, MlirModule],
configuration: Configuration,
is_simulated: bool = False,
compilation_context: Optional[CompilationContext] = None,
) -> "Server":
"""
Create a server using MLIR and output sign information.
Args:
mlir (str):
mlir (MlirModule):
mlir to compile
configuration (Configuration):
configuration to use
compilation_context (CompilationContext):
context to use for the Compiler
is_simulated (bool, default = False):
whether to compile in simulation mode or not
configuration (Optional[Configuration]):
configuration to use
"""

options = CompilationOptions.new("main")
Expand Down Expand Up @@ -154,7 +164,13 @@ def create(mlir: str, configuration: Configuration, is_simulated: bool = False)
output_dir = None

support = JITSupport.new()
compilation_result = support.compile(mlir, options)
if isinstance(mlir, str):
compilation_result = support.compile(mlir, options)
else: # MlirModule
assert (
compilation_context is not None
), "must provide compilation context when compiling MlirModule"
compilation_result = support.compile(mlir, options, compilation_context)
server_lambda = support.load_server_lambda(compilation_result)

else:
Expand All @@ -166,7 +182,13 @@ def create(mlir: str, configuration: Configuration, is_simulated: bool = False)
support = LibrarySupport.new(
str(output_dir_path), generateCppHeader=False, generateStaticLib=False
)
compilation_result = support.compile(mlir, options)
if isinstance(mlir, str):
compilation_result = support.compile(mlir, options)
else: # MlirModule
assert (
compilation_context is not None
), "must provide compilation context when compiling MlirModule"
compilation_result = support.compile(mlir, options, compilation_context)
server_lambda = support.load_server_lambda(compilation_result)

client_parameters = support.load_client_parameters(compilation_result)
Expand All @@ -182,7 +204,7 @@ def create(mlir: str, configuration: Configuration, is_simulated: bool = False)
)

# pylint: disable=protected-access
result._mlir = mlir
result._mlir = str(mlir).strip()
result._configuration = configuration
# pylint: enable=protected-access

Expand Down
16 changes: 11 additions & 5 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ class Converter:
Converter class, to convert a computation graph to MLIR.
"""

def convert(self, graph: Graph, configuration: Configuration) -> str:
def convert(
self, graph: Graph, configuration: Configuration, mlir_context: MlirContext
) -> MlirModule:
"""
Convert a computation graph to MLIR.
Expand All @@ -45,14 +47,18 @@ def convert(self, graph: Graph, configuration: Configuration) -> str:
configuration (Configuration):
configuration to use
mlir_context (MlirContext):
MLIR Context to use for module generation
Return:
str:
MLIR corresponding to graph
MlirModule:
In-memory MLIR module corresponding to the graph
"""

graph = self.process(graph, configuration)

with MlirContext() as context, MlirLocation.unknown():
context = mlir_context
with MlirLocation.unknown():
concrete.lang.register_dialects(context) # pylint: disable=no-member

module = MlirModule.create()
Expand Down Expand Up @@ -88,7 +94,7 @@ def main(*args):

return tuple(outputs)

return str(module).strip()
return module

@staticmethod
def stdout_with_ansi_support() -> bool:
Expand Down

0 comments on commit e66cd37

Please sign in to comment.