From b794caaf4b02f353aea518782c889031b6a29b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Fri, 1 Mar 2024 16:50:12 +0100 Subject: [PATCH] wip --- .../include/concretelang/Common/Compat.h | 23 +- .../concretelang/ServerLib/ServerLib.h | 1 + .../lib/Bindings/Python/CMakeLists.txt | 2 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 60 +- .../Python/concrete/compiler/__init__.py | 1 + .../concrete/compiler/server_circuit.py | 92 +++ .../concrete/compiler/server_program.py | 84 +++ .../concrete-python/concrete/fhe/__init__.py | 2 +- .../concrete/fhe/compilation/__init__.py | 4 +- .../concrete/fhe/compilation/artifacts.py | 168 ++++- .../concrete/fhe/compilation/client.py | 12 +- .../concrete/fhe/compilation/configuration.py | 5 +- .../concrete/fhe/compilation/decorators.py | 193 ++--- .../concrete/fhe/compilation/program.py | 660 ++++++++++++++++++ .../fhe/compilation/program_compiler.py | 481 +++++++++++++ .../concrete/fhe/compilation/server.py | 211 +++--- .../concrete/fhe/compilation/utils.py | 5 +- .../concrete/fhe/mlir/converter.py | 109 ++- .../fhe/mlir/processors/assign_bit_widths.py | 77 +- .../concrete/fhe/representation/__init__.py | 2 +- .../concrete/fhe/representation/graph.py | 23 + 21 files changed, 1885 insertions(+), 330 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py create mode 100644 compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py create mode 100644 frontends/concrete-python/concrete/fhe/compilation/program.py create mode 100644 frontends/concrete-python/concrete/fhe/compilation/program_compiler.py diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h index ae26644897..d92df6d013 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h @@ -42,15 +42,19 @@ using concretelang::serverlib::ServerProgram; using concretelang::values::TransportValue; using concretelang::values::Value; -#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \ - auto VARNAME = EXPECTED; \ - if (auto err = VARNAME.takeError()) { \ - throw std::runtime_error(llvm::toString(std::move(err))); \ - } - #define CONCAT(a, b) CONCAT_INNER(a, b) #define CONCAT_INNER(a, b) a##b +#define GET_OR_THROW_EXPECTED_(VARNAME, RESULT, MAYBE) \ + auto MAYBE = RESULT; \ + if (auto err = MAYBE.takeError()){ \ + throw std::runtime_error(llvm::toString(std::move(err))); \ + } \ + VARNAME = std::move(MAYBE); \ + +#define GET_OR_THROW_EXPECTED(VARNAME, RESULT) \ + GET_OR_THROW_EXPECTED_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__)) + #define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \ auto MAYBE = RESULT; \ if (MAYBE.has_failure()) { \ @@ -370,6 +374,13 @@ class LibrarySupport { useSimulation}; } + llvm::Expected loadServerProgram(LibraryCompilationResult &result, bool useSimulation){ + EXPECTED_TRY(auto programInfo, getProgramInfo()); + return outcomeToExpected(ServerProgram::load(programInfo.asReader(), + getSharedLibPath(), + useSimulation)); + } + /// Load the client parameters from the compilation result. llvm::Expected<::concretelang::clientlib::ClientParameters> loadClientParameters(LibraryCompilationResult &result) { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h index c7a56a533e..8e8a68eae6 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h @@ -50,6 +50,7 @@ class ServerCircuit { Result> call(const ServerKeyset &serverKeyset, std::vector &args); + /// Simulate the circuit with public arguments. Result> simulate(std::vector &args); diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt index 470c0b5378..d3c2e4a7c3 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt @@ -49,6 +49,8 @@ declare_mlir_python_sources( concrete/compiler/parameter.py concrete/compiler/public_arguments.py concrete/compiler/public_result.py + concrete/compiler/server_circuit.py + concrete/compiler/server_program.py concrete/compiler/evaluation_keys.py concrete/compiler/simulated_value_decrypter.py concrete/compiler/simulated_value_exporter.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index d850dadfae..11d7245312 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -11,6 +11,7 @@ #include "concretelang/Common/Keysets.h" #include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc" #include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/ServerLib/ServerLib.h" #include "concretelang/Support/logging.h" #include #include @@ -79,7 +80,7 @@ library_compile(LibrarySupport_Py support, const char *module, llvm::SourceMgr sm; sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module), llvm::SMLoc()); - GET_OR_THROW_LLVM_EXPECTED(compilationResult, + GET_OR_THROW_EXPECTED(auto compilationResult, support.support.compile(sm, options)); return std::move(*compilationResult); } @@ -89,7 +90,7 @@ library_compile_module( LibrarySupport_Py support, mlir::ModuleOp module, mlir::concretelang::CompilationOptions options, std::shared_ptr cctx) { - GET_OR_THROW_LLVM_EXPECTED(compilationResult, + GET_OR_THROW_EXPECTED(auto compilationResult, support.support.compile(module, cctx, options)); return std::move(*compilationResult); } @@ -97,7 +98,7 @@ library_compile_module( concretelang::clientlib::ClientParameters library_load_client_parameters( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(clientParameters, + GET_OR_THROW_EXPECTED(auto clientParameters, support.support.loadClientParameters(result)); return *clientParameters; } @@ -106,7 +107,7 @@ mlir::concretelang::ProgramCompilationFeedback library_load_compilation_feedback( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, + GET_OR_THROW_EXPECTED(auto compilationFeedback, support.support.loadCompilationFeedback(result)); return *compilationFeedback; } @@ -115,8 +116,8 @@ concretelang::serverlib::ServerLambda library_load_server_lambda(LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result, std::string circuitName, bool useSimulation) { - GET_OR_THROW_LLVM_EXPECTED( - serverLambda, + GET_OR_THROW_EXPECTED( + auto serverLambda, support.support.loadServerLambda(result, circuitName, useSimulation)); return *serverLambda; } @@ -126,8 +127,8 @@ library_server_call(LibrarySupport_Py support, concretelang::serverlib::ServerLambda lambda, concretelang::clientlib::PublicArguments &args, concretelang::clientlib::EvaluationKeys &evaluationKeys) { - GET_OR_THROW_LLVM_EXPECTED( - publicResult, support.support.serverCall(lambda, args, evaluationKeys)); + GET_OR_THROW_EXPECTED( + auto publicResult, support.support.serverCall(lambda, args, evaluationKeys)); return std::move(*publicResult); } @@ -135,7 +136,7 @@ std::unique_ptr library_simulate(LibrarySupport_Py support, concretelang::serverlib::ServerLambda lambda, concretelang::clientlib::PublicArguments &args) { - GET_OR_THROW_LLVM_EXPECTED(publicResult, + GET_OR_THROW_EXPECTED(auto publicResult, support.support.simulate(lambda, args)); return std::move(*publicResult); } @@ -1186,6 +1187,47 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return pybind11::bytes(valueSerialize(value)); }); + + pybind11::class_(m, "ServerProgram") + .def_static( + "load", + [](LibrarySupport_Py &support, + bool useSimulation) { + GET_OR_THROW_EXPECTED(auto programInfo, support.support.getProgramInfo()); + auto sharedLibPath = support.support.getSharedLibPath(); + GET_OR_THROW_RESULT(auto result, ServerProgram::load((*programInfo).asReader(), sharedLibPath, useSimulation)); + return result; + }) + .def( + "get_server_circuit", + [](ServerProgram &program, + const std::string &circuitName) { + GET_OR_THROW_RESULT(auto result, program.getServerCircuit(circuitName)); + return result; + }); + + pybind11::class_(m, "ServerCircuit") + .def("call", + [](ServerCircuit &circuit, + ::concretelang::clientlib::PublicArguments &publicArguments, + ::concretelang::clientlib::EvaluationKeys &evaluationKeys) { + SignalGuard signalGuard; + auto keyset = evaluationKeys.keyset; + auto values = publicArguments.values; + GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values)); + ::concretelang::clientlib::PublicResult res{output}; + return std::make_unique<::concretelang::clientlib::PublicResult>(std::move(res)); + }) + .def("simulate", + [](ServerCircuit &circuit, + ::concretelang::clientlib::PublicArguments &publicArguments) { + pybind11::gil_scoped_release release; + auto values = publicArguments.values; + GET_OR_THROW_RESULT(auto output, circuit.simulate(values)); + ::concretelang::clientlib::PublicResult res{output}; + return std::make_unique<::concretelang::clientlib::PublicResult>(std::move(res)); + }); + pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter") .def_static( "create", diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index 7a712cbd3a..4262354e4b 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -37,6 +37,7 @@ from .simulated_value_decrypter import SimulatedValueDecrypter from .simulated_value_exporter import SimulatedValueExporter from .parameter import Parameter +from .server_program import ServerProgram def init_dfr(): diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py new file mode 100644 index 0000000000..f4b3cd5cdb --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py @@ -0,0 +1,92 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. + +"""ServerCircuit.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + ServerCircuit as _ServerCircuit, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp +from .public_arguments import PublicArguments +from .public_result import PublicResult +from .evaluation_keys import EvaluationKeys + + +class ServerCircuit(WrapperCpp): + """ServerCircuit references a circuit that can be called for execution and simulation.""" + + def __init__(self, server_circuit: _ServerCircuit): + """Wrap the native Cpp object. + + Args: + server_circuit (_ServerCircuit): object to wrap + + Raises: + TypeError: if server_circuit is not of type _ServerCircuit + """ + if not isinstance(server_circuit, _ServerCircuit): + raise TypeError( + f"server_circuit must be of type _ServerCircuit, not {type(server_circuit)}" + ) + super().__init__(server_circuit) + + def call( + self, + public_arguments: PublicArguments, + evaluation_keys: EvaluationKeys, + ) -> PublicResult: + """Executes the circuit on the public arguments. + + Args: + public_arguments (PublicArguments): public arguments to execute on + execution_keys (EvaluationKeys): evaluation keys to use for execution. + + Raises: + TypeError: if public_arguments is not of type PublicArguments, or if evaluation_keys is not of type EvaluationKeys + + Returns: + PublicResult: A public result object containing the results. + """ + if not isinstance(public_arguments, PublicArguments): + raise TypeError( + f"public_arguments must be of type PublicArguments, not " + f"{type(public_arguments)}" + ) + if not isinstance(evaluation_keys, EvaluationKeys): + raise TypeError( + f"simulation must be of type EvaluationKeys, not " f"{type(evaluation_keys)}" + ) + return PublicResult.wrap( + self.cpp().call( + public_arguments.cpp(), evaluation_keys.cpp() + ) + ) + + def simulate( + self, + public_arguments: PublicArguments, + ) -> PublicResult: + """Simulates the circuit on the public arguments. + + Args: + public_arguments (PublicArguments): public arguments to execute on + + Raises: + TypeError: if public_arguments is not of type PublicArguments + + Returns: + PublicResult: A public result object containing the results. + """ + if not isinstance(public_arguments, PublicArguments): + raise TypeError( + f"public_arguments must be of type PublicArguments, not " + f"{type(public_arguments)}" + ) + return PublicResult.wrap( + self.cpp().simulate( + public_arguments.cpp() + ) + ) \ No newline at end of file diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py new file mode 100644 index 0000000000..97b399531d --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py @@ -0,0 +1,84 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. + +"""ServerProgram.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + ServerProgram as _ServerProgram, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp +from .library_support import LibrarySupport +from .server_circuit import ServerCircuit + + +class ServerProgram(WrapperCpp): + """ServerProgram references compiled circuit objects.""" + + def __init__(self, server_program: _ServerProgram): + """Wrap the native Cpp object. + + Args: + server_program (_ServerProgram): object to wrap + + Raises: + TypeError: if server_program is not of type _ServerProgram + """ + if not isinstance(server_program, _ServerProgram): + raise TypeError( + f"server_program must be of type _ServerProgram, not {type(server_program)}" + ) + super().__init__(server_program) + + def load( + library_support: LibrarySupport, + simulation: bool, + ) -> "ServerProgram": + """Loads the server program from a library support. + + Args: + library_support (LibrarySupport): library support + simulation (bool): use simulation for execution + + Raises: + TypeError: if library_support is not of type LibrarySupport, or if simulation is not of type bool + + Returns: + ServerProgram: A server program object containing references to circuits for calls. + """ + if not isinstance(library_support, LibrarySupport): + raise TypeError( + f"library_support must be of type LibrarySupport, not " + f"{type(library_support)}" + ) + if not isinstance(simulation, bool): + raise TypeError( + f"simulation must be of type bool, not " f"{type(simulation)}" + ) + return ServerProgram.wrap( + _ServerProgram.load( + library_support.cpp(), simulation + ) + ) + + + def get_server_circuit(self, circuit_name: str) -> ServerCircuit: + """Returns a given circuit if it is part of the program. + + Args: + circuit_name (str): name of the circuit to retrieve. + + Raises: + TypeError: if circuit_name is not of type str + RuntimeError: if the circuit is not part of the program + """ + if not isinstance(circuit_name, str): + raise TypeError( + f"circuit_name must be of type str, not {type(circuit_name)}" + ) + + return ServerCircuit.wrap( + self.cpp().get_server_circuit(circuit_name) + ) diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 0e8386ca99..65c3cb042a 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -28,7 +28,7 @@ Server, Value, ) -from .compilation.decorators import circuit, compiler +from .compilation.decorators import circuit, compiler, program from .extensions import ( AutoRounder, AutoTruncator, diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index 5f134c2a99..23842a5e6d 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -2,7 +2,7 @@ Glue the compilation process together. """ -from .artifacts import DebugArtifacts +from .artifacts import CircuitDebugArtifacts, DebugArtifacts, ProgramDebugArtifacts from .circuit import Circuit from .client import Client from .compiler import Compiler, EncryptionStatus @@ -20,6 +20,8 @@ ParameterSelectionStrategy, ) from .keys import Keys +from .program import Program, ProgramCircuit +from .program_compiler import CircuitDef, ProgramCompiler from .server import Server from .specs import ClientSpecs from .value import Value diff --git a/frontends/concrete-python/concrete/fhe/compilation/artifacts.py b/frontends/concrete-python/concrete/fhe/compilation/artifacts.py index a1b7994813..8e3b10792e 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/artifacts.py +++ b/frontends/concrete-python/concrete/fhe/compilation/artifacts.py @@ -14,29 +14,21 @@ DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") -class DebugArtifacts: +class CircuitDebugArtifacts: """ - DebugArtifacts class, to export information about the compilation process. + An object containing debug artifacts for a certain circuit in a multi-circuit program. """ - output_directory: Path - source_code: Optional[str] parameter_encryption_statuses: Dict[str, str] textual_representations_of_graphs: Dict[str, List[str]] final_graph: Optional[Graph] - mlir_to_compile: Optional[str] - client_parameters: Optional[bytes] - - def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY): - self.output_directory = Path(output_directory) + def __init__(self): self.source_code = None self.parameter_encryption_statuses = {} self.textual_representations_of_graphs = {} self.final_graph = None - self.mlir_to_compile = None - self.client_parameters = None def add_source_code(self, function: Union[str, Callable]): """ @@ -46,7 +38,6 @@ def add_source_code(self, function: Union[str, Callable]): function (Union[str, Callable]): either the source code of the function or the function itself """ - try: self.source_code = ( function if isinstance(function, str) else inspect.getsource(function) @@ -65,7 +56,6 @@ def add_parameter_encryption_status(self, name: str, encryption_status: str): encryption_status (str): encryption status of the parameter """ - self.parameter_encryption_statuses[name] = encryption_status def add_graph(self, name: str, graph: Graph): @@ -79,15 +69,33 @@ def add_graph(self, name: str, graph: Graph): graph (Graph): a representation of the function being compiled """ - if name not in self.textual_representations_of_graphs: self.textual_representations_of_graphs[name] = [] - textual_representation = graph.format() self.textual_representations_of_graphs[name].append(textual_representation) - self.final_graph = graph + +class ProgramDebugArtifacts: + """ + An object containing debug artifacts for a whole multi-circuit program. + """ + + output_directory: Path + mlir_to_compile: Optional[str] + client_parameters: Optional[bytes] + circuits: Dict[str, CircuitDebugArtifacts] + + def __init__( + self, + circuit_names: List[str], + output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY, + ): + self.output_directory = Path(output_directory) + self.mlir_to_compile = None + self.client_parameters = None + self.circuits = {name: CircuitDebugArtifacts() for name in circuit_names} + def add_mlir_to_compile(self, mlir: str): """ Add textual representation of the resulting MLIR. @@ -96,7 +104,6 @@ def add_mlir_to_compile(self, mlir: str): mlir (str): textual representation of the resulting MLIR """ - self.mlir_to_compile = mlir def add_client_parameters(self, client_parameters: bytes): @@ -113,7 +120,6 @@ def export(self): """ Export the collected information to `self.output_directory`. """ - # pylint: disable=too-many-branches output_directory = self.output_directory @@ -164,27 +170,35 @@ def export(self): f.write(f"{name}=={version}\n") - if self.source_code is not None: - with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f: - f.write(self.source_code) - - if len(self.parameter_encryption_statuses) > 0: - with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f: - for name, parameter in self.parameter_encryption_statuses.items(): - f.write(f"{name} :: {parameter}\n") - - identifier = 0 - - textual_representations = self.textual_representations_of_graphs.items() - for name, representations in textual_representations: - for representation in representations: - identifier += 1 - output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt") - with open(output_path, "w", encoding="utf-8") as f: - f.write(f"{representation}\n") + for circuit_name, circuit in self.circuits.items(): + if circuit.source_code is not None: + with open( + output_directory.joinpath(f"{circuit_name}.txt"), "w", encoding="utf-8" + ) as f: + f.write(circuit.source_code) + + if len(circuit.parameter_encryption_statuses) > 0: + with open( + output_directory.joinpath(f"{circuit_name}.parameters.txt"), + "w", + encoding="utf-8", + ) as f: + for name, parameter in self.parameter_encryption_statuses.items(): + f.write(f"{name} :: {parameter}\n") + + identifier = 0 + + textual_representations = circuit.textual_representations_of_graphs.items() + for name, representations in textual_representations: + for representation in representations: + identifier += 1 + output_path = output_directory.joinpath( + f"{circuit_name}.{identifier}.{name}.graph.txt" + ) + with open(output_path, "w", encoding="utf-8") as f: + f.write(f"{representation}\n") if self.mlir_to_compile is not None: - assert self.final_graph is not None with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: f.write(f"{self.mlir_to_compile}\n") @@ -193,3 +207,83 @@ def export(self): f.write(self.client_parameters) # pylint: enable=too-many-branches + + +class DebugArtifacts: + """ + DebugArtifacts class, to export information about the compilation process for single circuit + programs. + """ + + program_artifacts: ProgramDebugArtifacts + + def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY): + self.program_artifacts = ProgramDebugArtifacts(["main"], output_directory) + + def add_source_code(self, function: Union[str, Callable]): + """ + Add source code of the function being compiled. + + Args: + function (Union[str, Callable]): + either the source code of the function or the function itself + """ + self.program_artifacts.circuits["main"].add_source_code(function) + + def add_parameter_encryption_status(self, name: str, encryption_status: str): + """ + Add parameter encryption status of a parameter of the function being compiled. + + Args: + name (str): + name of the parameter + + encryption_status (str): + encryption status of the parameter + """ + + self.program_artifacts.circuits["main"].add_parameter_encryption_status( + name, encryption_status + ) + + def add_graph(self, name: str, graph: Graph): + """ + Add a representation of the function being compiled. + + Args: + name (str): + name of the graph (e.g., initial, optimized, final) + + graph (Graph): + a representation of the function being compiled + """ + + self.program_artifacts.circuits["main"].add_graph(name, graph) + + def add_mlir_to_compile(self, mlir: str): + """ + Add textual representation of the resulting MLIR. + + Args: + mlir (str): + textual representation of the resulting MLIR + """ + + self.program_artifacts.circuits.add_mlit_to_compile(mlir) + + def add_client_parameters(self, client_parameters: bytes): + """ + Add client parameters used. + + Args: + client_parameters (bytes): client parameters + """ + + self.program_artifacts.add_client_parameters(client_parameters) + + def export(self): + """ + Export the collected information to `self.output_directory`. + """ + + self.program_artifacts.export() diff --git a/frontends/concrete-python/concrete/fhe/compilation/client.py b/frontends/concrete-python/concrete/fhe/compilation/client.py index 83ed7a1a45..d07fb94628 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/client.py +++ b/frontends/concrete-python/concrete/fhe/compilation/client.py @@ -120,6 +120,7 @@ def keygen( def encrypt( self, *args: Optional[Union[int, np.ndarray, List]], + circuit_name: str = "main", ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: """ Encrypt argument(s) to for evaluation. @@ -127,18 +128,20 @@ def encrypt( Args: *args (Optional[Union[int, np.ndarray, List]]): argument(s) for evaluation + circuit_name (str): + name of the circuit to encrypt Returns: Optional[Union[Value, Tuple[Optional[Value], ...]]]: encrypted argument(s) for evaluation """ - ordered_sanitized_args = validate_input_args(self.specs, *args) + ordered_sanitized_args = validate_input_args(self.specs, circuit_name=circuit_name, *args) self.keygen(force=False) keyset = self.keys._keyset # pylint: disable=protected-access - exporter = ValueExporter.new(keyset, self.specs.client_parameters) + exporter = ValueExporter.new(keyset, self.specs.client_parameters, circuit_name) exported = [ None if arg is None @@ -155,6 +158,7 @@ def encrypt( def decrypt( self, *results: Union[Value, Tuple[Value, ...]], + circuit_name: str = "main", ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: """ Decrypt result(s) of evaluation. @@ -162,6 +166,8 @@ def decrypt( Args: *results (Union[Value, Tuple[Value, ...]]): result(s) of evaluation + circuit_name (str): + name of the circuit to decrypt for Returns: Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: @@ -179,7 +185,7 @@ def decrypt( self.keygen(force=False) keyset = self.keys._keyset # pylint: disable=protected-access - decrypter = ValueDecrypter.new(keyset, self.specs.client_parameters) + decrypter = ValueDecrypter.new(keyset, self.specs.client_parameters, circuit_name) decrypted = tuple( decrypter.decrypt(position, result.inner) for position, result in enumerate(flattened_results) diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index f01b34bc4b..06541b4016 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -1268,10 +1268,9 @@ def _validate(self): raise RuntimeError(message) if ( - self.composable - and self.parameter_selection_strategy != ParameterSelectionStrategy.MULTI + self.composable and self.parameter_selection_strategy == ParameterSelectionStrategy.MONO ): # pragma: no cover - message = "Composition can only be used with MULTI parameter selection strategy" + message = "Composition can not be used with MONO parameter selection strategy" raise RuntimeError(message) diff --git a/frontends/concrete-python/concrete/fhe/compilation/decorators.py b/frontends/concrete-python/concrete/fhe/compilation/decorators.py index 1e4066effe..aca865acbf 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/decorators.py +++ b/frontends/concrete-python/concrete/fhe/compilation/decorators.py @@ -13,6 +13,7 @@ from .circuit import Circuit from .compiler import Compiler, EncryptionStatus from .configuration import Configuration +from .program_compiler import CircuitDef, ProgramCompiler def circuit( @@ -22,7 +23,7 @@ def circuit( **kwargs, ): """ - Provide a direct interface for compilation. + Provide a direct interface for compilation of single circuit programs. Args: parameters (Mapping[str, Union[str, EncryptionStatus]]): @@ -66,18 +67,120 @@ def decoration(function: Callable): if is_value else ValueDescription(deepcopy(annotation.dtype), shape=(), is_encrypted=False) ) + print(inspect.getmembers_static(classs)) - status = EncryptionStatus(parameters[name].lower()) - parameter_values[name].is_encrypted = status == "encrypted" + return decoration - return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs) - return decoration +class Compilable: + """ + Compilable class, to wrap a function and provide methods to trace and compile it. + """ + + function: Callable + compiler: Compiler + + def __init__(self, function: Callable, parameters): + self.function = function # type: ignore + self.compiler = Compiler(self.function, dict(parameters)) + + def __call__(self, *args, **kwargs) -> Any: + self.compiler(*args, **kwargs) + return self.function(*args, **kwargs) + + def trace( + self, + inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, + ) -> Graph: + """ + Trace the function into computation graph. + + Args: + inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): + optional inputset to extend accumulated inputset before bounds measurement + + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite + + Returns: + Graph: + computation graph representing the function prior to MLIR conversion + """ + + return self.compiler.trace(inputset, configuration, artifacts, **kwargs) + + def compile( + self, + inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, + ) -> Circuit: + """ + Compile the function into a circuit. + + Args: + inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): + optional inputset to extend accumulated inputset before bounds measurement + + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite + + Returns: + Circuit: + compiled circuit + """ + + return self.compiler.compile(inputset, configuration, artifacts, **kwargs) def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]): """ - Provide an easy interface for compilation. + Provides an easy interface for the compilation of single-circuit programs. + + Args: + parameters (Mapping[str, Union[str, EncryptionStatus]]): + encryption statuses of the parameters of the function to compile + """ + + def decoration(function: Callable): + return Compilable(function, parameters) + + return decoration + + +def program(): + """ + Provides an easy interface for the compilation of multi-circuit programs. + """ + + def decoration(classs): + circuits = inspect.getmembers_static(classs, lambda x: isinstance(x, CircuitDef)) + circuits_map = {name: circ for (name, circ) in circuits} + returned = ProgramCompiler(circuits_map) + return returned + + return decoration + + +def circuit(parameters: Mapping[str, Union[str, EncryptionStatus]]): + """ + Provides an easy interface to define a circuit within a multi-circuit program. Args: parameters (Mapping[str, Union[str, EncryptionStatus]]): @@ -85,82 +188,6 @@ def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]): """ def decoration(function: Callable): - class Compilable: - """ - Compilable class, to wrap a function and provide methods to trace and compile it. - """ - - function: Callable - compiler: Compiler - - def __init__(self, function: Callable): - self.function = function # type: ignore - self.compiler = Compiler(self.function, dict(parameters)) - - def __call__(self, *args, **kwargs) -> Any: - self.compiler(*args, **kwargs) - return self.function(*args, **kwargs) - - def trace( - self, - inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - configuration: Optional[Configuration] = None, - artifacts: Optional[DebugArtifacts] = None, - **kwargs, - ) -> Graph: - """ - Trace the function into computation graph. - - Args: - inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): - optional inputset to extend accumulated inputset before bounds measurement - - configuration(Optional[Configuration], default = None): - configuration to use - - artifacts (Optional[DebugArtifacts], default = None): - artifacts to store information about the process - - kwargs (Dict[str, Any]): - configuration options to overwrite - - Returns: - Graph: - computation graph representing the function prior to MLIR conversion - """ - - return self.compiler.trace(inputset, configuration, artifacts, **kwargs) - - def compile( - self, - inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - configuration: Optional[Configuration] = None, - artifacts: Optional[DebugArtifacts] = None, - **kwargs, - ) -> Circuit: - """ - Compile the function into a circuit. - - Args: - inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): - optional inputset to extend accumulated inputset before bounds measurement - - configuration(Optional[Configuration], default = None): - configuration to use - - artifacts (Optional[DebugArtifacts], default = None): - artifacts to store information about the process - - kwargs (Dict[str, Any]): - configuration options to overwrite - - Returns: - Circuit: - compiled circuit - """ - - return self.compiler.compile(inputset, configuration, artifacts, **kwargs) - - return Compilable(function) + return CircuitDef(function, parameters) return decoration diff --git a/frontends/concrete-python/concrete/fhe/compilation/program.py b/frontends/concrete-python/concrete/fhe/compilation/program.py new file mode 100644 index 0000000000..a2b2eb6232 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/compilation/program.py @@ -0,0 +1,660 @@ +""" +Declaration of `Program` classes. +""" + +# pylint: disable=import-error,no-member,no-name-in-module + +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union + +import numpy as np +from concrete.compiler import ( + CompilationContext, + Parameter, + SimulatedValueDecrypter, + SimulatedValueExporter, +) +from mlir.ir import Module as MlirModule + +from ..internal.utils import assert_that +from ..representation import Graph +from .client import Client +from .configuration import Configuration +from .keys import Keys +from .server import Server +from .utils import validate_input_args +from .value import Value + +# pylint: enable=import-error,no-member,no-name-in-module + + +class ExecutionRt(NamedTuple): + """ + Runtime object class for execution. + """ + + client: Client + server: Server + + +class SimulationRt(NamedTuple): + """ + Runtime object class for simulation. + """ + + server: Server + + +class ProgramCircuit: + """ + Circuit class, allowing to run or simulate one circuit of a program. + """ + + rt: Union[ExecutionRt, SimulationRt] + graph: Graph + name: str + + def __init__(self, name: str, rt: Union[ExecutionRt, SimulationRt], graph: Graph): + self.name = name + self.rt = rt + self.graph = graph + + def draw( + self, + *, + horizontal: bool = False, + save_to: Optional[Union[Path, str]] = None, + show: bool = False, + ) -> Path: + """ + Draw the graph of the circuit. + + That this function requires the python `pygraphviz` package + which itself requires the installation of `graphviz` packages + + (see https://pygraphviz.github.io/documentation/stable/install.html) + + Args: + horizontal (bool, default = False): + whether to draw horizontally + + save_to (Optional[Path], default = None): + path to save the drawing + a temporary file will be used if it's None + + show (bool, default = False): + whether to show the drawing using matplotlib + + Returns: + Path: + path to the drawing + """ + return self.graph.draw(horizontal=horizontal, save_to=save_to, show=show) + + def __str__(self): + return self.graph.format() + + def simulate(self, *args: Any) -> Any: + """ + Simulate execution of the circuit. + + Args: + *args (Any): + inputs to the circuit + + Returns: + Any: + result of the simulation + """ + assert isinstance(self.rt, SimulationRt) + + ordered_validated_args = validate_input_args( + self.rt.server.client_specs, *args, circuit_name=self.name + ) + + exporter = SimulatedValueExporter.new( + self.rt.server.client_specs.client_parameters, self.name + ) + exported = [ + None + if arg is None + else Value( + exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape)) + if isinstance(arg, np.ndarray) and arg.shape != () + else exporter.export_scalar(position, int(arg)) + ) + for position, arg in enumerate(ordered_validated_args) + ] + + results = self.rt.server.run(*exported, circuit_name=self.name) + if not isinstance(results, tuple): + results = (results,) + + decrypter = SimulatedValueDecrypter.new( + self.rt.server.client_specs.client_parameters, self.name + ) + decrypted = tuple( + decrypter.decrypt(position, result.inner) for position, result in enumerate(results) + ) + + return decrypted if len(decrypted) != 1 else decrypted[0] + + def encrypt( + self, + *args: Optional[Union[int, np.ndarray, List]], + ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: + """ + Encrypt argument(s) to for evaluation. + + Args: + *args (Optional[Union[int, numpy.ndarray, List]]): + argument(s) for evaluation + + Returns: + Optional[Union[Value, Tuple[Optional[Value], ...]]]: + encrypted argument(s) for evaluation + """ + assert isinstance(self.rt, ExecutionRt) + return self.rt.client.encrypt(*args, circuit_name=self.name) + + def run( + self, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...]]: + """ + Evaluate the circuit. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ + assert isinstance(self.rt, ExecutionRt) + return self.rt.server.run( + *args, evaluation_keys=self.rt.client.evaluation_keys, circuit_name=self.name + ) + + def decrypt( + self, + *results: Union[Value, Tuple[Value, ...]], + ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + """ + Decrypt result(s) of evaluation. + + Args: + *results (Union[Value, Tuple[Value, ...]]): + result(s) of evaluation + + Returns: + Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + decrypted result(s) of evaluation + """ + assert isinstance(self.rt, ExecutionRt) + return self.rt.client.decrypt(*results, circuit_name=self.name) + + def encrypt_run_decrypt(self, *args: Any) -> Any: + """ + Encrypt inputs, run the circuit, and decrypt the outputs in one go. + + Args: + *args (Union[int, numpy.ndarray]): + inputs to the circuit + + Returns: + Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + clear result of homomorphic evaluation + """ + return self.decrypt(self.run(self.encrypt(*args))) + + @property + def size_of_inputs(self) -> int: + """ + Get size of the inputs of the circuit. + """ + return self.rt.server.size_of_inputs(self.name) # pragma: no cover + + @property + def size_of_outputs(self) -> int: + """ + Get size of the outputs of the circuit. + """ + return self.rt.server.size_of_outputs(self.name) # pragma: no cover + + # Programmable Bootstrap Statistics + + @property + def programmable_bootstrap_count(self) -> int: + """ + Get the number of programmable bootstraps in the circuit. + """ + return self.rt.server.programmable_bootstrap_count(self.name) # pragma: no cover + + @property + def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of programmable bootstraps per bit width in the circuit. + """ + return self.rt.server.programmable_bootstrap_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of programmable bootstraps per tag in the circuit. + """ + return self.rt.server.programmable_bootstrap_count_per_tag(self.name) # pragma: no cover + + @property + def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]: + """ + Get the number of programmable bootstraps per tag per bit width in the circuit. + """ + return self.rt.server.programmable_bootstrap_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Key Switch Statistics + + @property + def key_switch_count(self) -> int: + """ + Get the number of key switches in the circuit. + """ + return self.rt.server.key_switch_count(self.name) # pragma: no cover + + @property + def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of key switches per parameter in the circuit. + """ + return self.rt.server.key_switch_count_per_parameter(self.name) # pragma: no cover + + @property + def key_switch_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of key switches per tag in the circuit. + """ + return self.rt.server.key_switch_count_per_tag(self.name) # pragma: no cover + + @property + def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of key switches per tag per parameter in the circuit. + """ + return self.rt.server.key_switch_count_per_tag_per_parameter(self.name) # pragma: no cover + + # Packing Key Switch Statistics + + @property + def packing_key_switch_count(self) -> int: + """ + Get the number of packing key switches in the circuit. + """ + return self.rt.server.packing_key_switch_count(self.name) # pragma: no cover + + @property + def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of packing key switches per parameter in the circuit. + """ + return self.rt.server.packing_key_switch_count_per_parameter(self.name) # pragma: no cover + + @property + def packing_key_switch_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of packing key switches per tag in the circuit. + """ + return self.rt.server.packing_key_switch_count_per_tag(self.name) # pragma: no cover + + @property + def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of packing key switches per tag per parameter in the circuit. + """ + return self.rt.server.packing_key_switch_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Clear Addition Statistics + + @property + def clear_addition_count(self) -> int: + """ + Get the number of clear additions in the circuit. + """ + return self.rt.server.clear_addition_count(self.name) # pragma: no cover + + @property + def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear additions per parameter in the circuit. + """ + return self.rt.server.clear_addition_count_per_parameter(self.name) # pragma: no cover + + @property + def clear_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear additions per tag in the circuit. + """ + return self.rt.server.clear_addition_count_per_tag(self.name) # pragma: no cover + + @property + def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear additions per tag per parameter in the circuit. + """ + return self.rt.server.clear_addition_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Encrypted Addition Statistics + + @property + def encrypted_addition_count(self) -> int: + """ + Get the number of encrypted additions in the circuit. + """ + return self.rt.server.encrypted_addition_count(self.name) # pragma: no cover + + @property + def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted additions per parameter in the circuit. + """ + return self.rt.server.encrypted_addition_count_per_parameter(self.name) # pragma: no cover + + @property + def encrypted_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted additions per tag in the circuit. + """ + return self.rt.server.encrypted_addition_count_per_tag(self.name) # pragma: no cover + + @property + def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted additions per tag per parameter in the circuit. + """ + return self.rt.server.encrypted_addition_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Clear Multiplication Statistics + + @property + def clear_multiplication_count(self) -> int: + """ + Get the number of clear multiplications in the circuit. + """ + return self.rt.server.clear_multiplication_count(self.name) # pragma: no cover + + @property + def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear multiplications per parameter in the circuit. + """ + return self.rt.server.clear_multiplication_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def clear_multiplication_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear multiplications per tag in the circuit. + """ + return self.rt.server.clear_multiplication_count_per_tag(self.name) # pragma: no cover + + @property + def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear multiplications per tag per parameter in the circuit. + """ + return self.rt.server.clear_multiplication_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Encrypted Negation Statistics + + @property + def encrypted_negation_count(self) -> int: + """ + Get the number of encrypted negations in the circuit. + """ + return self.rt.server.encrypted_negation_count(self.name) # pragma: no cover + + @property + def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted negations per parameter in the circuit. + """ + return self.rt.server.encrypted_negation_count_per_parameter(self.name) # pragma: no cover + + @property + def encrypted_negation_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted negations per tag in the circuit. + """ + return self.rt.server.encrypted_negation_count_per_tag(self.name) # pragma: no cover + + @property + def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted negations per tag per parameter in the circuit. + """ + return self.rt.server.encrypted_negation_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + @property + def statistics(self) -> Dict: + """ + Get all statistics of the circuit. + """ + attributes = [ + "size_of_inputs", + "size_of_outputs", + "programmable_bootstrap_count", + "programmable_bootstrap_count_per_parameter", + "programmable_bootstrap_count_per_tag", + "programmable_bootstrap_count_per_tag_per_parameter", + "key_switch_count", + "key_switch_count_per_parameter", + "key_switch_count_per_tag", + "key_switch_count_per_tag_per_parameter", + "packing_key_switch_count", + "packing_key_switch_count_per_parameter", + "packing_key_switch_count_per_tag", + "packing_key_switch_count_per_tag_per_parameter", + "clear_addition_count", + "clear_addition_count_per_parameter", + "clear_addition_count_per_tag", + "clear_addition_count_per_tag_per_parameter", + "encrypted_addition_count", + "encrypted_addition_count_per_parameter", + "encrypted_addition_count_per_tag", + "encrypted_addition_count_per_tag_per_parameter", + "clear_multiplication_count", + "clear_multiplication_count_per_parameter", + "clear_multiplication_count_per_tag", + "clear_multiplication_count_per_tag_per_parameter", + "encrypted_negation_count", + "encrypted_negation_count_per_parameter", + "encrypted_negation_count_per_tag", + "encrypted_negation_count_per_tag_per_parameter", + ] + return {attribute: getattr(self, attribute) for attribute in attributes} + + +class Program: + """ + Program class, to combine computation graphs, mlir, runtime objects into a single object. + """ + + configuration: Configuration + graphs: Dict[str, Graph] + mlir_module: MlirModule + compilation_context: CompilationContext + rt: Union[ExecutionRt, SimulationRt] + + def __init__( + self, + graphs: Dict[str, Graph], + mlir: MlirModule, + compilation_context: CompilationContext, + configuration: Optional[Configuration] = None, + ): + assert configuration and (configuration.fhe_simulation or configuration.fhe_execution) + + self.configuration = configuration if configuration is not None else Configuration() + self.graphs = graphs + self.mlir_module = mlir + self.compilation_context = compilation_context + + if self.configuration.fhe_simulation: + server = Server.create( + self.mlir_module, + self.configuration, + is_simulated=True, + compilation_context=self.compilation_context, + ) + self.rt = SimulationRt(server) + else: + server = Server.create( + self.mlir_module, self.configuration, compilation_context=self.compilation_context + ) + + keyset_cache_directory = None + if self.configuration.use_insecure_key_cache: + assert_that(self.configuration.enable_unsafe_features) + assert_that(self.configuration.insecure_key_cache_location is not None) + keyset_cache_directory = self.configuration.insecure_key_cache_location + + client = Client(server.client_specs, keyset_cache_directory) + self.rt = ExecutionRt(client, server) + + @property + def mlir(self) -> str: + """Textual representation of the MLIR module. + + Returns: + str: textual representation of the MLIR module + """ + return str(self.mlir_module).strip() + + @property + def keys(self) -> Keys: + """ + Get the keys of the circuit. + """ + if isinstance(self.rt, ExecutionRt): + return self.rt.client.keys + else: + return None + + @keys.setter + def keys(self, new_keys: Keys): + """ + Set the keys of the circuit. + """ + if isinstance(self.rt, ExecutionRt): + self.rt.client.keys = new_keys + + def keygen( + self, force: bool = False, seed: Optional[int] = None, encryption_seed: Optional[int] = None + ): + """ + Generate keys required for homomorphic evaluation. + + Args: + force (bool, default = False): + whether to generate new keys even if keys are already generated + + seed (Optional[int], default = None): + seed for private keys randomness + + encryption_seed (Optional[int], default = None): + seed for encryption randomness + """ + if isinstance(self.rt, ExecutionRt): + self.rt.client.keygen(force, seed, encryption_seed) + + def cleanup(self): + """ + Cleanup the temporary library output directory. + """ + self.rt.server.cleanup() + + @property + def size_of_secret_keys(self) -> int: + """ + Get size of the secret keys of the program. + """ + return self.rt.server.size_of_secret_keys # pragma: no cover + + @property + def size_of_bootstrap_keys(self) -> int: + """ + Get size of the bootstrap keys of the program. + """ + return self.rt.server.size_of_bootstrap_keys # pragma: no cover + + @property + def size_of_keyswitch_keys(self) -> int: + """ + Get size of the key switch keys of the program. + """ + return self.rt.server.size_of_keyswitch_keys # pragma: no cover + + @property + def p_error(self) -> int: + """ + Get probability of error for each simple TLU (on a scalar). + """ + return self.rt.server.p_error # pragma: no cover + + @property + def global_p_error(self) -> int: + """ + Get the probability of having at least one simple TLU error during the entire execution. + """ + return self.rt.server.global_p_error # pragma: no cover + + @property + def complexity(self) -> float: + """ + Get complexity of the program. + """ + return self.rt.server.complexity # pragma: no cover + + @property + def statistics(self) -> Dict: + """ + Get all statistics of the program. + """ + attributes = [ + "size_of_secret_keys", + "size_of_bootstrap_keys", + "size_of_keyswitch_keys", + "p_error", + "global_p_error", + "complexity", + ] + statistics = {attribute: getattr(self, attribute) for attribute in attributes} + statistics["circuits"] = {name: circuit.statistics() for (name, circuit) in self.circuits()} + return statistics + + def circuits(self) -> Dict[str, ProgramCircuit]: + """ + Returns a dictionnary containing all the circuits of the program. + """ + return {name: self.__getattr__(name) for name in self.graphs.keys()} + + def __getattr__(self, item): + if not item in list(self.graphs.keys()): + raise AttributeError(f"No attribute {item}") + else: + return ProgramCircuit(item, self.rt, self.graphs[item]) diff --git a/frontends/concrete-python/concrete/fhe/compilation/program_compiler.py b/frontends/concrete-python/concrete/fhe/compilation/program_compiler.py new file mode 100644 index 0000000000..5b55fe7f20 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/compilation/program_compiler.py @@ -0,0 +1,481 @@ +""" +Declaration of `MultiCompiler` class. +""" + +# pylint: disable=import-error,no-name-in-module + +import inspect +import os +import traceback +from copy import deepcopy +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +from concrete.compiler import CompilationContext + +from ..extensions import AutoRounder, AutoTruncator +from ..mlir import GraphConverter +from ..representation import Graph +from ..tracing import Tracer +from ..values import ValueDescription +from .artifacts import CircuitDebugArtifacts, ProgramDebugArtifacts +from .circuit import Circuit +from .compiler import EncryptionStatus +from .configuration import Configuration +from .program import ExecutionRt, Program, ProgramCircuit +from .utils import fuse, get_terminal_size + +DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") + +# pylint: enable=import-error,no-name-in-module + + + +class CircuitDef: + """ + An object representing the definition of a circuit as used in a multi-circuit program. + """ + + name: str + function: Callable + parameter_encryption_statuses: Dict[str, EncryptionStatus] + inputset: List[Any] + graph: Optional[Graph] + artifacts: Optional[CircuitDebugArtifacts] + _is_direct: bool + _parameter_values: Dict[str, ValueDescription] + + def __init__( + self, + function: Callable, + parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]], + ): + signature = inspect.signature(function) + + missing_args = list(signature.parameters) + for arg in parameter_encryption_statuses.keys(): + if arg in signature.parameters: + missing_args.remove(arg) + + if len(missing_args) != 0: + parameter_str = repr(missing_args[0]) + for arg in missing_args[1:-1]: + parameter_str += f", {repr(arg)}" + if len(missing_args) != 1: + parameter_str += f" and {repr(missing_args[-1])}" + + message = ( + f"Encryption status{'es' if len(missing_args) > 1 else ''} " + f"of parameter{'s' if len(missing_args) > 1 else ''} " + f"{parameter_str} of function '{function.__name__}' " + f"{'are' if len(missing_args) > 1 else 'is'} not provided" + ) + raise ValueError(message) + + additional_args = list(parameter_encryption_statuses) + for arg in signature.parameters.keys(): + if arg in parameter_encryption_statuses: + additional_args.remove(arg) + + if len(additional_args) != 0: + parameter_str = repr(additional_args[0]) + for arg in additional_args[1:-1]: + parameter_str += f", {repr(arg)}" + if len(additional_args) != 1: + parameter_str += f" and {repr(additional_args[-1])}" + + message = ( + f"Encryption status{'es' if len(additional_args) > 1 else ''} " + f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but " + f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter " + f"of function '{function.__name__}'" + ) + raise ValueError(message) + + self.function = function # type: ignore + self.parameter_encryption_statuses = { + param: EncryptionStatus(status.lower()) + for param, status in parameter_encryption_statuses.items() + } + self.artifacts = None + self.inputset = [] + self.graph = None + self.name = function.__name__ + self._is_direct = False + self._parameter_values = {} + + def _trace(self, sample: Union[Any, Tuple[Any, ...]]): + """ + Trace the function and fuse the resulting graph with a sample input. + + Args: + sample (Union[Any, Tuple[Any, ...]]): + sample to use for tracing + """ + + if self.artifacts is not None: + self.artifacts.add_source_code(self.function) + for param, encryption_status in self.parameter_encryption_statuses.items(): + self.artifacts.add_parameter_encryption_status(param, encryption_status) + + parameters = { + param: ValueDescription.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED)) + for arg, (param, status) in zip( + ( + sample + if len(self.parameter_encryption_statuses) > 1 or isinstance(sample, tuple) + else (sample,) + ), + self.parameter_encryption_statuses.items(), + ) + } + + self.graph = Tracer.trace(self.function, parameters) + if self.artifacts is not None: + self.artifacts.add_graph("initial", self.graph) + + fuse(self.graph, self.artifacts) + + def _evaluate( + self, + action: str, + inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]], + configuration: Configuration, + artifacts: CircuitDebugArtifacts, + ): + """ + Trace, fuse, measure bounds, and update values in the resulting graph in one go. + + Args: + action (str): + action being performed (e.g., "trace", "compile") + + inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): + optional inputset to extend accumulated inputset before bounds measurement + """ + + if self._is_direct: + self.graph = Tracer.trace(self.function, self._parameter_values, is_direct=True) + artifacts.add_graph("initial", self.graph) # pragma: no cover + fuse(self.graph, artifacts) + artifacts.add_graph("final", self.graph) # pragma: no cover + return + + if inputset is not None: + previous_inputset_length = len(self.inputset) + for index, sample in enumerate(iter(inputset)): + self.inputset.append(sample) + + if not isinstance(sample, tuple): + sample = (sample,) + + if len(sample) != len(self.parameter_encryption_statuses): + self.inputset = self.inputset[:previous_inputset_length] + + expected = ( + "a single value" + if len(self.parameter_encryption_statuses) == 1 + else f"a tuple of {len(self.parameter_encryption_statuses)} values" + ) + actual = ( + "a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values" + ) + + message = ( + f"Input #{index} of your inputset is not well formed " + f"(expected {expected} got {actual})" + ) + raise ValueError(message) + + if configuration.auto_adjust_rounders: + AutoRounder.adjust(self.function, self.inputset) + + if configuration.auto_adjust_truncators: + AutoTruncator.adjust(self.function, self.inputset) + + if self.graph is None: + try: + first_sample = next(iter(self.inputset)) + except StopIteration as error: + message = ( + f"{action} function '{self.function.__name__}' " + f"without an inputset is not supported" + ) + raise RuntimeError(message) from error + + self._trace(first_sample) + assert self.graph is not None + + bounds = self.graph.measure_bounds(self.inputset) + self.graph.update_with_bounds(bounds) + + artifacts.add_graph("final", self.graph) + + +class DebugManager: + """ + A debug manager, allowing streamlined debugging. + """ + + configuration: Configuration + begin_call: Callable + + def __init__(self, config: Configuration): + self.configuration = config + is_first = [True] + + def begin_call(): + if is_first[0]: + print() + is_first[0] = False + + self.begin_call = begin_call + + def debug_table(self, title: str, activate: bool = True): + """ + Returns a context manager that prints a table around what is printed inside the scope. + """ + + class DebugTableCm: + def __init__(self, title): + self.title = title + self.columns = get_terminal_size() + + def __enter__(self): + print(f"{self.title}") + print("-" * self.columns) + + def __exit__(self, _exc_type, _exc_value, _exc_tb): + print("-" * self.columns) + print() + + class EmptyCm: + def __enter__(self): + pass + + def __exit__(self, _exc_type, _exc_value, _exc_tb): + pass + + if activate: + self.begin_call() + return DebugTableCm(title) + else: + return EmptyCm() + + def show_graph(self): + return ( + self.configuration.show_graph + if self.configuration.show_graph is not None + else self.configuration.verbose + ) + + def show_bit_width_constraints(self): + return ( + self.configuration.show_bit_width_constraints + if self.configuration.show_bit_width_constraints is not None + else self.configuration.verbose + ) + + def show_bit_width_assignments(self): + return ( + self.configuration.show_bit_width_assignments + if self.configuration.show_bit_width_assignments is not None + else self.configuration.verbose + ) + + def show_assigned_graph(self): + return ( + self.configuration.show_assigned_graph + if self.configuration.show_assigned_graph is not None + else self.configuration.verbose + ) + + def show_mlir(self): + return ( + self.configuration.show_mlir + if self.configuration.show_mlir is not None + else self.configuration.verbose + ) + + def show_optimizer(self): + return ( + self.configuration.show_optimizer + if self.configuration.show_optimizer is not None + else self.configuration.verbose + ) + + def show_statistics(self): + return ( + self.configuration.show_statistics + if self.configuration.show_statistics is not None + else self.configuration.verbose + ) + + def debug_computation_graph(self, name, circuit_graph): + if ( + self.show_graph() + or self.show_bit_width_constraints() + or self.show_bit_width_assignments() + or self.show_assigned_graph() + or self.show_mlir() + or self.show_optimizer() + or self.show_statistics() + ): + if self.show_graph(): + with self.debug_table(f"Computation Graph for {name}"): + print(circuit_graph.format()) + + def debug_bit_width_constaints(self, name, circuit_graph): + if self.show_bit_width_constraints(): + with self.debug_table(f"Bit-Width Constraints for {name}"): + print(circuit_graph.format_bit_width_constraints()) + + def debug_bit_width_assignments(self, name, circuit_graph): + if self.show_bit_width_assignments(): + with self.debug_table(f"Bit-Width Assignments for {name}"): + print(circuit_graph.format_bit_width_assignments()) + + def debug_assigned_graph(self, name, circuit_graph): + if self.show_assigned_graph(): + with self.debug_table(f"Bit-Width Assigned Computation Graph for {name}"): + print(circuit_graph.format(show_assigned_bit_widths=True)) + + def debug_mlir(self, mlir_str): + if self.show_mlir(): + with self.debug_table("MLIR"): + print(mlir_str) + + def debug_statistics(self, program): + if self.show_statistics(): + + def pretty(d, indent=0): # pragma: no cover + if indent > 0: + print("{") + + for key, value in d.items(): + if isinstance(value, dict) and len(value) == 0: + continue + print(" " * indent + str(key) + ": ", end="") + if isinstance(value, dict): + pretty(value, indent + 1) + else: + print(value) + if indent > 0: + print(" " * (indent - 1) + "}") + + with self.debug_table("Whole program statistics"): + pretty(program.statistics) + + for name, circuit in program.circuits().items(): + with self.debug_table(f"{name} circuit statistics"): + pretty(circuit.statistics) + + +class ProgramCompiler: + """ + Compiler class for multiple circuits, to glue the compilation pipeline. + """ + + default_configuration: Configuration + circuits: Dict[str, CircuitDef] + compilation_context: CompilationContext + + def __init__(self, circuits): + self.configuration = Configuration() + self.circuits = circuits + self.compilation_context = CompilationContext.new() + + # pylint: disable=too-many-branches,too-many-statements + + def compile( + self, + inputsets: Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]] = None, + configuration: Optional[Configuration] = None, + program_artifacts: Optional[ProgramDebugArtifacts] = None, + **kwargs, + ) -> Program: + """ + Compile the circuits using an ensemble of inputsets. + + Args: + inputsets (Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]]): + optional inputsets to extend accumulated inputsets before bounds measurement + + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[ProgramDebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite + + Returns: + Program: + compiled program + """ + + configuration = configuration if configuration is not None else self.default_configuration + configuration = deepcopy(configuration) + if len(kwargs) != 0: + configuration = configuration.fork(**kwargs) + program_artifacts = ( + program_artifacts + if program_artifacts is not None + else ProgramDebugArtifacts(list(self.circuits.keys())) + ) + + dbg = DebugManager(configuration) + + try: + # Trace and fuse the circuits + for name, circuit in self.circuits.items(): + inputset = inputsets[name] + circuit_artifacts = program_artifacts.circuits[name] + circuit._evaluate("Compiling", inputset, self.configuration, circuit_artifacts) + assert circuit.graph is not None + dbg.debug_computation_graph(name, circuit.graph) + + # Convert the graphs to an mlir module + mlir_context = self.compilation_context.mlir_context() + graphs = {name: circuit.graph for (name, circuit) in self.circuits.items()} + mlir_module = GraphConverter().convert_many(graphs, self.configuration, mlir_context) + mlir_str = str(mlir_module).strip() + dbg.debug_mlir(mlir_str) + program_artifacts.add_mlir_to_compile(mlir_str) + + # Debug some circuit informations + for name, circuit in self.circuits.items(): + dbg.debug_bit_width_constaints(name, circuit.graph) + dbg.debug_bit_width_assignments(name, circuit.graph) + dbg.debug_assigned_graph(name, circuit.graph) + + # Compile to a program! + with dbg.debug_table("Optimizer", activate=dbg.show_optimizer()): + output = Program(graphs, mlir_module, self.compilation_context, configuration) + if isinstance(output.rt, ExecutionRt): + client_parameters = output.rt.client.specs.client_parameters + program_artifacts.add_client_parameters(client_parameters.serialize()) + + dbg.debug_statistics(output) + + except Exception: # pragma: no cover + # this branch is reserved for unexpected issues and hence it shouldn't be tested + # if it could be tested, we would have fixed the underlying issue + + # if the user desires so, + # we need to export all the information we have about the compilation + + if self.configuration.dump_artifacts_on_unexpected_failures: + program_artifacts.export() + + # traceback_path = self.artifacts.output_directory.joinpath("traceback.txt") + # with open(traceback_path, "w", encoding="utf-8") as f: + # f.write(traceback.format_exc()) + + raise + + return output + + # pylint: enable=too-many-branches,too-many-statements diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 8ec8ea007a..0d2f173ae1 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -17,11 +17,11 @@ CompilationOptions, EvaluationKeys, LibraryCompilationResult, - LibraryLambda, LibrarySupport, Parameter, ProgramCompilationFeedback, PublicArguments, + ServerProgram, set_compiler_logging, set_llvm_debug_flag, ) @@ -60,7 +60,7 @@ class Server: _support: LibrarySupport _compilation_result: LibraryCompilationResult _compilation_feedback: ProgramCompilationFeedback - _server_lambda: LibraryLambda + _server_program: ServerProgram _mlir: Optional[str] _configuration: Optional[Configuration] @@ -71,7 +71,7 @@ def __init__( output_dir: Optional[tempfile.TemporaryDirectory], support: LibrarySupport, compilation_result: LibraryCompilationResult, - server_lambda: LibraryLambda, + server_program: ServerProgram, is_simulated: bool, ): self.client_specs = client_specs @@ -81,7 +81,7 @@ def __init__( self._support = support self._compilation_result = compilation_result self._compilation_feedback = self._support.load_compilation_feedback(compilation_result) - self._server_lambda = server_lambda + self._server_program = server_program self._mlir = None assert_that( @@ -199,7 +199,7 @@ def create( compilation_context is not None ), "must provide compilation context when compiling MlirModule" compilation_result = support.compile(mlir, options, compilation_context) - server_lambda = support.load_server_lambda(compilation_result, is_simulated) + server_program = ServerProgram.load(support, is_simulated) finally: set_llvm_debug_flag(False) set_compiler_logging(False) @@ -212,7 +212,7 @@ def create( output_dir, support, compilation_result, - server_lambda, + server_program, is_simulated, ) @@ -313,16 +313,17 @@ def load(path: Union[str, Path]) -> "Server": generateStaticLib=False, ) compilation_result = support.reload() - server_lambda = support.load_server_lambda(compilation_result, is_simulated) + server_program = ServerProgram.load(support, is_simulated) return Server( - client_specs, output_dir, support, compilation_result, server_lambda, is_simulated + client_specs, output_dir, support, compilation_result, server_program, is_simulated ) def run( self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], evaluation_keys: Optional[EvaluationKeys] = None, + circuit_name: str = "main", ) -> Union[Value, Tuple[Value, ...]]: """ Evaluate. @@ -334,6 +335,9 @@ def run( evaluation_keys (Optional[EvaluationKeys], default = None): evaluation keys required for fhe execution + circuit_name (str): + The name of the circuit to run + Returns: Union[Value, Tuple[Value, ...]]: result(s) of evaluation @@ -363,13 +367,12 @@ def run( buffers.append(arg.inner) public_args = PublicArguments.new(self.client_specs.client_parameters, buffers) + server_circuit = self._server_program.get_server_circuit(circuit_name) if self.is_simulated: - public_result = self._support.simulate(self._server_lambda, public_args) + public_result = server_circuit.simulate(public_args) else: - public_result = self._support.server_call( - self._server_lambda, public_args, evaluation_keys - ) + public_result = server_circuit.call(public_args, evaluation_keys) result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values())) return result if len(result) > 1 else result[0] @@ -403,20 +406,6 @@ def size_of_keyswitch_keys(self) -> int: """ return self._compilation_feedback.total_keyswitch_keys_size - @property - def size_of_inputs(self) -> int: - """ - Get size of the inputs of the compiled program. - """ - return self._compilation_feedback.circuit("main").total_inputs_size - - @property - def size_of_outputs(self) -> int: - """ - Get size of the outputs of the compiled program. - """ - return self._compilation_feedback.circuit("main").total_output_size - @property def p_error(self) -> int: """ @@ -438,43 +427,55 @@ def complexity(self) -> float: """ return self._compilation_feedback.complexity + def size_of_inputs(self, circuit: str = "main") -> int: + """ + Get size of the inputs of the compiled program. + """ + return self._compilation_feedback.circuit(circuit).total_inputs_size + + def size_of_outputs(self, circuit: str = "main") -> int: + """ + Get size of the outputs of the compiled program. + """ + return self._compilation_feedback.circuit(circuit).total_output_size + # Programmable Bootstrap Statistics - @property - def programmable_bootstrap_count(self) -> int: + def programmable_bootstrap_count(self, circuit: str = "main") -> int: """ Get the number of programmable bootstraps in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) - @property - def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: + def programmable_bootstrap_count_per_parameter( + self, circuit: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of programmable bootstraps per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, client_parameters=self.client_specs.client_parameters, ) - @property - def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: + def programmable_bootstrap_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of programmable bootstraps per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) - @property - def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def programmable_bootstrap_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of programmable bootstraps per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, client_parameters=self.client_specs.client_parameters, @@ -482,41 +483,39 @@ def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[P # Key Switch Statistics - @property - def key_switch_count(self) -> int: + def key_switch_count(self, circuit: str = "main") -> int: """ Get the number of key switches in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) - @property - def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + def key_switch_count_per_parameter(self, circuit: str = "main") -> Dict[Parameter, int]: """ Get the number of key switches per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, ) - @property - def key_switch_count_per_tag(self) -> Dict[str, int]: + def key_switch_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of key switches per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) - @property - def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def key_switch_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -524,41 +523,39 @@ def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, in # Packing Key Switch Statistics - @property - def packing_key_switch_count(self) -> int: + def packing_key_switch_count(self, circuit: str = "main") -> int: """ Get the number of packing key switches in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.WOP_PBS} ) - @property - def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + def packing_key_switch_count_per_parameter(self, circuit: str = "main") -> Dict[Parameter, int]: """ Get the number of packing key switches per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, ) - @property - def packing_key_switch_count_per_tag(self) -> Dict[str, int]: + def packing_key_switch_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of packing key switches per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.WOP_PBS} ) - @property - def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def packing_key_switch_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of packing key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -566,41 +563,39 @@ def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Param # Clear Addition Statistics - @property - def clear_addition_count(self) -> int: + def clear_addition_count(self, circuit: str = "main") -> int: """ Get the number of clear additions in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.CLEAR_ADDITION} ) - @property - def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: + def clear_addition_count_per_parameter(self, circuit: str = "main") -> Dict[Parameter, int]: """ Get the number of clear additions per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def clear_addition_count_per_tag(self) -> Dict[str, int]: + def clear_addition_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of clear additions per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.CLEAR_ADDITION}, ) - @property - def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def clear_addition_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of clear additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -608,41 +603,39 @@ def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter # Encrypted Addition Statistics - @property - def encrypted_addition_count(self) -> int: + def encrypted_addition_count(self, circuit: str = "main") -> int: """ Get the number of encrypted additions in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.ENCRYPTED_ADDITION} ) - @property - def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: + def encrypted_addition_count_per_parameter(self, circuit: str = "main") -> Dict[Parameter, int]: """ Get the number of encrypted additions per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def encrypted_addition_count_per_tag(self) -> Dict[str, int]: + def encrypted_addition_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of encrypted additions per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, ) - @property - def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def encrypted_addition_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of encrypted additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -650,41 +643,41 @@ def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Param # Clear Multiplication Statistics - @property - def clear_multiplication_count(self) -> int: + def clear_multiplication_count(self, circuit: str = "main") -> int: """ Get the number of clear multiplications in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) - @property - def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: + def clear_multiplication_count_per_parameter( + self, circuit: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of clear multiplications per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def clear_multiplication_count_per_tag(self) -> Dict[str, int]: + def clear_multiplication_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of clear multiplications per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) - @property - def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def clear_multiplication_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of clear multiplications per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -692,41 +685,39 @@ def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Par # Encrypted Negation Statistics - @property - def encrypted_negation_count(self) -> int: + def encrypted_negation_count(self, circuit: str = "main") -> int: """ Get the number of encrypted negations in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(circuit).count( operations={PrimitiveOperation.ENCRYPTED_NEGATION} ) - @property - def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: + def encrypted_negation_count_per_parameter(self, circuit: str = "main") -> Dict[Parameter, int]: """ Get the number of encrypted negations per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_parameter( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def encrypted_negation_count_per_tag(self) -> Dict[str, int]: + def encrypted_negation_count_per_tag(self, circuit: str = "main") -> Dict[str, int]: """ Get the number of encrypted negations per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(circuit).count_per_tag( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, ) - @property - def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def encrypted_negation_count_per_tag_per_parameter( + self, circuit: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of encrypted negations per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(circuit).count_per_tag_per_parameter( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, diff --git a/frontends/concrete-python/concrete/fhe/compilation/utils.py b/frontends/concrete-python/concrete/fhe/compilation/utils.py index f903b80f4d..df51fafc14 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/utils.py +++ b/frontends/concrete-python/concrete/fhe/compilation/utils.py @@ -23,6 +23,7 @@ def validate_input_args( client_specs: ClientSpecs, *args: Optional[Union[int, np.ndarray, List]], + circuit_name: str = "main", ) -> List[Optional[Union[int, np.ndarray]]]: """Validate input arguments. @@ -31,11 +32,13 @@ def validate_input_args( client specification *args (Optional[Union[int, np.ndarray, List]]): argument(s) for evaluation + circuit_name (str): name of the circuit to verify Returns: List[Optional[Union[int, np.ndarray]]]: ordered validated args """ - client_parameters_json = json.loads(client_specs.client_parameters.serialize())["circuits"][0] + circuits_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"] + client_parameters_json = next(filter(lambda x: x["name"] == circuit_name, circuits_parameters)) assert "inputs" in client_parameters_json input_specs = client_parameters_json["inputs"] if len(args) != len(input_specs): diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 8271cb7428..74ba915952 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -19,7 +19,7 @@ from mlir.ir import Module as MlirModule from ..compilation.configuration import Configuration, Exactness -from ..representation import Graph, Node, Operation +from ..representation import Graph, GraphProcessor, MultiGraphProcessor, Node, Operation from .context import Context from .conversion import Conversion from .processors import * # pylint: disable=wildcard-import @@ -33,18 +33,18 @@ class Converter: Converter class, to convert a computation graph to MLIR. """ - def convert( + def convert_many( self, - graph: Graph, + graphs: Dict[str, Graph], configuration: Configuration, mlir_context: MlirContext, ) -> MlirModule: """ - Convert a computation graph to MLIR. + Convert multiple computation graphs to an MLIR module. Args: - graph (Graph): - graph to convert + graphs (Dict[str, Graph]): + graphs to convert configuration (Configuration): configuration to use @@ -56,46 +56,75 @@ def convert( MlirModule: In-memory MLIR module corresponding to the graph """ - - self.process(graph, configuration) + self.process(graphs, configuration) with mlir_context as context, MlirLocation.unknown(): concrete.lang.register_dialects(context) # pylint: disable=no-member module = MlirModule.create() with MlirInsertionPoint(module.body): - ctx = Context(context, graph, configuration) + for name, graph in graphs.items(): + ctx = Context(context, graph, configuration) + + input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()] + + @func.FuncOp.from_py_func(*input_types, name=name) + def main(*args): + for index, node in enumerate(graph.ordered_inputs()): + conversion = Conversion(node, args[index]) + if "original_bit_width" in node.properties: + conversion.set_original_bit_width( + node.properties["original_bit_width"] + ) + ctx.conversions[node] = conversion + + ordered_nodes = [ + node + for node in nx.lexicographical_topological_sort(graph.graph) + if node.operation != Operation.Input + ] + + for progress_index, node in enumerate(ordered_nodes): + self.trace_progress(configuration, progress_index, ordered_nodes) + preds = [ctx.conversions[pred] for pred in graph.ordered_preds_of(node)] + self.node(ctx, node, preds) + self.trace_progress(configuration, len(ordered_nodes), ordered_nodes) + + outputs = [] + for node in graph.ordered_outputs(): + assert node in ctx.conversions + outputs.append(ctx.conversions[node].result) + + return tuple(outputs) - input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()] + return module - @func.FuncOp.from_py_func(*input_types) - def main(*args): - for index, node in enumerate(graph.ordered_inputs()): - conversion = Conversion(node, args[index]) - if "original_bit_width" in node.properties: - conversion.set_original_bit_width(node.properties["original_bit_width"]) - ctx.conversions[node] = conversion + def convert( + self, + graph: Graph, + configuration: Configuration, + mlir_context: MlirContext, + name: str = "main", + ) -> MlirModule: + """ + Convert a computation graph to MLIR. - ordered_nodes = [ - node - for node in nx.lexicographical_topological_sort(graph.graph) - if node.operation != Operation.Input - ] + Args: + graph (Graph): + graph to convert - for progress_index, node in enumerate(ordered_nodes): - self.trace_progress(configuration, progress_index, ordered_nodes) - preds = [ctx.conversions[pred] for pred in graph.ordered_preds_of(node)] - self.node(ctx, node, preds) - self.trace_progress(configuration, len(ordered_nodes), ordered_nodes) + configuration (Configuration): + configuration to use - outputs = [] - for node in graph.ordered_outputs(): - assert node in ctx.conversions - outputs.append(ctx.conversions[node].result) + mlir_context (MlirContext): + MLIR Context to use for module generation - return tuple(outputs) + Return: + MlirModule: + In-memory MLIR module corresponding to the graph + """ - return module + return self.convert_many({name: graph}, configuration, mlir_context) @staticmethod def stdout_with_ansi_support() -> bool: @@ -172,13 +201,13 @@ def trace_progress(cls, configuration: Configuration, progress_index: int, nodes return concrete.lang.dialects.tracing.TraceMessageOp(msg=msg) # pylint: disable=no-member - def process(self, graph: Graph, configuration: Configuration): + def process(self, graphs: Dict[str, Graph], configuration: Configuration): """ Process a computation graph for MLIR conversion. Args: - graph (Graph): - graph to process + graph (Dict[str, Graph]): + graphs to process configuration (Configuration): configuration to use @@ -201,7 +230,13 @@ def process(self, graph: Graph, configuration: Configuration): ] + configuration.additional_processors for processor in pipeline: - processor.apply(graph) + if isinstance(processor, MultiGraphProcessor): + processor.apply_many(graphs) + elif isinstance(processor, GraphProcessor): + for graph in graphs.values(): + processor.apply(graph) + else: + raise RuntimeError("Unknown processor type.") def node(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: """ diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index 6466f0ac57..676d49683b 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -14,10 +14,10 @@ MultivariateStrategy, ) from ...dtypes import Integer -from ...representation import Graph, GraphProcessor, Node, Operation +from ...representation import Graph, MultiGraphProcessor, Node, Operation -class AssignBitWidths(GraphProcessor): +class AssignBitWidths(MultiGraphProcessor): """ AssignBitWidths graph processor, to assign proper bit-widths to be compatible with FHE. @@ -56,52 +56,53 @@ def __init__( self.multivariate_strategy_preference = multivariate_strategy_preference self.min_max_strategy_preference = min_max_strategy_preference - def apply(self, graph: Graph): + def apply_many(self, graphs: Dict[str, Graph]): optimizer = z3.Optimize() max_bit_width: z3.Int = z3.Int("max") bit_widths: Dict[Node, z3.Int] = {} - additional_constraints = AdditionalConstraints( - optimizer, - graph, - bit_widths, - self.comparison_strategy_preference, - self.bitwise_strategy_preference, - self.shifts_with_promotion, - self.multivariate_strategy_preference, - self.min_max_strategy_preference, - ) + for _, graph in graphs.items(): + additional_constraints = AdditionalConstraints( + optimizer, + graph, + bit_widths, + self.comparison_strategy_preference, + self.bitwise_strategy_preference, + self.shifts_with_promotion, + self.multivariate_strategy_preference, + self.min_max_strategy_preference, + ) - nodes = graph.query_nodes(ordered=True) - for i, node in enumerate(nodes): - assert isinstance(node.output.dtype, Integer) - required_bit_width = node.output.dtype.bit_width + nodes = graph.query_nodes(ordered=True) + for i, node in enumerate(nodes): + assert isinstance(node.output.dtype, Integer) + required_bit_width = node.output.dtype.bit_width - bit_width_hint = node.properties.get("bit_width_hint") - if bit_width_hint is not None: - required_bit_width = max(required_bit_width, bit_width_hint) + bit_width_hint = node.properties.get("bit_width_hint") + if bit_width_hint is not None: + required_bit_width = max(required_bit_width, bit_width_hint) - bit_width = z3.Int(f"%{i}") - bit_widths[node] = bit_width + bit_width = z3.Int(f"%{i}") + bit_widths[node] = bit_width - base_constraint = bit_width >= required_bit_width - node.bit_width_constraints.append(base_constraint) + base_constraint = bit_width >= required_bit_width + node.bit_width_constraints.append(base_constraint) - optimizer.add(base_constraint) - optimizer.add(max_bit_width >= bit_width) + optimizer.add(base_constraint) + optimizer.add(max_bit_width >= bit_width) - additional_constraints.generate_for(node, bit_width) + additional_constraints.generate_for(node, bit_width) - if self.single_precision: - for bit_width in bit_widths.values(): - optimizer.add(bit_width == max_bit_width) + if self.single_precision: + for bit_width in bit_widths.values(): + optimizer.add(bit_width == max_bit_width) - if self.composable: - input_output_bitwidth = z3.Int("input_output") - for node in chain(graph.input_nodes.values(), graph.output_nodes.values()): - bit_width = bit_widths[node] - optimizer.add(bit_width == input_output_bitwidth) + if self.composable: + input_output_bitwidth = z3.Int("input_output") + for node in chain(graph.input_nodes.values(), graph.output_nodes.values()): + bit_width = bit_widths[node] + optimizer.add(bit_width == input_output_bitwidth) optimizer.minimize(sum(bit_width for bit_width in bit_widths.values())) @@ -117,9 +118,9 @@ def apply(self, graph: Graph): node.properties["original_bit_width"] = node.output.dtype.bit_width node.output.dtype.bit_width = new_bit_width - - graph.bit_width_constraints = optimizer - graph.bit_width_assignments = model + for graph in graphs.values(): + graph.bit_width_constraints = optimizer + graph.bit_width_assignments = model class AdditionalConstraints: diff --git a/frontends/concrete-python/concrete/fhe/representation/__init__.py b/frontends/concrete-python/concrete/fhe/representation/__init__.py index 34640a94e1..30825b1f7a 100644 --- a/frontends/concrete-python/concrete/fhe/representation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/representation/__init__.py @@ -2,6 +2,6 @@ Define structures used to represent computation. """ -from .graph import Graph, GraphProcessor +from .graph import Graph, GraphProcessor, MultiGraphProcessor from .node import Node from .operation import Operation diff --git a/frontends/concrete-python/concrete/fhe/representation/graph.py b/frontends/concrete-python/concrete/fhe/representation/graph.py index aa278f43a1..e1bf10c4f1 100644 --- a/frontends/concrete-python/concrete/fhe/representation/graph.py +++ b/frontends/concrete-python/concrete/fhe/representation/graph.py @@ -591,6 +591,9 @@ def sorter(line: str) -> int: if line.startswith("max"): # we won't have 4 million nodes... return 2**32 + if line.startswith("input_output"): + # this is the composable constraint + return 2**32 assert line.startswith("%") @@ -967,6 +970,7 @@ def integer_range( class GraphProcessor(ABC): """ GraphProcessor base class, to define the API for a graph processing pipeline. + Process a single graph at once. """ @abstractmethod @@ -998,3 +1002,22 @@ def error(graph: Graph, highlights: Mapping[Node, Union[str, List[str]]]): highlighted_nodes=highlights_with_location ) raise RuntimeError(message) + + +class MultiGraphProcessor(GraphProcessor): + """ + MultiGraphProcessor base class, to define the API for a multiple graph processing pipeline. + Processes multiple graphs at once. + """ + + @abstractmethod + def apply_many(self, graphs: Dict[str, Graph]): + """ + Process the graphs. + """ + + def apply(self, graph: Graph): + """ + Process a single graph. + """ + return self.apply_many({"main": graph})