From 170002ff20b1bd12838c2f9a95e5f62b53e2f852 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20P=C3=A9r=C3=A9?= Date: Fri, 1 Mar 2024 16:50:12 +0100 Subject: [PATCH] feat(frontend-python): add support for multi-circuits --- .../include/concretelang/Common/Compat.h | 23 +- .../concretelang/ServerLib/ServerLib.h | 1 + .../lib/Bindings/Python/CMakeLists.txt | 2 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 85 ++- .../Python/concrete/compiler/__init__.py | 1 + .../concrete/compiler/server_circuit.py | 88 +++ .../concrete/compiler/server_program.py | 80 ++ .../concrete-python/concrete/fhe/__init__.py | 2 +- .../concrete/fhe/compilation/__init__.py | 4 +- .../concrete/fhe/compilation/artifacts.py | 181 ++++- .../concrete/fhe/compilation/circuit.py | 66 +- .../concrete/fhe/compilation/client.py | 12 +- .../concrete/fhe/compilation/compiler.py | 12 +- .../concrete/fhe/compilation/configuration.py | 5 +- .../concrete/fhe/compilation/decorators.py | 173 +++-- .../concrete/fhe/compilation/module.py | 682 ++++++++++++++++++ .../fhe/compilation/module_compiler.py | 571 +++++++++++++++ .../concrete/fhe/compilation/server.py | 232 +++--- .../concrete/fhe/compilation/utils.py | 11 +- .../concrete/fhe/mlir/converter.py | 109 ++- .../fhe/mlir/processors/assign_bit_widths.py | 80 +- .../concrete/fhe/representation/__init__.py | 2 +- .../concrete/fhe/representation/graph.py | 41 +- .../concrete/fhe/tracing/tracer.py | 6 +- .../tests/compilation/test_artifacts.py | 24 +- .../tests/compilation/test_circuit.py | 5 +- .../tests/compilation/test_configuration.py | 38 +- .../tests/compilation/test_decorators.py | 18 +- .../tests/compilation/test_program.py | 149 ++++ .../tests/mlir/test_converter.py | 4 +- 30 files changed, 2285 insertions(+), 422 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py create mode 100644 compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py create mode 100644 frontends/concrete-python/concrete/fhe/compilation/module.py create mode 100644 frontends/concrete-python/concrete/fhe/compilation/module_compiler.py create mode 100644 frontends/concrete-python/tests/compilation/test_program.py diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h index 37bd436328..cc2579931d 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Compat.h @@ -42,15 +42,19 @@ using concretelang::serverlib::ServerProgram; using concretelang::values::TransportValue; using concretelang::values::Value; -#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \ - auto VARNAME = EXPECTED; \ - if (auto err = VARNAME.takeError()) { \ - throw std::runtime_error(llvm::toString(std::move(err))); \ - } - #define CONCAT(a, b) CONCAT_INNER(a, b) #define CONCAT_INNER(a, b) a##b +#define GET_OR_THROW_EXPECTED_(VARNAME, RESULT, MAYBE) \ + auto MAYBE = RESULT; \ + if (auto err = MAYBE.takeError()) { \ + throw std::runtime_error(llvm::toString(std::move(err))); \ + } \ + VARNAME = std::move(*MAYBE); + +#define GET_OR_THROW_EXPECTED(VARNAME, RESULT) \ + GET_OR_THROW_EXPECTED_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__)) + #define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \ auto MAYBE = RESULT; \ if (MAYBE.has_failure()) { \ @@ -370,6 +374,13 @@ class LibrarySupport { useSimulation}; } + llvm::Expected + loadServerProgram(LibraryCompilationResult &result, bool useSimulation) { + EXPECTED_TRY(auto programInfo, getProgramInfo()); + return outcomeToExpected(ServerProgram::load( + programInfo.asReader(), getSharedLibPath(), useSimulation)); + } + /// Load the client parameters from the compilation result. llvm::Expected<::concretelang::clientlib::ClientParameters> loadClientParameters(LibraryCompilationResult &result) { diff --git a/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h b/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h index 26f2d05942..476b0bae70 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/ServerLib/ServerLib.h @@ -50,6 +50,7 @@ class ServerCircuit { Result> call(const ServerKeyset &serverKeyset, std::vector &args); + /// Simulate the circuit with public arguments. Result> simulate(std::vector &args); diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt index 470c0b5378..d3c2e4a7c3 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt @@ -49,6 +49,8 @@ declare_mlir_python_sources( concrete/compiler/parameter.py concrete/compiler/public_arguments.py concrete/compiler/public_result.py + concrete/compiler/server_circuit.py + concrete/compiler/server_program.py concrete/compiler/evaluation_keys.py concrete/compiler/simulated_value_decrypter.py concrete/compiler/simulated_value_exporter.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index fcf33d2047..5388396ff1 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -11,6 +11,7 @@ #include "concretelang/Common/Keysets.h" #include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc" #include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/ServerLib/ServerLib.h" #include "concretelang/Support/logging.h" #include #include @@ -79,9 +80,9 @@ library_compile(LibrarySupport_Py support, const char *module, llvm::SourceMgr sm; sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module), llvm::SMLoc()); - GET_OR_THROW_LLVM_EXPECTED(compilationResult, - support.support.compile(sm, options)); - return std::move(*compilationResult); + GET_OR_THROW_EXPECTED(auto compilationResult, + support.support.compile(sm, options)); + return compilationResult; } std::unique_ptr @@ -89,36 +90,36 @@ library_compile_module( LibrarySupport_Py support, mlir::ModuleOp module, mlir::concretelang::CompilationOptions options, std::shared_ptr cctx) { - GET_OR_THROW_LLVM_EXPECTED(compilationResult, - support.support.compile(module, cctx, options)); - return std::move(*compilationResult); + GET_OR_THROW_EXPECTED(auto compilationResult, + support.support.compile(module, cctx, options)); + return compilationResult; } concretelang::clientlib::ClientParameters library_load_client_parameters( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(clientParameters, - support.support.loadClientParameters(result)); - return *clientParameters; + GET_OR_THROW_EXPECTED(auto clientParameters, + support.support.loadClientParameters(result)); + return clientParameters; } mlir::concretelang::ProgramCompilationFeedback library_load_compilation_feedback( LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result) { - GET_OR_THROW_LLVM_EXPECTED(compilationFeedback, - support.support.loadCompilationFeedback(result)); - return *compilationFeedback; + GET_OR_THROW_EXPECTED(auto compilationFeedback, + support.support.loadCompilationFeedback(result)); + return compilationFeedback; } concretelang::serverlib::ServerLambda library_load_server_lambda(LibrarySupport_Py support, mlir::concretelang::LibraryCompilationResult &result, std::string circuitName, bool useSimulation) { - GET_OR_THROW_LLVM_EXPECTED( - serverLambda, + GET_OR_THROW_EXPECTED( + auto serverLambda, support.support.loadServerLambda(result, circuitName, useSimulation)); - return *serverLambda; + return serverLambda; } std::unique_ptr @@ -126,18 +127,18 @@ library_server_call(LibrarySupport_Py support, concretelang::serverlib::ServerLambda lambda, concretelang::clientlib::PublicArguments &args, concretelang::clientlib::EvaluationKeys &evaluationKeys) { - GET_OR_THROW_LLVM_EXPECTED( - publicResult, support.support.serverCall(lambda, args, evaluationKeys)); - return std::move(*publicResult); + GET_OR_THROW_EXPECTED(auto publicResult, support.support.serverCall( + lambda, args, evaluationKeys)); + return publicResult; } std::unique_ptr library_simulate(LibrarySupport_Py support, concretelang::serverlib::ServerLambda lambda, concretelang::clientlib::PublicArguments &args) { - GET_OR_THROW_LLVM_EXPECTED(publicResult, - support.support.simulate(lambda, args)); - return std::move(*publicResult); + GET_OR_THROW_EXPECTED(auto publicResult, + support.support.simulate(lambda, args)); + return publicResult; } std::string library_get_shared_lib_path(LibrarySupport_Py support) { @@ -1198,6 +1199,48 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( return pybind11::bytes(valueSerialize(value)); }); + pybind11::class_(m, "ServerProgram") + .def_static("load", + [](LibrarySupport_Py &support, bool useSimulation) { + GET_OR_THROW_EXPECTED(auto programInfo, + support.support.getProgramInfo()); + auto sharedLibPath = support.support.getSharedLibPath(); + GET_OR_THROW_RESULT( + auto result, + ServerProgram::load(programInfo.asReader(), + sharedLibPath, useSimulation)); + return result; + }) + .def("get_server_circuit", [](ServerProgram &program, + const std::string &circuitName) { + GET_OR_THROW_RESULT(auto result, program.getServerCircuit(circuitName)); + return result; + }); + + pybind11::class_(m, "ServerCircuit") + .def("call", + [](ServerCircuit &circuit, + ::concretelang::clientlib::PublicArguments &publicArguments, + ::concretelang::clientlib::EvaluationKeys &evaluationKeys) { + SignalGuard signalGuard; + auto keyset = evaluationKeys.keyset; + auto values = publicArguments.values; + GET_OR_THROW_RESULT(auto output, circuit.call(keyset, values)); + ::concretelang::clientlib::PublicResult res{output}; + return std::make_unique<::concretelang::clientlib::PublicResult>( + std::move(res)); + }) + .def("simulate", + [](ServerCircuit &circuit, + ::concretelang::clientlib::PublicArguments &publicArguments) { + pybind11::gil_scoped_release release; + auto values = publicArguments.values; + GET_OR_THROW_RESULT(auto output, circuit.simulate(values)); + ::concretelang::clientlib::PublicResult res{output}; + return std::make_unique<::concretelang::clientlib::PublicResult>( + std::move(res)); + }); + pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter") .def_static( "create", diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index 9532b0a797..d122706dc4 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -37,6 +37,7 @@ from .simulated_value_decrypter import SimulatedValueDecrypter from .simulated_value_exporter import SimulatedValueExporter from .parameter import Parameter +from .server_program import ServerProgram def init_dfr(): diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py new file mode 100644 index 0000000000..cfd2be2686 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_circuit.py @@ -0,0 +1,88 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. + +"""ServerCircuit.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + ServerCircuit as _ServerCircuit, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp +from .public_arguments import PublicArguments +from .public_result import PublicResult +from .evaluation_keys import EvaluationKeys + + +class ServerCircuit(WrapperCpp): + """ServerCircuit references a circuit that can be called for execution and simulation.""" + + def __init__(self, server_circuit: _ServerCircuit): + """Wrap the native Cpp object. + + Args: + server_circuit (_ServerCircuit): object to wrap + + Raises: + TypeError: if server_circuit is not of type _ServerCircuit + """ + if not isinstance(server_circuit, _ServerCircuit): + raise TypeError( + f"server_circuit must be of type _ServerCircuit, not {type(server_circuit)}" + ) + super().__init__(server_circuit) + + def call( + self, + public_arguments: PublicArguments, + evaluation_keys: EvaluationKeys, + ) -> PublicResult: + """Executes the circuit on the public arguments. + + Args: + public_arguments (PublicArguments): public arguments to execute on + execution_keys (EvaluationKeys): evaluation keys to use for execution. + + Raises: + TypeError: if public_arguments is not of type PublicArguments, or if evaluation_keys is + not of type EvaluationKeys + + Returns: + PublicResult: A public result object containing the results. + """ + if not isinstance(public_arguments, PublicArguments): + raise TypeError( + f"public_arguments must be of type PublicArguments, not " + f"{type(public_arguments)}" + ) + if not isinstance(evaluation_keys, EvaluationKeys): + raise TypeError( + f"simulation must be of type EvaluationKeys, not " + f"{type(evaluation_keys)}" + ) + return PublicResult.wrap( + self.cpp().call(public_arguments.cpp(), evaluation_keys.cpp()) + ) + + def simulate( + self, + public_arguments: PublicArguments, + ) -> PublicResult: + """Simulates the circuit on the public arguments. + + Args: + public_arguments (PublicArguments): public arguments to execute on + + Raises: + TypeError: if public_arguments is not of type PublicArguments + + Returns: + PublicResult: A public result object containing the results. + """ + if not isinstance(public_arguments, PublicArguments): + raise TypeError( + f"public_arguments must be of type PublicArguments, not " + f"{type(public_arguments)}" + ) + return PublicResult.wrap(self.cpp().simulate(public_arguments.cpp())) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py new file mode 100644 index 0000000000..16b6138761 --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/server_program.py @@ -0,0 +1,80 @@ +# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions. +# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information. + +"""ServerProgram.""" + +# pylint: disable=no-name-in-module,import-error +from mlir._mlir_libs._concretelang._compiler import ( + ServerProgram as _ServerProgram, +) + +# pylint: enable=no-name-in-module,import-error +from .wrapper import WrapperCpp +from .library_support import LibrarySupport +from .server_circuit import ServerCircuit + + +class ServerProgram(WrapperCpp): + """ServerProgram references compiled circuit objects.""" + + def __init__(self, server_program: _ServerProgram): + """Wrap the native Cpp object. + + Args: + server_program (_ServerProgram): object to wrap + + Raises: + TypeError: if server_program is not of type _ServerProgram + """ + if not isinstance(server_program, _ServerProgram): + raise TypeError( + f"server_program must be of type _ServerProgram, not {type(server_program)}" + ) + super().__init__(server_program) + + @staticmethod + def load( + library_support: LibrarySupport, + simulation: bool, + ) -> "ServerProgram": + """Loads the server program from a library support. + + Args: + library_support (LibrarySupport): library support + simulation (bool): use simulation for execution + + Raises: + TypeError: if library_support is not of type LibrarySupport, or if simulation is not of type bool + + Returns: + ServerProgram: A server program object containing references to circuits for calls. + """ + if not isinstance(library_support, LibrarySupport): + raise TypeError( + f"library_support must be of type LibrarySupport, not " + f"{type(library_support)}" + ) + if not isinstance(simulation, bool): + raise TypeError( + f"simulation must be of type bool, not " f"{type(simulation)}" + ) + return ServerProgram.wrap( + _ServerProgram.load(library_support.cpp(), simulation) + ) + + def get_server_circuit(self, circuit_name: str) -> ServerCircuit: + """Returns a given circuit if it is part of the program. + + Args: + circuit_name (str): name of the circuit to retrieve. + + Raises: + TypeError: if circuit_name is not of type str + RuntimeError: if the circuit is not part of the program + """ + if not isinstance(circuit_name, str): + raise TypeError( + f"circuit_name must be of type str, not {type(circuit_name)}" + ) + + return ServerCircuit.wrap(self.cpp().get_server_circuit(circuit_name)) diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 0f336c6d10..8bf94de5d3 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -29,7 +29,7 @@ Value, inputset, ) -from .compilation.decorators import circuit, compiler +from .compilation.decorators import circuit, compiler, function, module from .dtypes import Integer from .extensions import ( AutoRounder, diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index d739586171..ca7a5f02d1 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -2,7 +2,7 @@ Glue the compilation process together. """ -from .artifacts import DebugArtifacts +from .artifacts import DebugArtifacts, FunctionDebugArtifacts, ModuleDebugArtifacts from .circuit import Circuit from .client import Client from .compiler import Compiler, EncryptionStatus @@ -20,6 +20,8 @@ ParameterSelectionStrategy, ) from .keys import Keys +from .module import FheFunction, FheModule +from .module_compiler import FunctionDef, ModuleCompiler from .server import Server from .specs import ClientSpecs from .utils import inputset diff --git a/frontends/concrete-python/concrete/fhe/compilation/artifacts.py b/frontends/concrete-python/concrete/fhe/compilation/artifacts.py index a1b7994813..5a182d084a 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/artifacts.py +++ b/frontends/concrete-python/concrete/fhe/compilation/artifacts.py @@ -14,29 +14,21 @@ DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") -class DebugArtifacts: +class FunctionDebugArtifacts: """ - DebugArtifacts class, to export information about the compilation process. + An object containing debug artifacts for a certain function in an fhe module. """ - output_directory: Path - source_code: Optional[str] parameter_encryption_statuses: Dict[str, str] textual_representations_of_graphs: Dict[str, List[str]] final_graph: Optional[Graph] - mlir_to_compile: Optional[str] - client_parameters: Optional[bytes] - - def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY): - self.output_directory = Path(output_directory) + def __init__(self): self.source_code = None self.parameter_encryption_statuses = {} self.textual_representations_of_graphs = {} self.final_graph = None - self.mlir_to_compile = None - self.client_parameters = None def add_source_code(self, function: Union[str, Callable]): """ @@ -46,7 +38,6 @@ def add_source_code(self, function: Union[str, Callable]): function (Union[str, Callable]): either the source code of the function or the function itself """ - try: self.source_code = ( function if isinstance(function, str) else inspect.getsource(function) @@ -65,7 +56,6 @@ def add_parameter_encryption_status(self, name: str, encryption_status: str): encryption_status (str): encryption status of the parameter """ - self.parameter_encryption_statuses[name] = encryption_status def add_graph(self, name: str, graph: Graph): @@ -79,15 +69,33 @@ def add_graph(self, name: str, graph: Graph): graph (Graph): a representation of the function being compiled """ - if name not in self.textual_representations_of_graphs: self.textual_representations_of_graphs[name] = [] - textual_representation = graph.format() self.textual_representations_of_graphs[name].append(textual_representation) - self.final_graph = graph + +class ModuleDebugArtifacts: + """ + An object containing debug artifacts for an fhe module. + """ + + output_directory: Path + mlir_to_compile: Optional[str] + client_parameters: Optional[bytes] + functions: Dict[str, FunctionDebugArtifacts] + + def __init__( + self, + function_names: List[str], + output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY, + ): + self.output_directory = Path(output_directory) + self.mlir_to_compile = None + self.client_parameters = None + self.functions = {name: FunctionDebugArtifacts() for name in function_names} + def add_mlir_to_compile(self, mlir: str): """ Add textual representation of the resulting MLIR. @@ -96,7 +104,6 @@ def add_mlir_to_compile(self, mlir: str): mlir (str): textual representation of the resulting MLIR """ - self.mlir_to_compile = mlir def add_client_parameters(self, client_parameters: bytes): @@ -113,7 +120,6 @@ def export(self): """ Export the collected information to `self.output_directory`. """ - # pylint: disable=too-many-branches output_directory = self.output_directory @@ -164,27 +170,35 @@ def export(self): f.write(f"{name}=={version}\n") - if self.source_code is not None: - with open(output_directory.joinpath("function.txt"), "w", encoding="utf-8") as f: - f.write(self.source_code) - - if len(self.parameter_encryption_statuses) > 0: - with open(output_directory.joinpath("parameters.txt"), "w", encoding="utf-8") as f: - for name, parameter in self.parameter_encryption_statuses.items(): - f.write(f"{name} :: {parameter}\n") - - identifier = 0 - - textual_representations = self.textual_representations_of_graphs.items() - for name, representations in textual_representations: - for representation in representations: - identifier += 1 - output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt") - with open(output_path, "w", encoding="utf-8") as f: - f.write(f"{representation}\n") + for function_name, function in self.functions.items(): + if function.source_code is not None: + with open( + output_directory.joinpath(f"{function_name}.txt"), "w", encoding="utf-8" + ) as f: + f.write(function.source_code) + + if len(function.parameter_encryption_statuses) > 0: + with open( + output_directory.joinpath(f"{function_name}.parameters.txt"), + "w", + encoding="utf-8", + ) as f: + for name, parameter in function.parameter_encryption_statuses.items(): + f.write(f"{name} :: {parameter}\n") + + identifier = 0 + + textual_representations = function.textual_representations_of_graphs.items() + for name, representations in textual_representations: + for representation in representations: + identifier += 1 + output_path = output_directory.joinpath( + f"{function_name}.{identifier}.{name}.graph.txt" + ) + with open(output_path, "w", encoding="utf-8") as f: + f.write(f"{representation}\n") if self.mlir_to_compile is not None: - assert self.final_graph is not None with open(output_directory.joinpath("mlir.txt"), "w", encoding="utf-8") as f: f.write(f"{self.mlir_to_compile}\n") @@ -193,3 +207,96 @@ def export(self): f.write(self.client_parameters) # pylint: enable=too-many-branches + + +class DebugArtifacts: + """ + DebugArtifacts class, to export information about the compilation process for single function. + """ + + module_artifacts: ModuleDebugArtifacts + + def __init__(self, output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY): + self.module_artifacts = ModuleDebugArtifacts(["main"], output_directory) + + def add_source_code(self, function: Union[str, Callable]): + """ + Add source code of the function being compiled. + + Args: + function (Union[str, Callable]): + either the source code of the function or the function itself + """ + self.module_artifacts.functions["main"].add_source_code(function) + + def add_parameter_encryption_status(self, name: str, encryption_status: str): + """ + Add parameter encryption status of a parameter of the function being compiled. + + Args: + name (str): + name of the parameter + + encryption_status (str): + encryption status of the parameter + """ + + self.module_artifacts.functions["main"].add_parameter_encryption_status( + name, encryption_status + ) + + def add_graph(self, name: str, graph: Graph): + """ + Add a representation of the function being compiled. + + Args: + name (str): + name of the graph (e.g., initial, optimized, final) + + graph (Graph): + a representation of the function being compiled + """ + + self.module_artifacts.functions["main"].add_graph(name, graph) + + def add_mlir_to_compile(self, mlir: str): + """ + Add textual representation of the resulting MLIR. + + Args: + mlir (str): + textual representation of the resulting MLIR + """ + + self.module_artifacts.add_mlir_to_compile(mlir) + + def add_client_parameters(self, client_parameters: bytes): + """ + Add client parameters used. + + Args: + client_parameters (bytes): client parameters + """ + + self.module_artifacts.add_client_parameters(client_parameters) + + def export(self): + """ + Export the collected information to `self.output_directory`. + """ + + self.module_artifacts.export() + + @property + def output_directory(self) -> Path: + """ + Return the directory to export artifacts to. + """ + return self.module_artifacts.output_directory + + @property + def mlir_to_compile(self) -> Optional[str]: + """ + Return the mlir string. + """ + return self.module_artifacts.mlir_to_compile diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 9e457edd6a..d4fc79fb36 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -359,14 +359,14 @@ def size_of_inputs(self) -> int: """ Get size of the inputs of the circuit. """ - return self._property("size_of_inputs") # pragma: no cover + return self._property("size_of_inputs")() # pragma: no cover @property def size_of_outputs(self) -> int: """ Get size of the outputs of the circuit. """ - return self._property("size_of_outputs") # pragma: no cover + return self._property("size_of_outputs")() # pragma: no cover @property def p_error(self) -> int: @@ -396,21 +396,21 @@ def programmable_bootstrap_count(self) -> int: """ Get the number of programmable bootstraps in the circuit. """ - return self._property("programmable_bootstrap_count") # pragma: no cover + return self._property("programmable_bootstrap_count")() # pragma: no cover @property def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of programmable bootstraps per bit width in the circuit. """ - return self._property("programmable_bootstrap_count_per_parameter") # pragma: no cover + return self._property("programmable_bootstrap_count_per_parameter")() # pragma: no cover @property def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: """ Get the number of programmable bootstraps per tag in the circuit. """ - return self._property("programmable_bootstrap_count_per_tag") # pragma: no cover + return self._property("programmable_bootstrap_count_per_tag")() # pragma: no cover @property def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]: @@ -419,7 +419,7 @@ def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[i """ return self._property( "programmable_bootstrap_count_per_tag_per_parameter" - ) # pragma: no cover + )() # pragma: no cover # Key Switch Statistics @@ -428,28 +428,28 @@ def key_switch_count(self) -> int: """ Get the number of key switches in the circuit. """ - return self._property("key_switch_count") # pragma: no cover + return self._property("key_switch_count")() # pragma: no cover @property def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of key switches per parameter in the circuit. """ - return self._property("key_switch_count_per_parameter") # pragma: no cover + return self._property("key_switch_count_per_parameter")() # pragma: no cover @property def key_switch_count_per_tag(self) -> Dict[str, int]: """ Get the number of key switches per tag in the circuit. """ - return self._property("key_switch_count_per_tag") # pragma: no cover + return self._property("key_switch_count_per_tag")() # pragma: no cover @property def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ Get the number of key switches per tag per parameter in the circuit. """ - return self._property("key_switch_count_per_tag_per_parameter") # pragma: no cover + return self._property("key_switch_count_per_tag_per_parameter")() # pragma: no cover # Packing Key Switch Statistics @@ -458,28 +458,30 @@ def packing_key_switch_count(self) -> int: """ Get the number of packing key switches in the circuit. """ - return self._property("packing_key_switch_count") # pragma: no cover + return self._property("packing_key_switch_count")() # pragma: no cover @property def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of packing key switches per parameter in the circuit. """ - return self._property("packing_key_switch_count_per_parameter") # pragma: no cover + return self._property("packing_key_switch_count_per_parameter")() # pragma: no cover @property def packing_key_switch_count_per_tag(self) -> Dict[str, int]: """ Get the number of packing key switches per tag in the circuit. """ - return self._property("packing_key_switch_count_per_tag") # pragma: no cover + return self._property("packing_key_switch_count_per_tag")() # pragma: no cover @property def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ Get the number of packing key switches per tag per parameter in the circuit. """ - return self._property("packing_key_switch_count_per_tag_per_parameter") # pragma: no cover + return self._property( + "packing_key_switch_count_per_tag_per_parameter" + )() # pragma: no cover # Clear Addition Statistics @@ -488,28 +490,28 @@ def clear_addition_count(self) -> int: """ Get the number of clear additions in the circuit. """ - return self._property("clear_addition_count") # pragma: no cover + return self._property("clear_addition_count")() # pragma: no cover @property def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of clear additions per parameter in the circuit. """ - return self._property("clear_addition_count_per_parameter") # pragma: no cover + return self._property("clear_addition_count_per_parameter")() # pragma: no cover @property def clear_addition_count_per_tag(self) -> Dict[str, int]: """ Get the number of clear additions per tag in the circuit. """ - return self._property("clear_addition_count_per_tag") # pragma: no cover + return self._property("clear_addition_count_per_tag")() # pragma: no cover @property def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ Get the number of clear additions per tag per parameter in the circuit. """ - return self._property("clear_addition_count_per_tag_per_parameter") # pragma: no cover + return self._property("clear_addition_count_per_tag_per_parameter")() # pragma: no cover # Encrypted Addition Statistics @@ -518,28 +520,30 @@ def encrypted_addition_count(self) -> int: """ Get the number of encrypted additions in the circuit. """ - return self._property("encrypted_addition_count") # pragma: no cover + return self._property("encrypted_addition_count")() # pragma: no cover @property def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of encrypted additions per parameter in the circuit. """ - return self._property("encrypted_addition_count_per_parameter") # pragma: no cover + return self._property("encrypted_addition_count_per_parameter")() # pragma: no cover @property def encrypted_addition_count_per_tag(self) -> Dict[str, int]: """ Get the number of encrypted additions per tag in the circuit. """ - return self._property("encrypted_addition_count_per_tag") # pragma: no cover + return self._property("encrypted_addition_count_per_tag")() # pragma: no cover @property def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ Get the number of encrypted additions per tag per parameter in the circuit. """ - return self._property("encrypted_addition_count_per_tag_per_parameter") # pragma: no cover + return self._property( + "encrypted_addition_count_per_tag_per_parameter" + )() # pragma: no cover # Clear Multiplication Statistics @@ -548,21 +552,21 @@ def clear_multiplication_count(self) -> int: """ Get the number of clear multiplications in the circuit. """ - return self._property("clear_multiplication_count") # pragma: no cover + return self._property("clear_multiplication_count")() # pragma: no cover @property def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of clear multiplications per parameter in the circuit. """ - return self._property("clear_multiplication_count_per_parameter") # pragma: no cover + return self._property("clear_multiplication_count_per_parameter")() # pragma: no cover @property def clear_multiplication_count_per_tag(self) -> Dict[str, int]: """ Get the number of clear multiplications per tag in the circuit. """ - return self._property("clear_multiplication_count_per_tag") # pragma: no cover + return self._property("clear_multiplication_count_per_tag")() # pragma: no cover @property def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: @@ -571,7 +575,7 @@ def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Par """ return self._property( "clear_multiplication_count_per_tag_per_parameter" - ) # pragma: no cover + )() # pragma: no cover # Encrypted Negation Statistics @@ -580,28 +584,30 @@ def encrypted_negation_count(self) -> int: """ Get the number of encrypted negations in the circuit. """ - return self._property("encrypted_negation_count") # pragma: no cover + return self._property("encrypted_negation_count")() # pragma: no cover @property def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: """ Get the number of encrypted negations per parameter in the circuit. """ - return self._property("encrypted_negation_count_per_parameter") # pragma: no cover + return self._property("encrypted_negation_count_per_parameter")() # pragma: no cover @property def encrypted_negation_count_per_tag(self) -> Dict[str, int]: """ Get the number of encrypted negations per tag in the circuit. """ - return self._property("encrypted_negation_count_per_tag") # pragma: no cover + return self._property("encrypted_negation_count_per_tag")() # pragma: no cover @property def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ Get the number of encrypted negations per tag per parameter in the circuit. """ - return self._property("encrypted_negation_count_per_tag_per_parameter") # pragma: no cover + return self._property( + "encrypted_negation_count_per_tag_per_parameter" + )() # pragma: no cover # All Statistics diff --git a/frontends/concrete-python/concrete/fhe/compilation/client.py b/frontends/concrete-python/concrete/fhe/compilation/client.py index 83ed7a1a45..c77e13aa18 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/client.py +++ b/frontends/concrete-python/concrete/fhe/compilation/client.py @@ -120,6 +120,7 @@ def keygen( def encrypt( self, *args: Optional[Union[int, np.ndarray, List]], + function_name: str = "main", ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: """ Encrypt argument(s) to for evaluation. @@ -127,18 +128,20 @@ def encrypt( Args: *args (Optional[Union[int, np.ndarray, List]]): argument(s) for evaluation + function_name (str): + name of the function to encrypt Returns: Optional[Union[Value, Tuple[Optional[Value], ...]]]: encrypted argument(s) for evaluation """ - ordered_sanitized_args = validate_input_args(self.specs, *args) + ordered_sanitized_args = validate_input_args(self.specs, *args, function_name=function_name) self.keygen(force=False) keyset = self.keys._keyset # pylint: disable=protected-access - exporter = ValueExporter.new(keyset, self.specs.client_parameters) + exporter = ValueExporter.new(keyset, self.specs.client_parameters, function_name) exported = [ None if arg is None @@ -155,6 +158,7 @@ def encrypt( def decrypt( self, *results: Union[Value, Tuple[Value, ...]], + function_name: str = "main", ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: """ Decrypt result(s) of evaluation. @@ -162,6 +166,8 @@ def decrypt( Args: *results (Union[Value, Tuple[Value, ...]]): result(s) of evaluation + function_name (str): + name of the function to decrypt for Returns: Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: @@ -179,7 +185,7 @@ def decrypt( self.keygen(force=False) keyset = self.keys._keyset # pylint: disable=protected-access - decrypter = ValueDecrypter.new(keyset, self.specs.client_parameters) + decrypter = ValueDecrypter.new(keyset, self.specs.client_parameters, function_name) decrypted = tuple( decrypter.decrypt(position, result.inner) for position, result in enumerate(flattened_results) diff --git a/frontends/concrete-python/concrete/fhe/compilation/compiler.py b/frontends/concrete-python/concrete/fhe/compilation/compiler.py index 20df707785..9125ad1124 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/compiler.py @@ -220,8 +220,10 @@ def _trace(self, sample: Union[Any, Tuple[Any, ...]]): self.graph = Tracer.trace(self.function, parameters) if self.artifacts is not None: self.artifacts.add_graph("initial", self.graph) - - fuse(self.graph, self.artifacts) + fuse( + self.graph, + self.artifacts.module_artifacts.functions["main"] if self.artifacts else None, + ) def _evaluate( self, @@ -244,7 +246,11 @@ def _evaluate( if self.artifacts is not None: self.artifacts.add_graph("initial", self.graph) # pragma: no cover - fuse(self.graph, self.artifacts) + fuse( + self.graph, + self.artifacts.module_artifacts.functions["main"] if self.artifacts else None, + ) + if self.artifacts is not None: self.artifacts.add_graph("final", self.graph) # pragma: no cover diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 4b7c1b0778..ba0b643240 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -1294,10 +1294,9 @@ def _validate(self): raise RuntimeError(message) if ( - self.composable - and self.parameter_selection_strategy != ParameterSelectionStrategy.MULTI + self.composable and self.parameter_selection_strategy == ParameterSelectionStrategy.MONO ): # pragma: no cover - message = "Composition can only be used with MULTI parameter selection strategy" + message = "Composition can not be used with MONO parameter selection strategy" raise RuntimeError(message) diff --git a/frontends/concrete-python/concrete/fhe/compilation/decorators.py b/frontends/concrete-python/concrete/fhe/compilation/decorators.py index 1e4066effe..3bdaaf23d6 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/decorators.py +++ b/frontends/concrete-python/concrete/fhe/compilation/decorators.py @@ -2,6 +2,7 @@ Declaration of `circuit` and `compiler` decorators. """ +import functools import inspect from copy import deepcopy from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union @@ -13,6 +14,7 @@ from .circuit import Circuit from .compiler import Compiler, EncryptionStatus from .configuration import Configuration +from .module_compiler import FunctionDef, ModuleCompiler def circuit( @@ -22,7 +24,7 @@ def circuit( **kwargs, ): """ - Provide a direct interface for compilation. + Provide a direct interface for compilation of single circuit programs. Args: parameters (Mapping[str, Union[str, EncryptionStatus]]): @@ -38,8 +40,8 @@ def circuit( configuration options to overwrite """ - def decoration(function: Callable): - signature = inspect.signature(function) + def decoration(function_: Callable): + signature = inspect.signature(function_) parameter_values: Dict[str, ValueDescription] = {} for name, details in signature.parameters.items(): @@ -70,97 +72,128 @@ def decoration(function: Callable): status = EncryptionStatus(parameters[name].lower()) parameter_values[name].is_encrypted = status == "encrypted" - return Compiler.assemble(function, parameter_values, configuration, artifacts, **kwargs) + return Compiler.assemble(function_, parameter_values, configuration, artifacts, **kwargs) return decoration -def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]): +class Compilable: """ - Provide an easy interface for compilation. - - Args: - parameters (Mapping[str, Union[str, EncryptionStatus]]): - encryption statuses of the parameters of the function to compile + Compilable class, to wrap a function and provide methods to trace and compile it. """ - def decoration(function: Callable): - class Compilable: - """ - Compilable class, to wrap a function and provide methods to trace and compile it. - """ + function: Callable + compiler: Compiler + + def __init__(self, function_: Callable, parameters): + self.function = function_ # type: ignore + self.compiler = Compiler(self.function, dict(parameters)) + + def __call__(self, *args, **kwargs) -> Any: + self.compiler(*args, **kwargs) + return self.function(*args, **kwargs) + + def trace( + self, + inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, + ) -> Graph: + """ + Trace the function into computation graph. - function: Callable - compiler: Compiler + Args: + inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): + optional inputset to extend accumulated inputset before bounds measurement - def __init__(self, function: Callable): - self.function = function # type: ignore - self.compiler = Compiler(self.function, dict(parameters)) + configuration(Optional[Configuration], default = None): + configuration to use - def __call__(self, *args, **kwargs) -> Any: - self.compiler(*args, **kwargs) - return self.function(*args, **kwargs) + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process - def trace( - self, - inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - configuration: Optional[Configuration] = None, - artifacts: Optional[DebugArtifacts] = None, - **kwargs, - ) -> Graph: - """ - Trace the function into computation graph. + kwargs (Dict[str, Any]): + configuration options to overwrite - Args: - inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): - optional inputset to extend accumulated inputset before bounds measurement + Returns: + Graph: + computation graph representing the function prior to MLIR conversion + """ - configuration(Optional[Configuration], default = None): - configuration to use + return self.compiler.trace(inputset, configuration, artifacts, **kwargs) - artifacts (Optional[DebugArtifacts], default = None): - artifacts to store information about the process + def compile( + self, + inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, + configuration: Optional[Configuration] = None, + artifacts: Optional[DebugArtifacts] = None, + **kwargs, + ) -> Circuit: + """ + Compile the function into a circuit. - kwargs (Dict[str, Any]): - configuration options to overwrite + Args: + inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): + optional inputset to extend accumulated inputset before bounds measurement - Returns: - Graph: - computation graph representing the function prior to MLIR conversion - """ + configuration(Optional[Configuration], default = None): + configuration to use - return self.compiler.trace(inputset, configuration, artifacts, **kwargs) + artifacts (Optional[DebugArtifacts], default = None): + artifacts to store information about the process - def compile( - self, - inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]] = None, - configuration: Optional[Configuration] = None, - artifacts: Optional[DebugArtifacts] = None, - **kwargs, - ) -> Circuit: - """ - Compile the function into a circuit. + kwargs (Dict[str, Any]): + configuration options to overwrite - Args: - inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): - optional inputset to extend accumulated inputset before bounds measurement + Returns: + Circuit: + compiled circuit + """ - configuration(Optional[Configuration], default = None): - configuration to use + return self.compiler.compile(inputset, configuration, artifacts, **kwargs) - artifacts (Optional[DebugArtifacts], default = None): - artifacts to store information about the process - kwargs (Dict[str, Any]): - configuration options to overwrite +def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]): + """ + Provide an easy interface for the compilation of single-circuit programs. + + Args: + parameters (Mapping[str, Union[str, EncryptionStatus]]): + encryption statuses of the parameters of the function to compile + """ - Returns: - Circuit: - compiled circuit - """ + def decoration(function_: Callable): + return Compilable(function_, parameters) - return self.compiler.compile(inputset, configuration, artifacts, **kwargs) + return decoration + + +def module(): + """ + Provide an easy interface for the compilation of multi functions modules. + """ + + def decoration(class_): + functions = inspect.getmembers(class_, lambda x: isinstance(x, FunctionDef)) + if not functions: + error = "Tried to define an @fhe.module without any @fhe.function" + raise RuntimeError(error) + return functools.wraps(class_)(ModuleCompiler([f for (_, f) in functions])) + + return decoration + + +def function(parameters: Dict[str, Union[str, EncryptionStatus]]): + """ + Provide an easy interface to define a function within an fhe module. + + Args: + parameters (Mapping[str, Union[str, EncryptionStatus]]): + encryption statuses of the parameters of the function to compile + """ - return Compilable(function) + def decoration(function_: Callable): + return functools.wraps(function_)(FunctionDef(function_, parameters)) return decoration diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py new file mode 100644 index 0000000000..77b7a61a97 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -0,0 +1,682 @@ +""" +Declaration of `FheModule` classes. +""" + +# pylint: disable=import-error,no-member,no-name-in-module + +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union + +import numpy as np +from concrete.compiler import ( + CompilationContext, + Parameter, + SimulatedValueDecrypter, + SimulatedValueExporter, +) +from mlir.ir import Module as MlirModule + +from ..internal.utils import assert_that +from ..representation import Graph +from .client import Client +from .configuration import Configuration +from .keys import Keys +from .server import Server +from .utils import validate_input_args +from .value import Value + +# pylint: enable=import-error,no-member,no-name-in-module + + +class ExecutionRt(NamedTuple): + """ + Runtime object class for execution. + """ + + client: Client + server: Server + + +class SimulationRt(NamedTuple): + """ + Runtime object class for simulation. + """ + + server: Server + + +class FheFunction: + """ + Fhe function class, allowing to run or simulate one function of an fhe module. + """ + + runtime: Union[ExecutionRt, SimulationRt] + graph: Graph + name: str + + def __init__(self, name: str, runtime: Union[ExecutionRt, SimulationRt], graph: Graph): + self.name = name + self.runtime = runtime + self.graph = graph + + def __call__( + self, + *args: Any, + ) -> Union[ + np.bool_, + np.integer, + np.floating, + np.ndarray, + Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...], + ]: + return self.graph(*args) + + def draw( + self, + *, + horizontal: bool = False, + save_to: Optional[Union[Path, str]] = None, + show: bool = False, + ) -> Path: + """ + Draw the graph of the function. + + That this function requires the python `pygraphviz` package + which itself requires the installation of `graphviz` packages + + (see https://pygraphviz.github.io/documentation/stable/install.html) + + Args: + horizontal (bool, default = False): + whether to draw horizontally + + save_to (Optional[Path], default = None): + path to save the drawing + a temporary file will be used if it's None + + show (bool, default = False): + whether to show the drawing using matplotlib + + Returns: + Path: + path to the drawing + """ + return self.graph.draw(horizontal=horizontal, save_to=save_to, show=show) + + def __str__(self): + return self.graph.format() + + def simulate(self, *args: Any) -> Any: + """ + Simulate execution of the function. + + Args: + *args (Any): + inputs to the function + + Returns: + Any: + result of the simulation + """ + assert isinstance(self.runtime, SimulationRt) + + ordered_validated_args = validate_input_args( + self.runtime.server.client_specs, *args, function_name=self.name + ) + + exporter = SimulatedValueExporter.new( + self.runtime.server.client_specs.client_parameters, self.name + ) + exported = [ + None + if arg is None + else Value( + exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape)) + if isinstance(arg, np.ndarray) and arg.shape != () + else exporter.export_scalar(position, int(arg)) + ) + for position, arg in enumerate(ordered_validated_args) + ] + + results = self.runtime.server.run(*exported, function_name=self.name) + if not isinstance(results, tuple): + results = (results,) + + decrypter = SimulatedValueDecrypter.new( + self.runtime.server.client_specs.client_parameters, self.name + ) + decrypted = tuple( + decrypter.decrypt(position, result.inner) for position, result in enumerate(results) + ) + + return decrypted if len(decrypted) != 1 else decrypted[0] + + def encrypt( + self, + *args: Optional[Union[int, np.ndarray, List]], + ) -> Optional[Union[Value, Tuple[Optional[Value], ...]]]: + """ + Encrypt argument(s) to for evaluation. + + Args: + *args (Optional[Union[int, numpy.ndarray, List]]): + argument(s) for evaluation + + Returns: + Optional[Union[Value, Tuple[Optional[Value], ...]]]: + encrypted argument(s) for evaluation + """ + assert isinstance(self.runtime, ExecutionRt) + return self.runtime.client.encrypt(*args, function_name=self.name) + + def run( + self, + *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], + ) -> Union[Value, Tuple[Value, ...]]: + """ + Evaluate the function. + + Args: + *args (Value): + argument(s) for evaluation + + Returns: + Union[Value, Tuple[Value, ...]]: + result(s) of evaluation + """ + assert isinstance(self.runtime, ExecutionRt) + return self.runtime.server.run( + *args, evaluation_keys=self.runtime.client.evaluation_keys, function_name=self.name + ) + + def decrypt( + self, + *results: Union[Value, Tuple[Value, ...]], + ) -> Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + """ + Decrypt result(s) of evaluation. + + Args: + *results (Union[Value, Tuple[Value, ...]]): + result(s) of evaluation + + Returns: + Optional[Union[int, np.ndarray, Tuple[Optional[Union[int, np.ndarray]], ...]]]: + decrypted result(s) of evaluation + """ + assert isinstance(self.runtime, ExecutionRt) + return self.runtime.client.decrypt(*results, function_name=self.name) + + def encrypt_run_decrypt(self, *args: Any) -> Any: + """ + Encrypt inputs, run the function, and decrypt the outputs in one go. + + Args: + *args (Union[int, numpy.ndarray]): + inputs to the function + + Returns: + Union[int, np.ndarray, Tuple[Union[int, np.ndarray], ...]]: + clear result of homomorphic evaluation + """ + return self.decrypt(self.run(self.encrypt(*args))) + + @property + def size_of_inputs(self) -> int: + """ + Get size of the inputs of the function. + """ + return self.runtime.server.size_of_inputs(self.name) # pragma: no cover + + @property + def size_of_outputs(self) -> int: + """ + Get size of the outputs of the function. + """ + return self.runtime.server.size_of_outputs(self.name) # pragma: no cover + + # Programmable Bootstrap Statistics + + @property + def programmable_bootstrap_count(self) -> int: + """ + Get the number of programmable bootstraps in the function. + """ + return self.runtime.server.programmable_bootstrap_count(self.name) # pragma: no cover + + @property + def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of programmable bootstraps per bit width in the function. + """ + return self.runtime.server.programmable_bootstrap_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of programmable bootstraps per tag in the function. + """ + return self.runtime.server.programmable_bootstrap_count_per_tag( + self.name + ) # pragma: no cover + + @property + def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]: + """ + Get the number of programmable bootstraps per tag per bit width in the function. + """ + return self.runtime.server.programmable_bootstrap_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Key Switch Statistics + + @property + def key_switch_count(self) -> int: + """ + Get the number of key switches in the function. + """ + return self.runtime.server.key_switch_count(self.name) # pragma: no cover + + @property + def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of key switches per parameter in the function. + """ + return self.runtime.server.key_switch_count_per_parameter(self.name) # pragma: no cover + + @property + def key_switch_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of key switches per tag in the function. + """ + return self.runtime.server.key_switch_count_per_tag(self.name) # pragma: no cover + + @property + def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of key switches per tag per parameter in the function. + """ + return self.runtime.server.key_switch_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Packing Key Switch Statistics + + @property + def packing_key_switch_count(self) -> int: + """ + Get the number of packing key switches in the function. + """ + return self.runtime.server.packing_key_switch_count(self.name) # pragma: no cover + + @property + def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of packing key switches per parameter in the function. + """ + return self.runtime.server.packing_key_switch_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def packing_key_switch_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of packing key switches per tag in the function. + """ + return self.runtime.server.packing_key_switch_count_per_tag(self.name) # pragma: no cover + + @property + def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of packing key switches per tag per parameter in the function. + """ + return self.runtime.server.packing_key_switch_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Clear Addition Statistics + + @property + def clear_addition_count(self) -> int: + """ + Get the number of clear additions in the function. + """ + return self.runtime.server.clear_addition_count(self.name) # pragma: no cover + + @property + def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear additions per parameter in the function. + """ + return self.runtime.server.clear_addition_count_per_parameter(self.name) # pragma: no cover + + @property + def clear_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear additions per tag in the function. + """ + return self.runtime.server.clear_addition_count_per_tag(self.name) # pragma: no cover + + @property + def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear additions per tag per parameter in the function. + """ + return self.runtime.server.clear_addition_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Encrypted Addition Statistics + + @property + def encrypted_addition_count(self) -> int: + """ + Get the number of encrypted additions in the function. + """ + return self.runtime.server.encrypted_addition_count(self.name) # pragma: no cover + + @property + def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted additions per parameter in the function. + """ + return self.runtime.server.encrypted_addition_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def encrypted_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted additions per tag in the function. + """ + return self.runtime.server.encrypted_addition_count_per_tag(self.name) # pragma: no cover + + @property + def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted additions per tag per parameter in the function. + """ + return self.runtime.server.encrypted_addition_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Clear Multiplication Statistics + + @property + def clear_multiplication_count(self) -> int: + """ + Get the number of clear multiplications in the function. + """ + return self.runtime.server.clear_multiplication_count(self.name) # pragma: no cover + + @property + def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear multiplications per parameter in the function. + """ + return self.runtime.server.clear_multiplication_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def clear_multiplication_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear multiplications per tag in the function. + """ + return self.runtime.server.clear_multiplication_count_per_tag(self.name) # pragma: no cover + + @property + def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear multiplications per tag per parameter in the function. + """ + return self.runtime.server.clear_multiplication_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + # Encrypted Negation Statistics + + @property + def encrypted_negation_count(self) -> int: + """ + Get the number of encrypted negations in the function. + """ + return self.runtime.server.encrypted_negation_count(self.name) # pragma: no cover + + @property + def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted negations per parameter in the function. + """ + return self.runtime.server.encrypted_negation_count_per_parameter( + self.name + ) # pragma: no cover + + @property + def encrypted_negation_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted negations per tag in the function. + """ + return self.runtime.server.encrypted_negation_count_per_tag(self.name) # pragma: no cover + + @property + def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted negations per tag per parameter in the function. + """ + return self.runtime.server.encrypted_negation_count_per_tag_per_parameter( + self.name + ) # pragma: no cover + + @property + def statistics(self) -> Dict: + """ + Get all statistics of the function. + """ + attributes = [ + "size_of_inputs", + "size_of_outputs", + "programmable_bootstrap_count", + "programmable_bootstrap_count_per_parameter", + "programmable_bootstrap_count_per_tag", + "programmable_bootstrap_count_per_tag_per_parameter", + "key_switch_count", + "key_switch_count_per_parameter", + "key_switch_count_per_tag", + "key_switch_count_per_tag_per_parameter", + "packing_key_switch_count", + "packing_key_switch_count_per_parameter", + "packing_key_switch_count_per_tag", + "packing_key_switch_count_per_tag_per_parameter", + "clear_addition_count", + "clear_addition_count_per_parameter", + "clear_addition_count_per_tag", + "clear_addition_count_per_tag_per_parameter", + "encrypted_addition_count", + "encrypted_addition_count_per_parameter", + "encrypted_addition_count_per_tag", + "encrypted_addition_count_per_tag_per_parameter", + "clear_multiplication_count", + "clear_multiplication_count_per_parameter", + "clear_multiplication_count_per_tag", + "clear_multiplication_count_per_tag_per_parameter", + "encrypted_negation_count", + "encrypted_negation_count_per_parameter", + "encrypted_negation_count_per_tag", + "encrypted_negation_count_per_tag_per_parameter", + ] + return {attribute: getattr(self, attribute) for attribute in attributes} + + +class FheModule: + """ + Fhe module class, to combine computation graphs, mlir, runtime objects into a single object. + """ + + configuration: Configuration + graphs: Dict[str, Graph] + mlir_module: MlirModule + compilation_context: CompilationContext + runtime: Union[ExecutionRt, SimulationRt] + + def __init__( + self, + graphs: Dict[str, Graph], + mlir: MlirModule, + compilation_context: CompilationContext, + configuration: Optional[Configuration] = None, + ): + assert configuration and (configuration.fhe_simulation or configuration.fhe_execution) + + self.configuration = configuration if configuration is not None else Configuration() + self.graphs = graphs + self.mlir_module = mlir + self.compilation_context = compilation_context + + if self.configuration.fhe_simulation: + server = Server.create( + self.mlir_module, + self.configuration, + is_simulated=True, + compilation_context=self.compilation_context, + ) + self.runtime = SimulationRt(server) + else: + server = Server.create( + self.mlir_module, self.configuration, compilation_context=self.compilation_context + ) + + keyset_cache_directory = None + if self.configuration.use_insecure_key_cache: + assert_that(self.configuration.enable_unsafe_features) + assert_that(self.configuration.insecure_key_cache_location is not None) + keyset_cache_directory = self.configuration.insecure_key_cache_location + + client = Client(server.client_specs, keyset_cache_directory) + self.runtime = ExecutionRt(client, server) + + @property + def mlir(self) -> str: + """Textual representation of the MLIR module. + + Returns: + str: textual representation of the MLIR module + """ + return str(self.mlir_module).strip() + + @property + def keys(self) -> Optional[Keys]: + """ + Get the keys of the module. + """ + if isinstance(self.runtime, ExecutionRt): + return self.runtime.client.keys + return None + + @keys.setter + def keys(self, new_keys: Keys): + """ + Set the keys of the module. + """ + if isinstance(self.runtime, ExecutionRt): + self.runtime.client.keys = new_keys + + def keygen( + self, force: bool = False, seed: Optional[int] = None, encryption_seed: Optional[int] = None + ): + """ + Generate keys required for homomorphic evaluation. + + Args: + force (bool, default = False): + whether to generate new keys even if keys are already generated + + seed (Optional[int], default = None): + seed for private keys randomness + + encryption_seed (Optional[int], default = None): + seed for encryption randomness + """ + if isinstance(self.runtime, ExecutionRt): + self.runtime.client.keygen(force, seed, encryption_seed) + + def cleanup(self): + """ + Cleanup the temporary library output directory. + """ + self.runtime.server.cleanup() + + @property + def size_of_secret_keys(self) -> int: + """ + Get size of the secret keys of the module. + """ + return self.runtime.server.size_of_secret_keys # pragma: no cover + + @property + def size_of_bootstrap_keys(self) -> int: + """ + Get size of the bootstrap keys of the module. + """ + return self.runtime.server.size_of_bootstrap_keys # pragma: no cover + + @property + def size_of_keyswitch_keys(self) -> int: + """ + Get size of the key switch keys of the module. + """ + return self.runtime.server.size_of_keyswitch_keys # pragma: no cover + + @property + def p_error(self) -> int: + """ + Get probability of error for each simple TLU (on a scalar). + """ + return self.runtime.server.p_error # pragma: no cover + + @property + def global_p_error(self) -> int: + """ + Get the probability of having at least one simple TLU error during the entire execution. + """ + return self.runtime.server.global_p_error # pragma: no cover + + @property + def complexity(self) -> float: + """ + Get complexity of the module. + """ + return self.runtime.server.complexity # pragma: no cover + + @property + def statistics(self) -> Dict: + """ + Get all statistics of the module. + """ + attributes = [ + "size_of_secret_keys", + "size_of_bootstrap_keys", + "size_of_keyswitch_keys", + "p_error", + "global_p_error", + "complexity", + ] + statistics = {attribute: getattr(self, attribute) for attribute in attributes} + statistics["functions"] = { + name: function.statistics for (name, function) in self.functions().items() + } + return statistics + + def functions(self) -> Dict[str, FheFunction]: + """ + Return a dictionnary containing all the functions of the module. + """ + return {name: getattr(self, name) for name in self.graphs.keys()} + + def __getattr__(self, item): + if item not in list(self.graphs.keys()): + self.__getattribute__(item) + return FheFunction(item, self.runtime, self.graphs[item]) diff --git a/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py b/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py new file mode 100644 index 0000000000..751f2eb10c --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py @@ -0,0 +1,571 @@ +""" +Declaration of `MultiCompiler` class. +""" + +# pylint: disable=import-error,no-name-in-module + +import inspect +import traceback +from copy import deepcopy +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +from concrete.compiler import CompilationContext + +from ..extensions import AutoRounder, AutoTruncator +from ..mlir import GraphConverter +from ..representation import Graph +from ..tracing import Tracer +from ..values import ValueDescription +from .artifacts import FunctionDebugArtifacts, ModuleDebugArtifacts +from .compiler import EncryptionStatus +from .configuration import Configuration +from .module import ExecutionRt, FheModule +from .utils import fuse, get_terminal_size + +DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts") + +# pylint: enable=import-error,no-name-in-module + + +class FunctionDef: + """ + An object representing the definition of a function as used in an fhe module. + """ + + name: str + function: Callable + parameter_encryption_statuses: Dict[str, EncryptionStatus] + inputset: List[Any] + graph: Optional[Graph] + artifacts: Optional[FunctionDebugArtifacts] + _is_direct: bool + _parameter_values: Dict[str, ValueDescription] + + def __init__( + self, + function: Callable, + parameter_encryption_statuses: Dict[str, Union[str, EncryptionStatus]], + ): + signature = inspect.signature(function) + + missing_args = list(signature.parameters) + for arg in parameter_encryption_statuses.keys(): + if arg in signature.parameters: + missing_args.remove(arg) + + if len(missing_args) != 0: + parameter_str = repr(missing_args[0]) + for arg in missing_args[1:-1]: + parameter_str += f", {repr(arg)}" + if len(missing_args) != 1: + parameter_str += f" and {repr(missing_args[-1])}" + + message = ( + f"Encryption status{'es' if len(missing_args) > 1 else ''} " + f"of parameter{'s' if len(missing_args) > 1 else ''} " + f"{parameter_str} of function '{function.__name__}' " + f"{'are' if len(missing_args) > 1 else 'is'} not provided" + ) + raise ValueError(message) + + additional_args = list(parameter_encryption_statuses) + for arg in signature.parameters.keys(): + if arg in parameter_encryption_statuses: + additional_args.remove(arg) + + if len(additional_args) != 0: + parameter_str = repr(additional_args[0]) + for arg in additional_args[1:-1]: + parameter_str += f", {repr(arg)}" + if len(additional_args) != 1: + parameter_str += f" and {repr(additional_args[-1])}" + + message = ( + f"Encryption status{'es' if len(additional_args) > 1 else ''} " + f"of {parameter_str} {'are' if len(additional_args) > 1 else 'is'} provided but " + f"{'they are' if len(additional_args) > 1 else 'it is'} not a parameter " + f"of function '{function.__name__}'" + ) + raise ValueError(message) + + self.function = function # type: ignore + self.parameter_encryption_statuses = { + param: EncryptionStatus(status.lower()) + for param, status in parameter_encryption_statuses.items() + } + self.artifacts = None + self.inputset = [] + self.graph = None + self.name = function.__name__ + self._is_direct = False + self._parameter_values = {} + + def trace(self, sample: Union[Any, Tuple[Any, ...]]): + """ + Trace the function and fuse the resulting graph with a sample input. + + Args: + sample (Union[Any, Tuple[Any, ...]]): + sample to use for tracing + """ + + if self.artifacts is not None: + self.artifacts.add_source_code(self.function) + for param, encryption_status in self.parameter_encryption_statuses.items(): + self.artifacts.add_parameter_encryption_status(param, encryption_status) + + parameters = { + param: ValueDescription.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED)) + for arg, (param, status) in zip( + ( + sample + if len(self.parameter_encryption_statuses) > 1 or isinstance(sample, tuple) + else (sample,) + ), + self.parameter_encryption_statuses.items(), + ) + } + + self.graph = Tracer.trace(self.function, parameters, name=self.name) + if self.artifacts is not None: + self.artifacts.add_graph("initial", self.graph) + + fuse(self.graph, self.artifacts) + + def evaluate( + self, + action: str, + inputset: Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]], + configuration: Configuration, + artifacts: FunctionDebugArtifacts, + ): + """ + Trace, fuse, measure bounds, and update values in the resulting graph in one go. + + Args: + action (str): + action being performed (e.g., "trace", "compile") + + inputset (Optional[Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]): + optional inputset to extend accumulated inputset before bounds measurement + + configuration (Configuration): + configuration to be used + + artifacts (FunctionDebugArtifacts): + artifact object to store informations in + """ + + if self._is_direct: + self.graph = Tracer.trace( + self.function, self._parameter_values, is_direct=True, name=self.name + ) + artifacts.add_graph("initial", self.graph) # pragma: no cover + fuse(self.graph, artifacts) + artifacts.add_graph("final", self.graph) # pragma: no cover + return + + if inputset is not None: + previous_inputset_length = len(self.inputset) + for index, sample in enumerate(iter(inputset)): + self.inputset.append(sample) + + if not isinstance(sample, tuple): + sample = (sample,) + + if len(sample) != len(self.parameter_encryption_statuses): + self.inputset = self.inputset[:previous_inputset_length] + + expected = ( + "a single value" + if len(self.parameter_encryption_statuses) == 1 + else f"a tuple of {len(self.parameter_encryption_statuses)} values" + ) + actual = ( + "a single value" if len(sample) == 1 else f"a tuple of {len(sample)} values" + ) + + message = ( + f"Input #{index} of your inputset is not well formed " + f"(expected {expected} got {actual})" + ) + raise ValueError(message) + + if configuration.auto_adjust_rounders: + AutoRounder.adjust(self.function, self.inputset) + + if configuration.auto_adjust_truncators: + AutoTruncator.adjust(self.function, self.inputset) + + if self.graph is None: + try: + first_sample = next(iter(self.inputset)) + except StopIteration as error: + message = ( + f"{action} function '{self.function.__name__}' " + f"without an inputset is not supported" + ) + raise RuntimeError(message) from error + + self.trace(first_sample) + assert self.graph is not None + + bounds = self.graph.measure_bounds(self.inputset) + self.graph.update_with_bounds(bounds) + + artifacts.add_graph("final", self.graph) + + def __call__( + self, + *args: Any, + **kwargs: Any, + ) -> Union[ + np.bool_, + np.integer, + np.floating, + np.ndarray, + Tuple[Union[np.bool_, np.integer, np.floating, np.ndarray], ...], + ]: + if len(kwargs) != 0: + message = f"Calling function '{self.function.__name__}' with kwargs is not supported" + raise RuntimeError(message) + + sample = args[0] if len(args) == 1 else args + + if self.graph is None: + self.trace(sample) + assert self.graph is not None + + self.inputset.append(sample) + return self.graph(*args) + + +class DebugManager: + """ + A debug manager, allowing streamlined debugging. + """ + + configuration: Configuration + begin_call: Callable + + def __init__(self, config: Configuration): + self.configuration = config + is_first = [True] + + def begin_call(): + if is_first[0]: + print() + is_first[0] = False + + self.begin_call = begin_call + + def debug_table(self, title: str, activate: bool = True): + """ + Return a context manager that prints a table around what is printed inside the scope. + """ + + # pylint: disable=missing-class-docstring + class DebugTableCm: + def __init__(self, title): + self.title = title + self.columns = get_terminal_size() + + def __enter__(self): + print(f"{self.title}") + print("-" * self.columns) + + def __exit__(self, _exc_type, _exc_value, _exc_tb): + print("-" * self.columns) + print() + + class EmptyCm: + def __enter__(self): + pass + + def __exit__(self, _exc_type, _exc_value, _exc_tb): + pass + + if activate: + self.begin_call() + return DebugTableCm(title) + return EmptyCm() + + def show_graph(self) -> bool: + """ + Tell if the configuration involves showing graph. + """ + + return ( + self.configuration.show_graph + if self.configuration.show_graph is not None + else self.configuration.verbose + ) + + def show_bit_width_constraints(self) -> bool: + """ + Tell if the configuration involves showing bitwidth constraints. + """ + + return ( + self.configuration.show_bit_width_constraints + if self.configuration.show_bit_width_constraints is not None + else self.configuration.verbose + ) + + def show_bit_width_assignments(self) -> bool: + """ + Tell if the configuration involves showing bitwidth assignments. + """ + + return ( + self.configuration.show_bit_width_assignments + if self.configuration.show_bit_width_assignments is not None + else self.configuration.verbose + ) + + def show_assigned_graph(self) -> bool: + """ + Tell if the configuration involves showing assigned graph. + """ + + return ( + self.configuration.show_assigned_graph + if self.configuration.show_assigned_graph is not None + else self.configuration.verbose + ) + + def show_mlir(self) -> bool: + """ + Tell if the configuration involves showing mlir. + """ + + return ( + self.configuration.show_mlir + if self.configuration.show_mlir is not None + else self.configuration.verbose + ) + + def show_optimizer(self) -> bool: + """ + Tell if the configuration involves showing optimizer. + """ + + return ( + self.configuration.show_optimizer + if self.configuration.show_optimizer is not None + else self.configuration.verbose + ) + + def show_statistics(self) -> bool: + """ + Tell if the configuration involves showing statistics. + """ + + return ( + self.configuration.show_statistics + if self.configuration.show_statistics is not None + else self.configuration.verbose + ) + + def debug_computation_graph(self, name, function_graph): + """ + Print computation graph if configuration tells so. + """ + + if ( + self.show_graph() + or self.show_bit_width_constraints() + or self.show_bit_width_assignments() + or self.show_assigned_graph() + or self.show_mlir() + or self.show_optimizer() + or self.show_statistics() + ): + if self.show_graph(): + with self.debug_table(f"Computation Graph for {name}"): + print(function_graph.format()) + + def debug_bit_width_constaints(self, name, function_graph): + """ + Print bitwidth constraints if configuration tells so. + """ + + if self.show_bit_width_constraints(): + with self.debug_table(f"Bit-Width Constraints for {name}"): + print(function_graph.format_bit_width_constraints()) + + def debug_bit_width_assignments(self, name, function_graph): + """ + Print bitwidth assignments if configuration tells so. + """ + + if self.show_bit_width_assignments(): + with self.debug_table(f"Bit-Width Assignments for {name}"): + print(function_graph.format_bit_width_assignments()) + + def debug_assigned_graph(self, name, function_graph): + """ + Print assigned graphs if configuration tells so. + """ + + if self.show_assigned_graph(): + with self.debug_table(f"Bit-Width Assigned Computation Graph for {name}"): + print(function_graph.format(show_assigned_bit_widths=True)) + + def debug_mlir(self, mlir_str): + """ + Print mlir if configuration tells so. + """ + + if self.show_mlir(): + with self.debug_table("MLIR"): + print(mlir_str) + + def debug_statistics(self, module): + """ + Print statistics if configuration tells so. + """ + + if self.show_statistics(): + + def pretty(d, indent=0): # pragma: no cover + if indent > 0: + print("{") + + for key, value in d.items(): + if isinstance(value, dict) and len(value) == 0: + continue + print(" " * indent + str(key) + ": ", end="") + if isinstance(value, dict): + pretty(value, indent + 1) + else: + print(value) + if indent > 0: + print(" " * (indent - 1) + "}") + + with self.debug_table("Statistics"): + pretty(module.statistics) + + +class ModuleCompiler: + """ + Compiler class for multiple functions, to glue the compilation pipeline. + """ + + default_configuration: Configuration + functions: Dict[str, FunctionDef] + compilation_context: CompilationContext + + def __init__(self, functions: List[FunctionDef]): + self.configuration = Configuration(composable=True) + self.functions = {function.name: function for function in functions} + self.compilation_context = CompilationContext.new() + + def compile( + self, + inputsets: Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]] = None, + configuration: Optional[Configuration] = None, + module_artifacts: Optional[ModuleDebugArtifacts] = None, + **kwargs, + ) -> FheModule: + """ + Compile the module using an ensemble of inputsets. + + Args: + inputsets (Optional[Dict[str, Union[Iterable[Any], Iterable[Tuple[Any, ...]]]]]): + optional inputsets to extend accumulated inputsets before bounds measurement + + configuration(Optional[Configuration], default = None): + configuration to use + + artifacts (Optional[ModuleDebugArtifacts], default = None): + artifacts to store information about the process + + kwargs (Dict[str, Any]): + configuration options to overwrite + + Returns: + FheModule: + compiled module + """ + + configuration = configuration if configuration is not None else self.default_configuration + configuration = deepcopy(configuration) + if len(kwargs) != 0: + configuration = configuration.fork(**kwargs) + if not configuration.composable: + error = "Module can only be compiled with `composable` activated." + raise RuntimeError(error) + + module_artifacts = ( + module_artifacts + if module_artifacts is not None + else ModuleDebugArtifacts(list(self.functions.keys())) + ) + + dbg = DebugManager(configuration) + + try: + # Trace and fuse the functions + for name, function in self.functions.items(): + inputset = inputsets[name] if inputsets is not None else None + function_artifacts = module_artifacts.functions[name] + function.evaluate("Compiling", inputset, self.configuration, function_artifacts) + assert function.graph is not None + dbg.debug_computation_graph(name, function.graph) + + # Convert the graphs to an mlir module + mlir_context = self.compilation_context.mlir_context() + graphs = {} + for name, function in self.functions.items(): + if function.graph is None: + error = "Expected graph to be set." + raise RuntimeError(error) + graphs[name] = function.graph + mlir_module = GraphConverter(self.configuration).convert_many(graphs, mlir_context) + mlir_str = str(mlir_module).strip() + dbg.debug_mlir(mlir_str) + module_artifacts.add_mlir_to_compile(mlir_str) + + # Debug some function informations + for name, function in self.functions.items(): + dbg.debug_bit_width_constaints(name, function.graph) + dbg.debug_bit_width_assignments(name, function.graph) + dbg.debug_assigned_graph(name, function.graph) + + # Compile to a module! + with dbg.debug_table("Optimizer", activate=dbg.show_optimizer()): + output = FheModule(graphs, mlir_module, self.compilation_context, configuration) + if isinstance(output.runtime, ExecutionRt): + client_parameters = output.runtime.client.specs.client_parameters + module_artifacts.add_client_parameters(client_parameters.serialize()) + + dbg.debug_statistics(output) + + except Exception: # pragma: no cover + # this branch is reserved for unexpected issues and hence it shouldn't be tested + # if it could be tested, we would have fixed the underlying issue + + # if the user desires so, + # we need to export all the information we have about the compilation + + if self.configuration.dump_artifacts_on_unexpected_failures: + module_artifacts.export() + + traceback_path = self.artifacts.output_directory.joinpath("traceback.txt") + with open(traceback_path, "w", encoding="utf-8") as f: + f.write(traceback.format_exc()) + + raise + + return output + + # pylint: enable=too-many-branches,too-many-statements + + def __getattr__(self, item): + if item not in list(self.functions.keys()): + error = f"No attribute {item}" + raise AttributeError(error) + return self.functions[item] diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 4965d413e4..5a13fdef97 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -17,11 +17,11 @@ CompilationOptions, EvaluationKeys, LibraryCompilationResult, - LibraryLambda, LibrarySupport, Parameter, ProgramCompilationFeedback, PublicArguments, + ServerProgram, set_compiler_logging, set_llvm_debug_flag, ) @@ -60,7 +60,7 @@ class Server: _support: LibrarySupport _compilation_result: LibraryCompilationResult _compilation_feedback: ProgramCompilationFeedback - _server_lambda: LibraryLambda + _server_program: ServerProgram _mlir: Optional[str] _configuration: Optional[Configuration] @@ -71,7 +71,7 @@ def __init__( output_dir: Optional[tempfile.TemporaryDirectory], support: LibrarySupport, compilation_result: LibraryCompilationResult, - server_lambda: LibraryLambda, + server_program: ServerProgram, is_simulated: bool, ): self.client_specs = client_specs @@ -81,7 +81,7 @@ def __init__( self._support = support self._compilation_result = compilation_result self._compilation_feedback = self._support.load_compilation_feedback(compilation_result) - self._server_lambda = server_lambda + self._server_program = server_program self._mlir = None assert_that( @@ -204,7 +204,7 @@ def create( compilation_context is not None ), "must provide compilation context when compiling MlirModule" compilation_result = support.compile(mlir, options, compilation_context) - server_lambda = support.load_server_lambda(compilation_result, is_simulated) + server_program = ServerProgram.load(support, is_simulated) finally: set_llvm_debug_flag(False) set_compiler_logging(False) @@ -217,7 +217,7 @@ def create( output_dir, support, compilation_result, - server_lambda, + server_program, is_simulated, ) @@ -318,16 +318,17 @@ def load(path: Union[str, Path]) -> "Server": generateStaticLib=False, ) compilation_result = support.reload() - server_lambda = support.load_server_lambda(compilation_result, is_simulated) + server_program = ServerProgram.load(support, is_simulated) return Server( - client_specs, output_dir, support, compilation_result, server_lambda, is_simulated + client_specs, output_dir, support, compilation_result, server_program, is_simulated ) def run( self, *args: Optional[Union[Value, Tuple[Optional[Value], ...]]], evaluation_keys: Optional[EvaluationKeys] = None, + function_name: str = "main", ) -> Union[Value, Tuple[Value, ...]]: """ Evaluate. @@ -339,6 +340,9 @@ def run( evaluation_keys (Optional[EvaluationKeys], default = None): evaluation keys required for fhe execution + function_name (str): + The name of the function to run + Returns: Union[Value, Tuple[Value, ...]]: result(s) of evaluation @@ -368,13 +372,12 @@ def run( buffers.append(arg.inner) public_args = PublicArguments.new(self.client_specs.client_parameters, buffers) + server_circuit = self._server_program.get_server_circuit(function_name) if self.is_simulated: - public_result = self._support.simulate(self._server_lambda, public_args) + public_result = server_circuit.simulate(public_args) else: - public_result = self._support.server_call( - self._server_lambda, public_args, evaluation_keys - ) + public_result = server_circuit.call(public_args, evaluation_keys) result = tuple(Value(public_result.get_value(i)) for i in range(public_result.n_values())) return result if len(result) > 1 else result[0] @@ -408,20 +411,6 @@ def size_of_keyswitch_keys(self) -> int: """ return self._compilation_feedback.total_keyswitch_keys_size - @property - def size_of_inputs(self) -> int: - """ - Get size of the inputs of the compiled program. - """ - return self._compilation_feedback.circuit("main").total_inputs_size - - @property - def size_of_outputs(self) -> int: - """ - Get size of the outputs of the compiled program. - """ - return self._compilation_feedback.circuit("main").total_output_size - @property def p_error(self) -> int: """ @@ -443,43 +432,55 @@ def complexity(self) -> float: """ return self._compilation_feedback.complexity + def size_of_inputs(self, function: str = "main") -> int: + """ + Get size of the inputs of the compiled program. + """ + return self._compilation_feedback.circuit(function).total_inputs_size + + def size_of_outputs(self, function: str = "main") -> int: + """ + Get size of the outputs of the compiled program. + """ + return self._compilation_feedback.circuit(function).total_output_size + # Programmable Bootstrap Statistics - @property - def programmable_bootstrap_count(self) -> int: + def programmable_bootstrap_count(self, function: str = "main") -> int: """ Get the number of programmable bootstraps in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) - @property - def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: + def programmable_bootstrap_count_per_parameter( + self, function: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of programmable bootstraps per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, client_parameters=self.client_specs.client_parameters, ) - @property - def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: + def programmable_bootstrap_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of programmable bootstraps per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, ) - @property - def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def programmable_bootstrap_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of programmable bootstraps per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, key_types={KeyType.BOOTSTRAP}, client_parameters=self.client_specs.client_parameters, @@ -487,41 +488,39 @@ def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[P # Key Switch Statistics - @property - def key_switch_count(self) -> int: + def key_switch_count(self, function: str = "main") -> int: """ Get the number of key switches in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) - @property - def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + def key_switch_count_per_parameter(self, function: str = "main") -> Dict[Parameter, int]: """ Get the number of key switches per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, ) - @property - def key_switch_count_per_tag(self) -> Dict[str, int]: + def key_switch_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of key switches per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, ) - @property - def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def key_switch_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, key_types={KeyType.KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -529,41 +528,41 @@ def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, in # Packing Key Switch Statistics - @property - def packing_key_switch_count(self) -> int: + def packing_key_switch_count(self, function: str = "main") -> int: """ Get the number of packing key switches in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.WOP_PBS} ) - @property - def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + def packing_key_switch_count_per_parameter( + self, function: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of packing key switches per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, ) - @property - def packing_key_switch_count_per_tag(self) -> Dict[str, int]: + def packing_key_switch_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of packing key switches per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.WOP_PBS} ) - @property - def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def packing_key_switch_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of packing key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.WOP_PBS}, key_types={KeyType.PACKING_KEY_SWITCH}, client_parameters=self.client_specs.client_parameters, @@ -571,41 +570,39 @@ def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Param # Clear Addition Statistics - @property - def clear_addition_count(self) -> int: + def clear_addition_count(self, function: str = "main") -> int: """ Get the number of clear additions in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.CLEAR_ADDITION} ) - @property - def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: + def clear_addition_count_per_parameter(self, function: str = "main") -> Dict[Parameter, int]: """ Get the number of clear additions per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def clear_addition_count_per_tag(self) -> Dict[str, int]: + def clear_addition_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of clear additions per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.CLEAR_ADDITION}, ) - @property - def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def clear_addition_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of clear additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.CLEAR_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -613,41 +610,41 @@ def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter # Encrypted Addition Statistics - @property - def encrypted_addition_count(self) -> int: + def encrypted_addition_count(self, function: str = "main") -> int: """ Get the number of encrypted additions in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.ENCRYPTED_ADDITION} ) - @property - def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: + def encrypted_addition_count_per_parameter( + self, function: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of encrypted additions per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def encrypted_addition_count_per_tag(self) -> Dict[str, int]: + def encrypted_addition_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of encrypted additions per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, ) - @property - def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def encrypted_addition_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of encrypted additions per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.ENCRYPTED_ADDITION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -655,41 +652,41 @@ def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Param # Clear Multiplication Statistics - @property - def clear_multiplication_count(self) -> int: + def clear_multiplication_count(self, function: str = "main") -> int: """ Get the number of clear multiplications in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) - @property - def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: + def clear_multiplication_count_per_parameter( + self, function: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of clear multiplications per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def clear_multiplication_count_per_tag(self) -> Dict[str, int]: + def clear_multiplication_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of clear multiplications per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, ) - @property - def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def clear_multiplication_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of clear multiplications per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -697,41 +694,41 @@ def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Par # Encrypted Negation Statistics - @property - def encrypted_negation_count(self) -> int: + def encrypted_negation_count(self, function: str = "main") -> int: """ Get the number of encrypted negations in the compiled program. """ - return self._compilation_feedback.circuit("main").count( + return self._compilation_feedback.circuit(function).count( operations={PrimitiveOperation.ENCRYPTED_NEGATION} ) - @property - def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: + def encrypted_negation_count_per_parameter( + self, function: str = "main" + ) -> Dict[Parameter, int]: """ Get the number of encrypted negations per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_parameter( + return self._compilation_feedback.circuit(function).count_per_parameter( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, ) - @property - def encrypted_negation_count_per_tag(self) -> Dict[str, int]: + def encrypted_negation_count_per_tag(self, function: str = "main") -> Dict[str, int]: """ Get the number of encrypted negations per tag in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag( + return self._compilation_feedback.circuit(function).count_per_tag( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, ) - @property - def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + def encrypted_negation_count_per_tag_per_parameter( + self, function: str = "main" + ) -> Dict[str, Dict[Parameter, int]]: """ Get the number of encrypted negations per tag per parameter in the compiled program. """ - return self._compilation_feedback.circuit("main").count_per_tag_per_parameter( + return self._compilation_feedback.circuit(function).count_per_tag_per_parameter( operations={PrimitiveOperation.ENCRYPTED_NEGATION}, key_types={KeyType.SECRET}, client_parameters=self.client_specs.client_parameters, @@ -745,14 +742,8 @@ def statistics(self) -> Dict: Get all statistics of the compiled program. """ attributes = [ - "size_of_secret_keys", - "size_of_bootstrap_keys", - "size_of_keyswitch_keys", "size_of_inputs", "size_of_outputs", - "p_error", - "global_p_error", - "complexity", "programmable_bootstrap_count", "programmable_bootstrap_count_per_parameter", "programmable_bootstrap_count_per_tag", @@ -782,4 +773,11 @@ def statistics(self) -> Dict: "encrypted_negation_count_per_tag", "encrypted_negation_count_per_tag_per_parameter", ] - return {attribute: getattr(self, attribute) for attribute in attributes} + output = {attribute: getattr(self, attribute)() for attribute in attributes} + output["size_of_secret_keys"] = self.size_of_secret_keys + output["size_of_bootstrap_keys"] = self.size_of_bootstrap_keys + output["size_of_keyswitch_keys"] = self.size_of_keyswitch_keys + output["p_error"] = self.p_error + output["global_p_error"] = self.global_p_error + output["complexity"] = self.complexity + return output diff --git a/frontends/concrete-python/concrete/fhe/compilation/utils.py b/frontends/concrete-python/concrete/fhe/compilation/utils.py index 3c169436ab..2ba2c7c392 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/utils.py +++ b/frontends/concrete-python/concrete/fhe/compilation/utils.py @@ -15,7 +15,7 @@ from ..representation import Graph, Node, Operation from ..tracing import ScalarAnnotation from ..values import ValueDescription -from .artifacts import DebugArtifacts +from .artifacts import FunctionDebugArtifacts from .specs import ClientSpecs # ruff: noqa: ERA001 @@ -68,6 +68,7 @@ def inputset( def validate_input_args( client_specs: ClientSpecs, *args: Optional[Union[int, np.ndarray, List]], + function_name: str = "main", ) -> List[Optional[Union[int, np.ndarray]]]: """Validate input arguments. @@ -76,11 +77,15 @@ def validate_input_args( client specification *args (Optional[Union[int, np.ndarray, List]]): argument(s) for evaluation + function_name (str): name of the function to verify Returns: List[Optional[Union[int, np.ndarray]]]: ordered validated args """ - client_parameters_json = json.loads(client_specs.client_parameters.serialize())["circuits"][0] + functions_parameters = json.loads(client_specs.client_parameters.serialize())["circuits"] + client_parameters_json = next( + filter(lambda x: x["name"] == function_name, functions_parameters) + ) assert "inputs" in client_parameters_json input_specs = client_parameters_json["inputs"] if len(args) != len(input_specs): @@ -153,7 +158,7 @@ def validate_input_args( return ordered_sanitized_args -def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None): +def fuse(graph: Graph, artifacts: Optional[FunctionDebugArtifacts] = None): """ Fuse appropriate subgraphs in a graph to a single Operation.Generic node. diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 533184eaeb..7ca4e86ef1 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -19,7 +19,7 @@ from mlir.ir import Module as MlirModule from ..compilation.configuration import Configuration, Exactness -from ..representation import Graph, Node, Operation +from ..representation import Graph, GraphProcessor, MultiGraphProcessor, Node, Operation from .context import Context from .conversion import Conversion from .processors import * # pylint: disable=wildcard-import @@ -38,17 +38,17 @@ class Converter: def __init__(self, configuration: Configuration): self.configuration = configuration - def convert( + def convert_many( self, - graph: Graph, + graphs: Dict[str, Graph], mlir_context: MlirContext, ) -> MlirModule: """ - Convert a computation graph to MLIR. + Convert multiple computation graphs to an MLIR module. Args: - graph (Graph): - graph to convert + graphs (Dict[str, Graph]): + graphs to convert mlir_context (MlirContext): MLIR Context to use for module generation @@ -57,46 +57,76 @@ def convert( MlirModule: In-memory MLIR module corresponding to the graph """ - - self.process(graph) + self.process(graphs) with mlir_context as context, MlirLocation.unknown(): concrete.lang.register_dialects(context) # pylint: disable=no-member module = MlirModule.create() with MlirInsertionPoint(module.body): - ctx = Context(context, graph, self.configuration) + for name, graph in graphs.items(): + # pylint: disable=cell-var-from-loop + # ruff: noqa: B023 + ctx = Context(context, graph, self.configuration) + + input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()] + + @func.FuncOp.from_py_func(*input_types, name=name) + def main(*args): + for index, node in enumerate(graph.ordered_inputs()): + conversion = Conversion(node, args[index]) + if "original_bit_width" in node.properties: + conversion.set_original_bit_width( + node.properties["original_bit_width"] + ) + ctx.conversions[node] = conversion + + ordered_nodes = [ + node + for node in nx.lexicographical_topological_sort(graph.graph) + if node.operation != Operation.Input + ] + + for progress_index, node in enumerate(ordered_nodes): + self.trace_progress(self.configuration, progress_index, ordered_nodes) + preds = [ctx.conversions[pred] for pred in graph.ordered_preds_of(node)] + self.node(ctx, node, preds) + self.trace_progress(self.configuration, len(ordered_nodes), ordered_nodes) + + outputs = [] + for node in graph.ordered_outputs(): + assert node in ctx.conversions + outputs.append(ctx.conversions[node].result) + + return tuple(outputs) - input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()] + return module - @func.FuncOp.from_py_func(*input_types) - def main(*args): - for index, node in enumerate(graph.ordered_inputs()): - conversion = Conversion(node, args[index]) - if "original_bit_width" in node.properties: - conversion.set_original_bit_width(node.properties["original_bit_width"]) - ctx.conversions[node] = conversion + def convert( + self, + graph: Graph, + mlir_context: MlirContext, + name: str = "main", + ) -> MlirModule: + """ + Convert a computation graph to MLIR. - ordered_nodes = [ - node - for node in nx.lexicographical_topological_sort(graph.graph) - if node.operation != Operation.Input - ] + Args: + graph (Graph): + graph to convert - for progress_index, node in enumerate(ordered_nodes): - self.trace_progress(self.configuration, progress_index, ordered_nodes) - preds = [ctx.conversions[pred] for pred in graph.ordered_preds_of(node)] - self.node(ctx, node, preds) - self.trace_progress(self.configuration, len(ordered_nodes), ordered_nodes) + mlir_context (MlirContext): + MLIR Context to use for module generation - outputs = [] - for node in graph.ordered_outputs(): - assert node in ctx.conversions - outputs.append(ctx.conversions[node].result) + name (str): + name of the function to convert - return tuple(outputs) + Return: + MlirModule: + In-memory MLIR module corresponding to the graph + """ - return module + return self.convert_many({name: graph}, mlir_context) @staticmethod def stdout_with_ansi_support() -> bool: @@ -173,13 +203,13 @@ def trace_progress(cls, configuration: Configuration, progress_index: int, nodes return concrete.lang.dialects.tracing.TraceMessageOp(msg=msg) # pylint: disable=no-member - def process(self, graph: Graph): + def process(self, graphs: Dict[str, Graph]): """ Process a computation graph for MLIR conversion. Args: - graph (Graph): - graph to process + graphs (Dict[str, Graph]): + graphs to process """ configuration = self.configuration @@ -205,7 +235,12 @@ def process(self, graph: Graph): ) for processor in pipeline: - processor.apply(graph) + assert isinstance(processor, GraphProcessor) + if isinstance(processor, MultiGraphProcessor): + processor.apply_many(graphs) + else: + for graph in graphs.values(): + processor.apply(graph) def node(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: """ diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index f1359630f6..6d3f5a9c1b 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -14,10 +14,10 @@ MultivariateStrategy, ) from ...dtypes import Integer -from ...representation import Graph, GraphProcessor, Node, Operation +from ...representation import Graph, MultiGraphProcessor, Node, Operation -class AssignBitWidths(GraphProcessor): +class AssignBitWidths(MultiGraphProcessor): """ AssignBitWidths graph processor, to assign proper bit-widths to be compatible with FHE. @@ -56,52 +56,54 @@ def __init__( self.multivariate_strategy_preference = multivariate_strategy_preference self.min_max_strategy_preference = min_max_strategy_preference - def apply(self, graph: Graph): + def apply_many(self, graphs: Dict[str, Graph]): optimizer = z3.Optimize() - max_bit_width: z3.Int = z3.Int("max") bit_widths: Dict[Node, z3.Int] = {} - additional_constraints = AdditionalConstraints( - optimizer, - graph, - bit_widths, - self.comparison_strategy_preference, - self.bitwise_strategy_preference, - self.shifts_with_promotion, - self.multivariate_strategy_preference, - self.min_max_strategy_preference, - ) + for graph_name, graph in graphs.items(): + max_bit_width: z3.Int = z3.Int(f"{graph_name}.max") + + additional_constraints = AdditionalConstraints( + optimizer, + graph, + bit_widths, + self.comparison_strategy_preference, + self.bitwise_strategy_preference, + self.shifts_with_promotion, + self.multivariate_strategy_preference, + self.min_max_strategy_preference, + ) - nodes = graph.query_nodes(ordered=True) - for i, node in enumerate(nodes): - assert isinstance(node.output.dtype, Integer) - required_bit_width = node.output.dtype.bit_width + nodes = graph.query_nodes(ordered=True) + for i, node in enumerate(nodes): + assert isinstance(node.output.dtype, Integer) + required_bit_width = node.output.dtype.bit_width - bit_width_hint = node.properties.get("bit_width_hint") - if bit_width_hint is not None: - required_bit_width = max(required_bit_width, bit_width_hint) + bit_width_hint = node.properties.get("bit_width_hint") + if bit_width_hint is not None: + required_bit_width = max(required_bit_width, bit_width_hint) - bit_width = z3.Int(f"%{i}") - bit_widths[node] = bit_width + bit_width = z3.Int(f"{graph_name}.%{i}") + bit_widths[node] = bit_width - base_constraint = bit_width >= required_bit_width - node.bit_width_constraints.append(base_constraint) + base_constraint = bit_width >= required_bit_width + node.bit_width_constraints.append(base_constraint) - optimizer.add(base_constraint) - optimizer.add(max_bit_width >= bit_width) + optimizer.add(base_constraint) + optimizer.add(max_bit_width >= bit_width) - additional_constraints.generate_for(node, bit_width) + additional_constraints.generate_for(node, bit_width) - if self.single_precision: - for bit_width in bit_widths.values(): - optimizer.add(bit_width == max_bit_width) + if self.single_precision: + for bit_width in bit_widths.values(): + optimizer.add(bit_width == max_bit_width) - if self.composable: - input_output_bitwidth = z3.Int("input_output") - for node in chain(graph.input_nodes.values(), graph.output_nodes.values()): - bit_width = bit_widths[node] - optimizer.add(bit_width == input_output_bitwidth) + if self.composable: + input_output_bitwidth = z3.Int("input_output") + for node in chain(graph.input_nodes.values(), graph.output_nodes.values()): + bit_width = bit_widths[node] + optimizer.add(bit_width == input_output_bitwidth) optimizer.minimize(sum(bit_width for bit_width in bit_widths.values())) @@ -120,9 +122,9 @@ def apply(self, graph: Graph): node.output.dtype.bit_width, ) node.output.dtype.bit_width = new_bit_width - - graph.bit_width_constraints = optimizer - graph.bit_width_assignments = model + for graph in graphs.values(): + graph.bit_width_constraints = optimizer + graph.bit_width_assignments = model class AdditionalConstraints: diff --git a/frontends/concrete-python/concrete/fhe/representation/__init__.py b/frontends/concrete-python/concrete/fhe/representation/__init__.py index 34640a94e1..30825b1f7a 100644 --- a/frontends/concrete-python/concrete/fhe/representation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/representation/__init__.py @@ -2,6 +2,6 @@ Define structures used to represent computation. """ -from .graph import Graph, GraphProcessor +from .graph import Graph, GraphProcessor, MultiGraphProcessor from .node import Node from .operation import Operation diff --git a/frontends/concrete-python/concrete/fhe/representation/graph.py b/frontends/concrete-python/concrete/fhe/representation/graph.py index aa278f43a1..2ea83e4bfd 100644 --- a/frontends/concrete-python/concrete/fhe/representation/graph.py +++ b/frontends/concrete-python/concrete/fhe/representation/graph.py @@ -40,12 +40,15 @@ class Graph: bit_width_constraints: Optional[z3.Optimize] bit_width_assignments: Optional[z3.Model] + name: str + def __init__( self, graph: nx.MultiDiGraph, input_nodes: Dict[int, Node], output_nodes: Dict[int, Node], is_direct: bool = False, + name: str = "main", ): self.graph = graph @@ -59,6 +62,8 @@ def __init__( self.bit_width_assignments = None self.bit_width_constraints = None + self.name = name + self.prune_useless_nodes() def __call__( @@ -584,18 +589,20 @@ def format_bit_width_assignments(self) -> str: lines = [] for variable in self.bit_width_assignments.decls(): # type: ignore - width = self.bit_width_assignments.get_interp(variable) # type: ignore - lines.append(f"{variable} = {width}") + if variable.name().startswith(f"{self.name}.") or variable.name() == "input_output": + width = self.bit_width_assignments.get_interp(variable) # type: ignore + lines.append(f"{variable} = {width}") def sorter(line: str) -> int: - if line.startswith("max"): + if line.startswith(f"{self.name}.max"): # we won't have 4 million nodes... return 2**32 - - assert line.startswith("%") + if line.startswith("input_output"): + # this is the composable constraint + return 2**32 equals_position = line.find("=") - index = line[1 : equals_position - 1] + index = line[len(self.name) + 2 : equals_position - 1] return int(index) result = "" @@ -967,6 +974,8 @@ def integer_range( class GraphProcessor(ABC): """ GraphProcessor base class, to define the API for a graph processing pipeline. + + Process a single graph. """ @abstractmethod @@ -998,3 +1007,23 @@ def error(graph: Graph, highlights: Mapping[Node, Union[str, List[str]]]): highlighted_nodes=highlights_with_location ) raise RuntimeError(message) + + +class MultiGraphProcessor(GraphProcessor): + """ + MultiGraphProcessor base class, to define the API for a multiple graph processing pipeline. + + Processes multiple graphs at once. + """ + + @abstractmethod + def apply_many(self, graphs: Dict[str, Graph]): + """ + Process a dictionnary of graphs. + """ + + def apply(self, graph: Graph): + """ + Process a single graph. + """ + return self.apply_many({"main": graph}) diff --git a/frontends/concrete-python/concrete/fhe/tracing/tracer.py b/frontends/concrete-python/concrete/fhe/tracing/tracer.py index c5592ee9e7..d6359d0764 100644 --- a/frontends/concrete-python/concrete/fhe/tracing/tracer.py +++ b/frontends/concrete-python/concrete/fhe/tracing/tracer.py @@ -38,6 +38,7 @@ def trace( function: Callable, parameters: Dict[str, ValueDescription], is_direct: bool = False, + name: str = "main", ) -> Graph: """ Trace `function` and create the `Graph` that represents it. @@ -53,6 +54,9 @@ def trace( is_direct (bool, default = False): whether the tracing is done on actual parameters or placeholders + name (str, default = "main"): + the name of the function being traced + Returns: Graph: computation graph corresponding to `function` @@ -164,7 +168,7 @@ def create_graph_from_output_tracers( output_idx: tracer.computation for output_idx, tracer in enumerate(output_tracers) } - return Graph(graph, input_nodes, output_nodes, is_direct) + return Graph(graph, input_nodes, output_nodes, is_direct, name) # pylint: enable=too-many-statements diff --git a/frontends/concrete-python/tests/compilation/test_artifacts.py b/frontends/concrete-python/tests/compilation/test_artifacts.py index edd45fe5c4..78e788ca91 100644 --- a/frontends/concrete-python/tests/compilation/test_artifacts.py +++ b/frontends/concrete-python/tests/compilation/test_artifacts.py @@ -35,13 +35,13 @@ def f(x): assert (tmpdir / "environment.txt").exists() assert (tmpdir / "requirements.txt").exists() - assert (tmpdir / "function.txt").exists() - assert (tmpdir / "parameters.txt").exists() + assert (tmpdir / "main.txt").exists() + assert (tmpdir / "main.parameters.txt").exists() - assert (tmpdir / "1.initial.graph.txt").exists() - assert (tmpdir / "2.after-fusing.graph.txt").exists() - assert (tmpdir / "3.after-fusing.graph.txt").exists() - assert (tmpdir / "4.final.graph.txt").exists() + assert (tmpdir / "main.1.initial.graph.txt").exists() + assert (tmpdir / "main.2.after-fusing.graph.txt").exists() + assert (tmpdir / "main.3.after-fusing.graph.txt").exists() + assert (tmpdir / "main.4.final.graph.txt").exists() assert (tmpdir / "mlir.txt").exists() assert (tmpdir / "client_parameters.json").exists() @@ -51,13 +51,13 @@ def f(x): assert (tmpdir / "environment.txt").exists() assert (tmpdir / "requirements.txt").exists() - assert (tmpdir / "function.txt").exists() - assert (tmpdir / "parameters.txt").exists() + assert (tmpdir / "main.txt").exists() + assert (tmpdir / "main.parameters.txt").exists() - assert (tmpdir / "1.initial.graph.txt").exists() - assert (tmpdir / "2.after-fusing.graph.txt").exists() - assert (tmpdir / "3.after-fusing.graph.txt").exists() - assert (tmpdir / "4.final.graph.txt").exists() + assert (tmpdir / "main.1.initial.graph.txt").exists() + assert (tmpdir / "main.2.after-fusing.graph.txt").exists() + assert (tmpdir / "main.3.after-fusing.graph.txt").exists() + assert (tmpdir / "main.4.final.graph.txt").exists() assert (tmpdir / "mlir.txt").exists() assert (tmpdir / "client_parameters.json").exists() diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index 33b0b69113..f9b41206fc 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -683,8 +683,11 @@ def test_statistics(function, parameters, expected_statistics, helpers): for name, expected_value in expected_statistics.items(): assert hasattr(circuit, name) + attr = getattr(circuit, name) + if callable(attr): + attr = attr() assert ( - getattr(circuit, name) == expected_value + attr == expected_value ), f""" Expected {name} to be {expected_value} but it's {getattr(circuit, name)} diff --git a/frontends/concrete-python/tests/compilation/test_configuration.py b/frontends/concrete-python/tests/compilation/test_configuration.py index d74a77c4b7..3fbf7b50fb 100644 --- a/frontends/concrete-python/tests/compilation/test_configuration.py +++ b/frontends/concrete-python/tests/compilation/test_configuration.py @@ -217,38 +217,38 @@ def test_configuration_bad_fork(kwargs, expected_error, expected_message): """ %0: - %0 >= 3 + main.%0 >= 3 %1: - %1 >= 7 + main.%1 >= 7 %2: - %2 >= 2 + main.%2 >= 2 %3: - %3 >= 5 + main.%3 >= 5 %4: - %4 >= 8 - %3 == %1 - %1 == %4 + main.%4 >= 8 + main.%3 == main.%1 + main.%1 == main.%4 """, """ - %0 = 3 - %1 = 8 - %2 = 2 - %3 = 8 - %4 = 8 -max = 8 + main.%0 = 3 + main.%1 = 8 + main.%2 = 2 + main.%3 = 8 + main.%4 = 8 +main.max = 8 """ if USE_MULTI_PRECISION else """ - %0 = 8 - %1 = 8 - %2 = 8 - %3 = 8 - %4 = 8 -max = 8 + main.%0 = 8 + main.%1 = 8 + main.%2 = 8 + main.%3 = 8 + main.%4 = 8 +main.max = 8 """, ), diff --git a/frontends/concrete-python/tests/compilation/test_decorators.py b/frontends/concrete-python/tests/compilation/test_decorators.py index 24e0c8c0bd..05bd51de68 100644 --- a/frontends/concrete-python/tests/compilation/test_decorators.py +++ b/frontends/concrete-python/tests/compilation/test_decorators.py @@ -87,21 +87,21 @@ def function(x): Bit-Width Constraints -------------------------------------------------------------------------------- %0: - %0 >= 4 + main.%0 >= 4 %1: - %1 >= 6 + main.%1 >= 6 %2: - %2 >= 6 - %0 == %1 - %1 == %2 + main.%2 >= 6 + main.%0 == main.%1 + main.%1 == main.%2 -------------------------------------------------------------------------------- Bit-Width Assignments -------------------------------------------------------------------------------- - %0 = 6 - %1 = 6 - %2 = 6 -max = 6 + main.%0 = 6 + main.%1 = 6 + main.%2 = 6 +main.max = 6 -------------------------------------------------------------------------------- Bit-Width Assigned Computation Graph diff --git a/frontends/concrete-python/tests/compilation/test_program.py b/frontends/concrete-python/tests/compilation/test_program.py new file mode 100644 index 0000000000..11e52a2444 --- /dev/null +++ b/frontends/concrete-python/tests/compilation/test_program.py @@ -0,0 +1,149 @@ +""" +Tests of everything related to multi-circuit. +""" +import numpy as np +import pytest + +from concrete import fhe + +# pylint: disable=missing-class-docstring, missing-function-docstring, no-self-argument +# pylint: disable=unused-variable, no-member +# ruff: noqa: N805 + + +def test_empty_module(): + """ + Test that defining a module without functions is an error. + """ + with pytest.raises( + RuntimeError, match="Tried to define an @fhe.module without any @fhe.function" + ): + + @fhe.module() + class Module: + def square(x): + return x**2 + + +def test_call_clear_circuits(): + """ + Test that calling clear functions works. + """ + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def square(x): + return x**2 + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def add_sub(x, y): + return (x + y), (x - y) + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def mul(x, y): + return x * y + + assert Module.square(2) == 4 + assert Module.add_sub(2, 3) == (5, -1) + assert Module.mul(3, 4) == 12 + + +def test_compile(helpers): + """ + Test that compiling a module works. + """ + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return x + 1 + + @fhe.function({"x": "encrypted"}) + def dec(x): + return x - 1 + + inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + configuration = helpers.configuration().fork( + p_error=0.1, + parameter_selection_strategy="v0", + composable=True, + verbose=True, + ) + Module.compile( + {"inc": inputset, "dec": inputset}, + configuration, + ) + + +def test_compiled_clear_call(helpers): + """ + Test that cleartext execution works on compiled objects. + """ + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return x + 1 + + @fhe.function({"x": "encrypted"}) + def dec(x): + return x - 1 + + inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + configuration = helpers.configuration().fork( + p_error=0.1, + parameter_selection_strategy="v0", + composable=True, + ) + module = Module.compile( + {"inc": inputset, "dec": inputset}, + configuration, + ) + + assert module.inc(5) == 6 + assert module.dec(5) == 4 + + +def test_encrypted_execution(helpers): + """ + Test that encrypted execution works. + """ + + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def inc(x): + return x + 1 % 20 + + @fhe.function({"x": "encrypted"}) + def dec(x): + return x - 1 % 20 + + inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] + configuration = helpers.configuration().fork( + p_error=0.1, + parameter_selection_strategy="v0", + composable=True, + ) + module = Module.compile( + {"inc": inputset, "dec": inputset}, + configuration, + ) + + x = 5 + x_enc = module.inc.encrypt(x) + x_inc_enc = module.inc.run(x_enc) + x_inc = module.inc.decrypt(x_inc_enc) + assert x_inc == 6 + + x_inc_dec_enc = module.dec.run(x_inc_enc) + x_inc_dec = module.dec.decrypt(x_inc_dec_enc) + assert x_inc_dec == 5 + + for _ in range(10): + x_enc = module.inc.run(x_enc) + x_dec = module.inc.decrypt(x_enc) + assert x_dec == 15 diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index c7370da4fe..cce0924bcb 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -1723,7 +1723,7 @@ def test_converter_process_multi_precision(function, parameters, expected_graph, inputset = helpers.generate_inputset(parameters) graph = compiler.trace(inputset, configuration) - GraphConverter(configuration).process(graph) + GraphConverter(configuration).process({"main": graph}) for node in graph.query_nodes(): if "original_bit_width" in node.properties: del node.properties["original_bit_width"] @@ -1765,7 +1765,7 @@ def test_converter_process_single_precision(function, parameters, expected_graph inputset = helpers.generate_inputset(parameters) graph = compiler.trace(inputset, configuration) - GraphConverter(configuration).process(graph) + GraphConverter(configuration).process({"main": graph}) for node in graph.query_nodes(): if "original_bit_width" in node.properties: del node.properties["original_bit_width"]