Skip to content

Commit

Permalink
feat(frontend): compile using in-memory module
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Jul 20, 2023
1 parent 6fcbb6d commit 636c338
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
5 changes: 3 additions & 2 deletions frontends/concrete-python/concrete/fhe/compilation/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
from concrete.compiler import SimulatedValueDecrypter, SimulatedValueExporter
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from ..representation import Graph
Expand All @@ -29,13 +30,13 @@ class Circuit:
configuration: Configuration

graph: Graph
mlir: str
mlir: MlirModule

client: Client
server: Server
simulator: Server

def __init__(self, graph: Graph, mlir: str, configuration: Optional[Configuration] = None):
def __init__(self, graph: Graph, mlir: MlirModule, configuration: Optional[Configuration] = None):
self.configuration = configuration if configuration is not None else Configuration()

self.graph = graph
Expand Down
13 changes: 8 additions & 5 deletions frontends/concrete-python/concrete/fhe/compilation/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,12 @@ def compile(
message = "Function you are trying to compile cannot be compiled\n\n" + fmtd_graph
raise RuntimeError(message)

mlir = GraphConverter().convert(self.graph, self.configuration)
# in-memory MLIR module
mlir_module = GraphConverter().convert(self.graph, self.configuration)
# textual representation of the MLIR module
mlir_str = str(mlir_module).strip()
if self.artifacts is not None:
self.artifacts.add_mlir_to_compile(mlir)
self.artifacts.add_mlir_to_compile(mlir_str)

show_graph = (
self.configuration.show_graph
Expand All @@ -475,7 +478,7 @@ def compile(
)

longest_graph_line = max(len(line) for line in graph.split("\n"))
longest_mlir_line = max(len(line) for line in mlir.split("\n"))
longest_mlir_line = max(len(line) for line in mlir_str.split("\n"))
longest_line = max(longest_graph_line, longest_mlir_line)

try: # pragma: no cover
Expand Down Expand Up @@ -506,7 +509,7 @@ def compile(

print("MLIR")
print("-" * columns)
print(mlir)
print(mlir_str)
print("-" * columns)

print()
Expand All @@ -517,7 +520,7 @@ def compile(
print("Optimizer")
print("-" * columns)

circuit = Circuit(self.graph, mlir, self.configuration)
circuit = Circuit(self.graph, mlir_module, self.configuration)

if hasattr(circuit, "client"):
client_parameters = circuit.client.specs.client_parameters
Expand Down
7 changes: 4 additions & 3 deletions frontends/concrete-python/concrete/fhe/compilation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
PublicArguments,
)
from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy
from mlir.ir import Module as MlirModule

from ..internal.utils import assert_that
from .configuration import (
Expand Down Expand Up @@ -81,12 +82,12 @@ def __init__(
)

@staticmethod
def create(mlir: str, configuration: Configuration, is_simulated: bool = False) -> "Server":
def create(mlir: MlirModule, configuration: Configuration, is_simulated: bool = False) -> "Server":
"""
Create a server using MLIR and output sign information.
Args:
mlir (str):
mlir (MlirModule):
mlir to compile
configuration (Configuration):
Expand Down Expand Up @@ -182,7 +183,7 @@ def create(mlir: str, configuration: Configuration, is_simulated: bool = False)
)

# pylint: disable=protected-access
result._mlir = mlir
result._mlir = str(mlir).strip()
result._configuration = configuration
# pylint: enable=protected-access

Expand Down
8 changes: 4 additions & 4 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Converter:
Converter class, to convert a computation graph to MLIR.
"""

def convert(self, graph: Graph, configuration: Configuration) -> str:
def convert(self, graph: Graph, configuration: Configuration) -> MlirModule:
"""
Convert a computation graph to MLIR.
Expand All @@ -46,8 +46,8 @@ def convert(self, graph: Graph, configuration: Configuration) -> str:
configuration to use
Return:
str:
MLIR corresponding to graph
MlirModule:
In-memory MLIR module corresponding to the graph
"""

graph = self.process(graph, configuration)
Expand Down Expand Up @@ -88,7 +88,7 @@ def main(*args):

return tuple(outputs)

return str(module).strip()
return module

@staticmethod
def stdout_with_ansi_support() -> bool:
Expand Down

0 comments on commit 636c338

Please sign in to comment.