From 883f19fefd92671c8f50847e3de880666b1ba059 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 31 Jul 2023 16:21:07 +0200 Subject: [PATCH 1/2] chore(frontend-python): fix pylint issues --- .../concrete-python/concrete/fhe/compilation/compiler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frontends/concrete-python/concrete/fhe/compilation/compiler.py b/frontends/concrete-python/concrete/fhe/compilation/compiler.py index f0744f0f21..015c260878 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/compiler.py @@ -2,6 +2,8 @@ Declaration of `Compiler` class. """ +# pylint: disable=import-error,no-name-in-module + import inspect import os import traceback @@ -22,6 +24,8 @@ from .configuration import Configuration from .utils import fuse +# pylint: enable=import-error,no-name-in-module + @unique class EncryptionStatus(str, Enum): From e40a4993e77ffa2da9d025826a0eb3ee99bee2fc Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 26 Jul 2023 14:27:32 +0200 Subject: [PATCH 2/2] feat(compiler): add more detailed statistics --- .../Support/CompilationFeedback.h | 43 ++- .../lib/Bindings/Python/CMakeLists.txt | 1 + .../lib/Bindings/Python/CompilerAPIModule.cpp | 128 ++++++- .../Python/concrete/compiler/__init__.py | 1 + .../concrete/compiler/compilation_feedback.py | 168 ++++++++- .../Python/concrete/compiler/parameter.py | 105 ++++++ .../compiler/lib/Bindings/Rust/build.rs | 1 + .../TFHE/Analysis/ExtractStatistics.cpp | 181 ++++++++- .../lib/Support/CompilationFeedback.cpp | 184 ++++++++-- .../concrete-python/concrete/fhe/__init__.py | 2 +- .../concrete/fhe/compilation/circuit.py | 269 +++++++++++--- .../concrete/fhe/compilation/compiler.py | 25 +- .../concrete/fhe/compilation/server.py | 344 ++++++++++++++++-- .../tests/compilation/test_circuit.py | 91 +++-- 14 files changed, 1339 insertions(+), 204 deletions(-) create mode 100644 compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h index e2ae8145da..278e14b91a 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/CompilationFeedback.h @@ -18,6 +18,30 @@ namespace concretelang { using StringError = ::concretelang::error::StringError; +enum class PrimitiveOperation { + PBS, + WOP_PBS, + KEY_SWITCH, + CLEAR_ADDITION, + ENCRYPTED_ADDITION, + CLEAR_MULTIPLICATION, + ENCRYPTED_NEGATION, +}; + +enum class KeyType { + SECRET, + BOOTSTRAP, + KEY_SWITCH, + PACKING_KEY_SWITCH, +}; + +struct Statistic { + std::string location; + PrimitiveOperation operation; + std::vector> keys; + size_t count; +}; + struct CompilationFeedback { double complexity; @@ -45,23 +69,8 @@ struct CompilationFeedback { /// @brief crt decomposition of outputs, if crt is not used, empty vectors std::vector> crtDecompositionsOfOutputs; - /// @brief number of programmable bootstraps in the entire circuit - uint64_t totalPbsCount = 0; - - /// @brief number of key switches in the entire circuit - uint64_t totalKsCount = 0; - - /// @brief number of clear additions in the entire circuit - uint64_t totalClearAdditionCount = 0; - - /// @brief number of encrypted additions in the entire circuit - uint64_t totalEncryptedAdditionCount = 0; - - /// @brief number of clear multiplications in the entire circuit - uint64_t totalClearMultiplicationCount = 0; - - /// @brief number of encrypted negations in the entire circuit - uint64_t totalEncryptedNegationCount = 0; + /// @brief statistics + std::vector statistics; /// Fill the sizes from the client parameters. void diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt index 32dd0e08fb..92f9b2d6cd 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CMakeLists.txt @@ -57,6 +57,7 @@ declare_mlir_python_sources( concrete/compiler/library_compilation_result.py concrete/compiler/library_support.py concrete/compiler/library_lambda.py + concrete/compiler/parameter.py concrete/compiler/public_arguments.py concrete/compiler/public_result.py concrete/compiler/evaluation_keys.py diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index ccfa2dadab..19b744ccf2 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -109,6 +109,35 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( options.simulate = simulate; }); + pybind11::enum_(m, + "PrimitiveOperation") + .value("PBS", mlir::concretelang::PrimitiveOperation::PBS) + .value("WOP_PBS", mlir::concretelang::PrimitiveOperation::WOP_PBS) + .value("KEY_SWITCH", mlir::concretelang::PrimitiveOperation::KEY_SWITCH) + .value("CLEAR_ADDITION", + mlir::concretelang::PrimitiveOperation::CLEAR_ADDITION) + .value("ENCRYPTED_ADDITION", + mlir::concretelang::PrimitiveOperation::ENCRYPTED_ADDITION) + .value("CLEAR_MULTIPLICATION", + mlir::concretelang::PrimitiveOperation::CLEAR_MULTIPLICATION) + .value("ENCRYPTED_NEGATION", + mlir::concretelang::PrimitiveOperation::ENCRYPTED_NEGATION) + .export_values(); + + pybind11::enum_(m, "KeyType") + .value("SECRET", mlir::concretelang::KeyType::SECRET) + .value("BOOTSTRAP", mlir::concretelang::KeyType::BOOTSTRAP) + .value("KEY_SWITCH", mlir::concretelang::KeyType::KEY_SWITCH) + .value("PACKING_KEY_SWITCH", + mlir::concretelang::KeyType::PACKING_KEY_SWITCH) + .export_values(); + + pybind11::class_(m, "Statistic") + .def_readonly("operation", &mlir::concretelang::Statistic::operation) + .def_readonly("location", &mlir::concretelang::Statistic::location) + .def_readonly("keys", &mlir::concretelang::Statistic::keys) + .def_readonly("count", &mlir::concretelang::Statistic::count); + pybind11::class_( m, "CompilationFeedback") .def_readonly("complexity", @@ -132,22 +161,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def_readonly( "crt_decompositions_of_outputs", &mlir::concretelang::CompilationFeedback::crtDecompositionsOfOutputs) - .def_readonly("total_pbs_count", - &mlir::concretelang::CompilationFeedback::totalPbsCount) - .def_readonly("total_ks_count", - &mlir::concretelang::CompilationFeedback::totalKsCount) - .def_readonly( - "total_clear_addition_count", - &mlir::concretelang::CompilationFeedback::totalClearAdditionCount) - .def_readonly( - "total_encrypted_addition_count", - &mlir::concretelang::CompilationFeedback::totalEncryptedAdditionCount) - .def_readonly("total_clear_multiplication_count", - &mlir::concretelang::CompilationFeedback:: - totalClearMultiplicationCount) - .def_readonly("total_encrypted_negation_count", - &mlir::concretelang::CompilationFeedback:: - totalEncryptedNegationCount); + .def_readonly("statistics", + &mlir::concretelang::CompilationFeedback::statistics); pybind11::class_( m, "JITCompilationResult"); @@ -320,6 +335,76 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( pybind11::class_(m, "KeySetCache") .def(pybind11::init()); + pybind11::class_<::concretelang::clientlib::LweSecretKeyParam>( + m, "LweSecretKeyParam") + .def_readonly("dimension", + &::concretelang::clientlib::LweSecretKeyParam::dimension); + + pybind11::class_<::concretelang::clientlib::BootstrapKeyParam>( + m, "BootstrapKeyParam") + .def_readonly( + "input_secret_key_id", + &::concretelang::clientlib::BootstrapKeyParam::inputSecretKeyID) + .def_readonly( + "output_secret_key_id", + &::concretelang::clientlib::BootstrapKeyParam::outputSecretKeyID) + .def_readonly("level", + &::concretelang::clientlib::BootstrapKeyParam::level) + .def_readonly("base_log", + &::concretelang::clientlib::BootstrapKeyParam::baseLog) + .def_readonly( + "glwe_dimension", + &::concretelang::clientlib::BootstrapKeyParam::glweDimension) + .def_readonly("variance", + &::concretelang::clientlib::BootstrapKeyParam::variance) + .def_readonly( + "polynomial_size", + &::concretelang::clientlib::BootstrapKeyParam::polynomialSize) + .def_readonly( + "input_lwe_dimension", + &::concretelang::clientlib::BootstrapKeyParam::inputLweDimension); + + pybind11::class_<::concretelang::clientlib::KeyswitchKeyParam>( + m, "KeyswitchKeyParam") + .def_readonly( + "input_secret_key_id", + &::concretelang::clientlib::KeyswitchKeyParam::inputSecretKeyID) + .def_readonly( + "output_secret_key_id", + &::concretelang::clientlib::KeyswitchKeyParam::outputSecretKeyID) + .def_readonly("level", + &::concretelang::clientlib::KeyswitchKeyParam::level) + .def_readonly("base_log", + &::concretelang::clientlib::KeyswitchKeyParam::baseLog) + .def_readonly("variance", + &::concretelang::clientlib::KeyswitchKeyParam::variance); + + pybind11::class_<::concretelang::clientlib::PackingKeyswitchKeyParam>( + m, "PackingKeyswitchKeyParam") + .def_readonly("input_secret_key_id", + &::concretelang::clientlib::PackingKeyswitchKeyParam:: + inputSecretKeyID) + .def_readonly("output_secret_key_id", + &::concretelang::clientlib::PackingKeyswitchKeyParam:: + outputSecretKeyID) + .def_readonly("level", + &::concretelang::clientlib::PackingKeyswitchKeyParam::level) + .def_readonly( + "base_log", + &::concretelang::clientlib::PackingKeyswitchKeyParam::baseLog) + .def_readonly( + "glwe_dimension", + &::concretelang::clientlib::PackingKeyswitchKeyParam::glweDimension) + .def_readonly( + "polynomial_size", + &::concretelang::clientlib::PackingKeyswitchKeyParam::polynomialSize) + .def_readonly("input_lwe_dimension", + &::concretelang::clientlib::PackingKeyswitchKeyParam:: + inputLweDimension) + .def_readonly( + "variance", + &::concretelang::clientlib::PackingKeyswitchKeyParam::variance); + pybind11::class_(m, "ClientParameters") .def_static("deserialize", [](const pybind11::bytes &buffer) { @@ -353,7 +438,16 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( } } return result; - }); + }) + .def_readonly("secret_keys", + &mlir::concretelang::ClientParameters::secretKeys) + .def_readonly("bootstrap_keys", + &mlir::concretelang::ClientParameters::bootstrapKeys) + .def_readonly("keyswitch_keys", + &mlir::concretelang::ClientParameters::keyswitchKeys) + .def_readonly( + "packing_keyswitch_keys", + &mlir::concretelang::ClientParameters::packingKeyswitchKeys); pybind11::class_(m, "KeySet") .def_static("deserialize", diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py index 308df4d936..ae1ff6c7d0 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/__init__.py @@ -39,6 +39,7 @@ from .value_exporter import ValueExporter from .simulated_value_decrypter import SimulatedValueDecrypter from .simulated_value_exporter import SimulatedValueExporter +from .parameter import Parameter def init_dfr(): diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py index 8def9497a5..a3274fde22 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_feedback.py @@ -3,13 +3,19 @@ """Compilation feedback.""" +from typing import Dict, Set + # pylint: disable=no-name-in-module,import-error,too-many-instance-attributes from mlir._mlir_libs._concretelang._compiler import ( CompilationFeedback as _CompilationFeedback, + KeyType, + PrimitiveOperation, ) # pylint: enable=no-name-in-module,import-error +from .client_parameters import ClientParameters +from .parameter import Parameter from .wrapper import WrapperCpp @@ -41,19 +47,153 @@ def __init__(self, compilation_feedback: _CompilationFeedback): self.crt_decompositions_of_outputs = ( compilation_feedback.crt_decompositions_of_outputs ) - self.total_pbs_count = compilation_feedback.total_pbs_count - self.total_ks_count = compilation_feedback.total_ks_count - self.total_clear_addition_count = ( - compilation_feedback.total_clear_addition_count - ) - self.total_encrypted_addition_count = ( - compilation_feedback.total_encrypted_addition_count - ) - self.total_clear_multiplication_count = ( - compilation_feedback.total_clear_multiplication_count - ) - self.total_encrypted_negation_count = ( - compilation_feedback.total_encrypted_negation_count - ) + self.statistics = compilation_feedback.statistics super().__init__(compilation_feedback) + + def count(self, *, operations: Set[PrimitiveOperation]) -> int: + """ + Count the amount of specified operations in the program. + + Args: + operations (Set[PrimitiveOperation]): + set of operations used to filter the statistics + + Returns: + int: + number of specified operations in the program + """ + + return sum( + statistic.count + for statistic in self.statistics + if statistic.operation in operations + ) + + def count_per_parameter( + self, + *, + operations: Set[PrimitiveOperation], + key_types: Set[KeyType], + client_parameters: ClientParameters, + ) -> Dict[Parameter, int]: + """ + Count the amount of specified operations in the program and group by parameters. + + Args: + operations (Set[PrimitiveOperation]): + set of operations used to filter the statistics + + key_types (Set[KeyType]): + set of key types used to filter the statistics + + client_parameters (ClientParameters): + client parameters required for grouping by parameters + + Returns: + Dict[Parameter, int]: + number of specified operations per parameter in the program + """ + + result = {} + for statistic in self.statistics: + if statistic.operation not in operations: + continue + + for key_type, key_index in statistic.keys: + if key_type not in key_types: + continue + + parameter = Parameter(client_parameters, key_type, key_index) + if parameter not in result: + result[parameter] = 0 + result[parameter] += statistic.count + + return result + + def count_per_tag(self, *, operations: Set[PrimitiveOperation]) -> Dict[str, int]: + """ + Count the amount of specified operations in the program and group by tags. + + Args: + operations (Set[PrimitiveOperation]): + set of operations used to filter the statistics + + Returns: + Dict[str, int]: + number of specified operations per tag in the program + """ + + result = {} + for statistic in self.statistics: + if statistic.operation not in operations: + continue + + file_and_maybe_tag = statistic.location.split("@") + tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip() + + tag_components = tag.split(".") + for i in range(1, len(tag_components) + 1): + current_tag = ".".join(tag_components[0:i]) + if current_tag == "": + continue + + if current_tag not in result: + result[current_tag] = 0 + + result[current_tag] += statistic.count + + return result + + def count_per_tag_per_parameter( + self, + *, + operations: Set[PrimitiveOperation], + key_types: Set[KeyType], + client_parameters: ClientParameters, + ) -> Dict[str, Dict[Parameter, int]]: + """ + Count the amount of specified operations in the program and group by tags and parameters. + + Args: + operations (Set[PrimitiveOperation]): + set of operations used to filter the statistics + + key_types (Set[KeyType]): + set of key types used to filter the statistics + + client_parameters (ClientParameters): + client parameters required for grouping by parameters + + Returns: + Dict[str, Dict[Parameter, int]]: + number of specified operations per tag per parameter in the program + """ + + result: Dict[str, Dict[int, int]] = {} + for statistic in self.statistics: + if statistic.operation not in operations: + continue + + file_and_maybe_tag = statistic.location.split("@") + tag = "" if len(file_and_maybe_tag) == 1 else file_and_maybe_tag[1].strip() + + tag_components = tag.split(".") + for i in range(1, len(tag_components) + 1): + current_tag = ".".join(tag_components[0:i]) + if current_tag == "": + continue + + if current_tag not in result: + result[current_tag] = {} + + for key_type, key_index in statistic.keys: + if key_type not in key_types: + continue + + parameter = Parameter(client_parameters, key_type, key_index) + if parameter not in result[current_tag]: + result[current_tag][parameter] = 0 + result[current_tag][parameter] += statistic.count + + return result diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py new file mode 100644 index 0000000000..2649ef547e --- /dev/null +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/parameter.py @@ -0,0 +1,105 @@ +""" +Parameter. +""" + +# pylint: disable=no-name-in-module,import-error + +from typing import Union + +from mlir._mlir_libs._concretelang._compiler import ( + LweSecretKeyParam, + BootstrapKeyParam, + KeyswitchKeyParam, + PackingKeyswitchKeyParam, + KeyType, +) + +from .client_parameters import ClientParameters + +# pylint: enable=no-name-in-module,import-error + + +class Parameter: + """ + An FHE parameter. + """ + + _inner: Union[ + LweSecretKeyParam, + BootstrapKeyParam, + KeyswitchKeyParam, + PackingKeyswitchKeyParam, + ] + + def __init__( + self, + client_parameters: ClientParameters, + key_type: KeyType, + key_index: int, + ): + if key_type == KeyType.SECRET: + self._inner = client_parameters.cpp().secret_keys[key_index] + elif key_type == KeyType.BOOTSTRAP: + self._inner = client_parameters.cpp().bootstrap_keys[key_index] + elif key_type == KeyType.KEY_SWITCH: + self._inner = client_parameters.cpp().keyswitch_keys[key_index] + elif key_type == KeyType.PACKING_KEY_SWITCH: + self._inner = client_parameters.cpp().packing_keyswitch_keys[key_index] + else: + raise ValueError("invalid key type") + + def __getattr__(self, item): + return getattr(self._inner, item) + + def __repr__(self): + param = self._inner + + if isinstance(param, LweSecretKeyParam): + result = f"LweSecretKeyParam(" f"dimension={param.dimension}" f")" + + elif isinstance(param, BootstrapKeyParam): + result = ( + f"BootstrapKeyParam(" + f"polynomial_size={param.polynomial_size}, " + f"glwe_dimension={param.glwe_dimension}, " + f"input_lwe_dimension={param.input_lwe_dimension}, " + f"level={param.level}, " + f"base_log={param.base_log}, " + f"variance={param.variance}" + f")" + ) + + elif isinstance(param, KeyswitchKeyParam): + result = ( + f"KeyswitchKeyParam(" + f"level={param.level}, " + f"base_log={param.base_log}, " + f"variance={param.variance}" + f")" + ) + + elif isinstance(param, PackingKeyswitchKeyParam): + result = ( + f"PackingKeyswitchKeyParam(" + f"polynomial_size={param.polynomial_size}, " + f"glwe_dimension={param.glwe_dimension}, " + f"input_lwe_dimension={param.input_lwe_dimension}" + f"level={param.level}, " + f"base_log={param.base_log}, " + f"variance={param.variance}" + f")" + ) + + else: + assert False + + return result + + def __str__(self): + return repr(self) + + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + return str(self) == str(other) diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs index 9b702ea007..37aa1aed23 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Rust/build.rs @@ -301,6 +301,7 @@ const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[ "SDFGDialect", "ExtractSDFGOps", "SDFGToStreamEmulator", + "TFHEDialectAnalysis", ]; fn main() { diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp index c5045a46be..2a13c03f22 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/TFHE/Analysis/ExtractStatistics.cpp @@ -10,6 +10,17 @@ using namespace mlir; using TFHE::ExtractTFHEStatisticsPass; +// ######### +// Utilities +// ######### + +template std::string locationOf(Op op) { + auto location = std::string(); + auto locationStream = llvm::raw_string_ostream(location); + op.getLoc()->print(locationStream); + return location.substr(5, location.size() - 2 - 5); // remove loc(" and ") +} + // ####### // scf.for // ####### @@ -103,7 +114,24 @@ static std::optional on_exit(scf::ForOp &op, static std::optional on_enter(TFHE::AddGLWEOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalEncryptedAdditionCount += pass.iterations; + auto resultingKey = op.getType().getKey().getNormalized(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::ENCRYPTED_ADDITION; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -113,7 +141,24 @@ static std::optional on_enter(TFHE::AddGLWEOp &op, static std::optional on_enter(TFHE::AddGLWEIntOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalClearAdditionCount += pass.iterations; + auto resultingKey = op.getType().getKey().getNormalized(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::CLEAR_ADDITION; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -123,7 +168,24 @@ static std::optional on_enter(TFHE::AddGLWEIntOp &op, static std::optional on_enter(TFHE::BootstrapGLWEOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalPbsCount += pass.iterations; + auto bsk = op.getKey(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::PBS; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -133,7 +195,24 @@ static std::optional on_enter(TFHE::BootstrapGLWEOp &op, static std::optional on_enter(TFHE::KeySwitchGLWEOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalKsCount += pass.iterations; + auto ksk = op.getKey(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::KEY_SWITCH; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -143,7 +222,24 @@ static std::optional on_enter(TFHE::KeySwitchGLWEOp &op, static std::optional on_enter(TFHE::MulGLWEIntOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalClearMultiplicationCount += pass.iterations; + auto resultingKey = op.getType().getKey().getNormalized(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::CLEAR_MULTIPLICATION; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -153,7 +249,24 @@ static std::optional on_enter(TFHE::MulGLWEIntOp &op, static std::optional on_enter(TFHE::NegGLWEOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalEncryptedNegationCount += pass.iterations; + auto resultingKey = op.getType().getKey().getNormalized(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -163,9 +276,34 @@ static std::optional on_enter(TFHE::NegGLWEOp &op, static std::optional on_enter(TFHE::SubGLWEIntOp &op, ExtractTFHEStatisticsPass &pass) { + auto resultingKey = op.getType().getKey().getNormalized(); + + auto location = locationOf(op); + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::SECRET, (size_t)resultingKey->index); + keys.push_back(key); + // clear - encrypted = clear + neg(encrypted) - pass.feedback.totalEncryptedNegationCount += pass.iterations; - pass.feedback.totalClearAdditionCount += pass.iterations; + + auto operation = PrimitiveOperation::ENCRYPTED_NEGATION; + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + + operation = PrimitiveOperation::CLEAR_ADDITION; + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } @@ -175,7 +313,32 @@ static std::optional on_enter(TFHE::SubGLWEIntOp &op, static std::optional on_enter(TFHE::WopPBSGLWEOp &op, ExtractTFHEStatisticsPass &pass) { - pass.feedback.totalPbsCount += pass.iterations; + auto bsk = op.getBsk(); + auto ksk = op.getKsk(); + auto pksk = op.getPksk(); + + auto location = locationOf(op); + auto operation = PrimitiveOperation::WOP_PBS; + auto keys = std::vector>(); + auto count = pass.iterations; + + std::pair key = + std::make_pair(KeyType::BOOTSTRAP, (size_t)bsk.getIndex()); + keys.push_back(key); + + key = std::make_pair(KeyType::KEY_SWITCH, (size_t)ksk.getIndex()); + keys.push_back(key); + + key = std::make_pair(KeyType::PACKING_KEY_SWITCH, (size_t)pksk.getIndex()); + keys.push_back(key); + + pass.feedback.statistics.push_back(Statistic{ + location, + operation, + keys, + count, + }); + return std::nullopt; } diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp index e2a8d91e10..dffdfc76e2 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilationFeedback.cpp @@ -6,7 +6,6 @@ #include #include -#include "boost/outcome.h" #include "llvm/Support/JSON.h" #include "concretelang/Support/CompilationFeedback.h" @@ -62,13 +61,6 @@ void CompilationFeedback::fillFromClientParameters( } crtDecompositionsOfOutputs.push_back(decomposition); } - // Stats - totalPbsCount = 0; - totalKsCount = 0; - totalClearAdditionCount = 0; - totalEncryptedAdditionCount = 0; - totalClearMultiplicationCount = 0; - totalEncryptedNegationCount = 0; } outcome::checked @@ -81,7 +73,7 @@ CompilationFeedback::load(std::string jsonPath) { } auto expectedCompFeedback = llvm::json::parse(content); if (auto err = expectedCompFeedback.takeError()) { - return StringError("Cannot open client parameters: ") + return StringError("Cannot open compilation feedback: ") << llvm::toString(std::move(err)) << "\n" << content << "\n"; } @@ -99,34 +91,166 @@ llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &v) { {"totalInputsSize", v.totalInputsSize}, {"totalOutputsSize", v.totalOutputsSize}, {"crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs}, - {"totalPbsCount", v.totalPbsCount}, - {"totalKsCount", v.totalKsCount}, - {"totalClearAdditionCount", v.totalClearAdditionCount}, - {"totalEncryptedAdditionCount", v.totalEncryptedAdditionCount}, - {"totalClearMultiplicationCount", v.totalClearMultiplicationCount}, - {"totalEncryptedNegationCount", v.totalEncryptedNegationCount}, }; + + auto statisticsJson = llvm::json::Array(); + for (auto statistic : v.statistics) { + auto statisticJson = llvm::json::Object(); + statisticJson.insert({"location", statistic.location}); + switch (statistic.operation) { + case PrimitiveOperation::PBS: + statisticJson.insert({"operation", "PBS"}); + break; + case PrimitiveOperation::WOP_PBS: + statisticJson.insert({"operation", "WOP_PBS"}); + break; + case PrimitiveOperation::KEY_SWITCH: + statisticJson.insert({"operation", "KEY_SWITCH"}); + break; + case PrimitiveOperation::CLEAR_ADDITION: + statisticJson.insert({"operation", "CLEAR_ADDITION"}); + break; + case PrimitiveOperation::ENCRYPTED_ADDITION: + statisticJson.insert({"operation", "ENCRYPTED_ADDITION"}); + break; + case PrimitiveOperation::CLEAR_MULTIPLICATION: + statisticJson.insert({"operation", "CLEAR_MULTIPLICATION"}); + break; + case PrimitiveOperation::ENCRYPTED_NEGATION: + statisticJson.insert({"operation", "ENCRYPTED_NEGATION"}); + break; + } + auto keysJson = llvm::json::Array(); + for (auto &key : statistic.keys) { + KeyType type = key.first; + size_t index = key.second; + + auto keyJson = llvm::json::Array(); + switch (type) { + case KeyType::SECRET: + keyJson.push_back("SECRET"); + break; + case KeyType::BOOTSTRAP: + keyJson.push_back("BOOTSTRAP"); + break; + case KeyType::KEY_SWITCH: + keyJson.push_back("KEY_SWITCH"); + break; + case KeyType::PACKING_KEY_SWITCH: + keyJson.push_back("PACKING_KEY_SWITCH"); + break; + } + keyJson.push_back((int64_t)index); + + keysJson.push_back(std::move(keyJson)); + } + statisticJson.insert({"keys", std::move(keysJson)}); + statisticJson.insert({"count", (int64_t)statistic.count}); + + statisticsJson.push_back(std::move(statisticJson)); + } + object.insert({"statistics", std::move(statisticsJson)}); + return object; } bool fromJSON(const llvm::json::Value j, mlir::concretelang::CompilationFeedback &v, llvm::json::Path p) { llvm::json::ObjectMapper O(j, p); - return O && O.map("complexity", v.complexity) && O.map("pError", v.pError) && - O.map("globalPError", v.globalPError) && - O.map("totalSecretKeysSize", v.totalSecretKeysSize) && - O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) && - O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) && - O.map("totalInputsSize", v.totalInputsSize) && - O.map("totalOutputsSize", v.totalOutputsSize) && - O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs) && - O.map("totalPbsCount", v.totalPbsCount) && - O.map("totalKsCount", v.totalKsCount) && - O.map("totalClearAdditionCount", v.totalClearAdditionCount) && - O.map("totalEncryptedAdditionCount", v.totalEncryptedAdditionCount) && - O.map("totalClearMultiplicationCount", - v.totalClearMultiplicationCount) && - O.map("totalEncryptedNegationCount", v.totalEncryptedNegationCount); + + bool is_success = + O && O.map("complexity", v.complexity) && O.map("pError", v.pError) && + O.map("globalPError", v.globalPError) && + O.map("totalSecretKeysSize", v.totalSecretKeysSize) && + O.map("totalBootstrapKeysSize", v.totalBootstrapKeysSize) && + O.map("totalKeyswitchKeysSize", v.totalKeyswitchKeysSize) && + O.map("totalInputsSize", v.totalInputsSize) && + O.map("totalOutputsSize", v.totalOutputsSize) && + O.map("crtDecompositionsOfOutputs", v.crtDecompositionsOfOutputs); + + if (!is_success) { + return false; + } + + auto object = j.getAsObject(); + if (!object) { + return false; + } + + auto statistics = object->getArray("statistics"); + if (!statistics) { + return false; + } + + for (auto statisticValue : *statistics) { + auto statistic = statisticValue.getAsObject(); + if (!statistic) { + return false; + } + + auto location = statistic->getString("location"); + auto operationStr = statistic->getString("operation"); + auto keysArray = statistic->getArray("keys"); + auto count = statistic->getInteger("count"); + + if (!operationStr || !location || !keysArray || !count) { + return false; + } + + PrimitiveOperation operation; + if (operationStr.value() == "PBS") { + operation = PrimitiveOperation::PBS; + } else if (operationStr.value() == "KEY_SWITCH") { + operation = PrimitiveOperation::KEY_SWITCH; + } else if (operationStr.value() == "WOP_PBS") { + operation = PrimitiveOperation::WOP_PBS; + } else if (operationStr.value() == "CLEAR_ADDITION") { + operation = PrimitiveOperation::CLEAR_ADDITION; + } else if (operationStr.value() == "ENCRYPTED_ADDITION") { + operation = PrimitiveOperation::ENCRYPTED_ADDITION; + } else if (operationStr.value() == "CLEAR_MULTIPLICATION") { + operation = PrimitiveOperation::CLEAR_MULTIPLICATION; + } else if (operationStr.value() == "ENCRYPTED_NEGATION") { + operation = PrimitiveOperation::ENCRYPTED_NEGATION; + } else { + return false; + } + + auto keys = std::vector>(); + for (auto keyValue : *keysArray) { + llvm::json::Array *keyArray = keyValue.getAsArray(); + if (!keyArray || keyArray->size() != 2) { + return false; + } + + auto typeStr = keyArray->front().getAsString(); + auto index = keyArray->back().getAsInteger(); + + if (!typeStr || !index) { + return false; + } + + KeyType type; + if (typeStr.value() == "SECRET") { + type = KeyType::SECRET; + } else if (typeStr.value() == "BOOTSTRAP") { + type = KeyType::BOOTSTRAP; + } else if (typeStr.value() == "KEY_SWITCH") { + type = KeyType::KEY_SWITCH; + } else if (typeStr.value() == "PACKING_KEY_SWITCH") { + type = KeyType::PACKING_KEY_SWITCH; + } else { + return false; + } + + keys.push_back(std::make_pair(type, (size_t)*index)); + } + + v.statistics.push_back( + Statistic{location->str(), operation, keys, (uint64_t)*count}); + } + + return true; } } // namespace concretelang diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 53c1d91cda..38e9a7c202 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -4,7 +4,7 @@ # pylint: disable=import-error,no-name-in-module -from concrete.compiler import EvaluationKeys, PublicArguments, PublicResult +from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, PublicResult from .compilation import ( DEFAULT_GLOBAL_P_ERROR, diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index fe57c1d1b0..bef16e9161 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -4,10 +4,15 @@ # pylint: disable=import-error,no-member,no-name-in-module -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -from concrete.compiler import CompilationContext, SimulatedValueDecrypter, SimulatedValueExporter +from concrete.compiler import ( + CompilationContext, + Parameter, + SimulatedValueDecrypter, + SimulatedValueExporter, +) from mlir.ir import Module as MlirModule from ..internal.utils import assert_that @@ -266,128 +271,304 @@ def cleanup(self): if hasattr(self, "server"): # pragma: no cover self.server.cleanup() + # Properties + + def _property(self, name: str) -> Any: + """ + Get a property of the circuit by name. + + Args: + name (str): + name of the property + + Returns: + Any: + statistic + """ + + if hasattr(self, "simulator"): + return getattr(self.simulator, name) # pragma: no cover + + if not hasattr(self, "server"): + self.enable_fhe_execution() # pragma: no cover + + return getattr(self.server, name) + @property def size_of_secret_keys(self) -> int: """ Get size of the secret keys of the circuit. """ - return self._statistic("size_of_secret_keys") + return self._property("size_of_secret_keys") # pragma: no cover @property def size_of_bootstrap_keys(self) -> int: """ Get size of the bootstrap keys of the circuit. """ - return self._statistic("size_of_bootstrap_keys") + return self._property("size_of_bootstrap_keys") # pragma: no cover @property def size_of_keyswitch_keys(self) -> int: """ Get size of the key switch keys of the circuit. """ - return self._statistic("size_of_keyswitch_keys") + return self._property("size_of_keyswitch_keys") # pragma: no cover @property def size_of_inputs(self) -> int: """ Get size of the inputs of the circuit. """ - return self._statistic("size_of_inputs") + return self._property("size_of_inputs") # pragma: no cover @property def size_of_outputs(self) -> int: """ Get size of the outputs of the circuit. """ - return self._statistic("size_of_outputs") + return self._property("size_of_outputs") # pragma: no cover @property def p_error(self) -> int: """ Get probability of error for each simple TLU (on a scalar). """ - return self._statistic("p_error") + return self._property("p_error") # pragma: no cover @property def global_p_error(self) -> int: """ Get the probability of having at least one simple TLU error during the entire execution. """ - return self._statistic("global_p_error") + return self._property("global_p_error") # pragma: no cover @property def complexity(self) -> float: """ Get complexity of the circuit. """ - return self._statistic("complexity") + return self._property("complexity") # pragma: no cover + + # Programmable Bootstrap Statistics @property - def total_pbs_count(self) -> int: + def programmable_bootstrap_count(self) -> int: """ - Get the total number of programmable bootstraps in the circuit. + Get the number of programmable bootstraps in the circuit. """ - return self._statistic("total_pbs_count") + return self._property("programmable_bootstrap_count") # pragma: no cover @property - def total_ks_count(self) -> int: + def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: """ - Get the total number of key switches in the circuit. + Get the number of programmable bootstraps per bit width in the circuit. """ - return self._statistic("total_ks_count") + return self._property("programmable_bootstrap_count_per_parameter") # pragma: no cover @property - def total_clear_addition_count(self) -> int: + def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: """ - Get the total number of clear additions in the circuit. + Get the number of programmable bootstraps per tag in the circuit. """ - return self._statistic("total_clear_addition_count") + return self._property("programmable_bootstrap_count_per_tag") # pragma: no cover @property - def total_encrypted_addition_count(self) -> int: + def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[int, int]]: """ - Get the total number of encrypted additions in the circuit. + Get the number of programmable bootstraps per tag per bit width in the circuit. """ - return self._statistic("total_encrypted_addition_count") + return self._property( + "programmable_bootstrap_count_per_tag_per_parameter" + ) # pragma: no cover + + # Key Switch Statistics @property - def total_clear_multiplication_count(self) -> int: + def key_switch_count(self) -> int: """ - Get the total number of clear multiplications in the circuit. + Get the number of key switches in the circuit. """ - return self._statistic("total_clear_multiplication_count") + return self._property("key_switch_count") # pragma: no cover @property - def total_encrypted_negation_count(self) -> int: + def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: """ - Get the total number of encrypted negations in the circuit. + Get the number of key switches per parameter in the circuit. """ - return self._statistic("total_encrypted_negation_count") + return self._property("key_switch_count_per_parameter") # pragma: no cover @property - def statistics(self) -> dict: + def key_switch_count_per_tag(self) -> Dict[str, int]: """ - Get all circuit statistics in a dict. + Get the number of key switches per tag in the circuit. """ - return self._statistic("statistics") + return self._property("key_switch_count_per_tag") # pragma: no cover - def _statistic(self, name: str) -> Any: + @property + def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of key switches per tag per parameter in the circuit. """ - Get a statistic of the circuit by name. + return self._property("key_switch_count_per_tag_per_parameter") # pragma: no cover - Args: - name (str): - name of the statistic + # Packing Key Switch Statistics - Returns: - Any: - statistic + @property + def packing_key_switch_count(self) -> int: + """ + Get the number of packing key switches in the circuit. """ + return self._property("packing_key_switch_count") # pragma: no cover - if hasattr(self, "simulator"): - return getattr(self.simulator, name) # pragma: no cover + @property + def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of packing key switches per parameter in the circuit. + """ + return self._property("packing_key_switch_count_per_parameter") # pragma: no cover - if not hasattr(self, "server"): - self.enable_fhe_execution() # pragma: no cover + @property + def packing_key_switch_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of packing key switches per tag in the circuit. + """ + return self._property("packing_key_switch_count_per_tag") # pragma: no cover - return getattr(self.server, name) + @property + def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of packing key switches per tag per parameter in the circuit. + """ + return self._property("packing_key_switch_count_per_tag_per_parameter") # pragma: no cover + + # Clear Addition Statistics + + @property + def clear_addition_count(self) -> int: + """ + Get the number of clear additions in the circuit. + """ + return self._property("clear_addition_count") # pragma: no cover + + @property + def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear additions per parameter in the circuit. + """ + return self._property("clear_addition_count_per_parameter") # pragma: no cover + + @property + def clear_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear additions per tag in the circuit. + """ + return self._property("clear_addition_count_per_tag") # pragma: no cover + + @property + def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear additions per tag per parameter in the circuit. + """ + return self._property("clear_addition_count_per_tag_per_parameter") # pragma: no cover + + # Encrypted Addition Statistics + + @property + def encrypted_addition_count(self) -> int: + """ + Get the number of encrypted additions in the circuit. + """ + return self._property("encrypted_addition_count") # pragma: no cover + + @property + def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted additions per parameter in the circuit. + """ + return self._property("encrypted_addition_count_per_parameter") # pragma: no cover + + @property + def encrypted_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted additions per tag in the circuit. + """ + return self._property("encrypted_addition_count_per_tag") # pragma: no cover + + @property + def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted additions per tag per parameter in the circuit. + """ + return self._property("encrypted_addition_count_per_tag_per_parameter") # pragma: no cover + + # Clear Multiplication Statistics + + @property + def clear_multiplication_count(self) -> int: + """ + Get the number of clear multiplications in the circuit. + """ + return self._property("clear_multiplication_count") # pragma: no cover + + @property + def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear multiplications per parameter in the circuit. + """ + return self._property("clear_multiplication_count_per_parameter") # pragma: no cover + + @property + def clear_multiplication_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear multiplications per tag in the circuit. + """ + return self._property("clear_multiplication_count_per_tag") # pragma: no cover + + @property + def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear multiplications per tag per parameter in the circuit. + """ + return self._property( + "clear_multiplication_count_per_tag_per_parameter" + ) # pragma: no cover + + # Encrypted Negation Statistics + + @property + def encrypted_negation_count(self) -> int: + """ + Get the number of encrypted negations in the circuit. + """ + return self._property("encrypted_negation_count") # pragma: no cover + + @property + def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted negations per parameter in the circuit. + """ + return self._property("encrypted_negation_count_per_parameter") # pragma: no cover + + @property + def encrypted_negation_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted negations per tag in the circuit. + """ + return self._property("encrypted_negation_count_per_tag") # pragma: no cover + + @property + def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted negations per tag per parameter in the circuit. + """ + return self._property("encrypted_negation_count_per_tag_per_parameter") # pragma: no cover + + # All Statistics + + @property + def statistics(self) -> Dict: + """ + Get all statistics of the circuit. + """ + return self._property("statistics") # pragma: no cover diff --git a/frontends/concrete-python/concrete/fhe/compilation/compiler.py b/frontends/concrete-python/concrete/fhe/compilation/compiler.py index 015c260878..af98268f8e 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/compiler.py @@ -485,7 +485,7 @@ def compile( ) columns = 0 - if show_graph or show_mlir or show_optimizer: + if show_graph or show_mlir or show_optimizer or show_statistics: graph = ( self.graph.format() if self.configuration.verbose or self.configuration.show_graph @@ -556,8 +556,27 @@ def compile( print("Statistics") print("-" * columns) - for name, value in circuit.statistics.items(): - print(f"{name}: {value}") + + def pretty(d, indent=0): # pragma: no cover + if indent > 0: + print("{") + + for key, value in d.items(): + if isinstance(value, dict) and len(value) == 0: + continue + + print(" " * indent + str(key) + ": ", end="") + + if isinstance(value, dict): + pretty(value, indent + 1) + else: + print(value) + + if indent > 0: + print(" " * (indent - 1) + "}") + + pretty(circuit.statistics) + print("-" * columns) print() diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index edf78e88b1..f622ad580e 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -8,7 +8,7 @@ import shutil import tempfile from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union # mypy: disable-error-code=attr-defined import concrete.compiler @@ -23,11 +23,12 @@ LibraryCompilationResult, LibraryLambda, LibrarySupport, + Parameter, PublicArguments, set_compiler_logging, set_llvm_debug_flag, ) -from mlir._mlir_libs._concretelang._compiler import OptimizerStrategy +from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation from mlir.ir import Module as MlirModule from ..internal.utils import assert_that @@ -435,66 +436,333 @@ def complexity(self) -> float: """ return self._compilation_feedback.complexity + # Programmable Bootstrap Statistics + + @property + def programmable_bootstrap_count(self) -> int: + """ + Get the number of programmable bootstraps in the compiled program. + """ + return self._compilation_feedback.count( + operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, + ) + + @property + def programmable_bootstrap_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of programmable bootstraps per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, + key_types={KeyType.BOOTSTRAP}, + client_parameters=self.client_specs.client_parameters, + ) + + @property + def programmable_bootstrap_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of programmable bootstraps per tag in the compiled program. + """ + return self._compilation_feedback.count_per_tag( + operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, + ) + @property - def total_pbs_count(self) -> int: + def programmable_bootstrap_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ - Get the total number of programmable bootstraps in the compiled program. + Get the number of programmable bootstraps per tag per parameter in the compiled program. """ - return self._compilation_feedback.total_pbs_count + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.PBS, PrimitiveOperation.WOP_PBS}, + key_types={KeyType.BOOTSTRAP}, + client_parameters=self.client_specs.client_parameters, + ) + + # Key Switch Statistics @property - def total_ks_count(self) -> int: + def key_switch_count(self) -> int: """ - Get the total number of key switches in the compiled program. + Get the number of key switches in the compiled program. """ - return self._compilation_feedback.total_ks_count + return self._compilation_feedback.count( + operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, + ) @property - def total_clear_addition_count(self) -> int: + def key_switch_count_per_parameter(self) -> Dict[Parameter, int]: """ - Get the total number of clear additions in the compiled program. + Get the number of key switches per parameter in the compiled program. """ - return self._compilation_feedback.total_clear_addition_count + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, + key_types={KeyType.KEY_SWITCH}, + client_parameters=self.client_specs.client_parameters, + ) @property - def total_encrypted_addition_count(self) -> int: + def key_switch_count_per_tag(self) -> Dict[str, int]: """ - Get the total number of encrypted additions in the compiled program. + Get the number of key switches per tag in the compiled program. """ - return self._compilation_feedback.total_encrypted_addition_count + return self._compilation_feedback.count_per_tag( + operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, + ) @property - def total_clear_multiplication_count(self) -> int: + def key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: """ - Get the total number of clear multiplications in the compiled program. + Get the number of key switches per tag per parameter in the compiled program. """ - return self._compilation_feedback.total_clear_multiplication_count + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.KEY_SWITCH, PrimitiveOperation.WOP_PBS}, + key_types={KeyType.KEY_SWITCH}, + client_parameters=self.client_specs.client_parameters, + ) + + # Packing Key Switch Statistics @property - def total_encrypted_negation_count(self) -> int: + def packing_key_switch_count(self) -> int: """ - Get the total number of encrypted negations in the compiled program. + Get the number of packing key switches in the compiled program. """ - return self._compilation_feedback.total_encrypted_negation_count + return self._compilation_feedback.count(operations={PrimitiveOperation.WOP_PBS}) @property - def statistics(self) -> dict: + def packing_key_switch_count_per_parameter(self) -> Dict[Parameter, int]: """ - Get all program statistics in a dict. + Get the number of packing key switches per parameter in the compiled program. """ - return { - "size_of_secret_keys": self.size_of_secret_keys, - "size_of_bootstrap_keys": self.size_of_bootstrap_keys, - "size_of_keyswitch_keys": self.size_of_keyswitch_keys, - "size_of_inputs": self.size_of_inputs, - "size_of_outputs": self.size_of_outputs, - "p_error": self.p_error, - "global_p_error": self.global_p_error, - "complexity": self.complexity, - "total_pbs_count": self.total_pbs_count, - "total_ks_count": self.total_ks_count, - "total_clear_addition_count": self.total_clear_addition_count, - "total_encrypted_addition_count": self.total_encrypted_addition_count, - "total_clear_multiplication_count": self.total_clear_multiplication_count, - "total_encrypted_negation_count": self.total_encrypted_negation_count, - } + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.WOP_PBS}, + key_types={KeyType.PACKING_KEY_SWITCH}, + client_parameters=self.client_specs.client_parameters, + ) + + @property + def packing_key_switch_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of packing key switches per tag in the compiled program. + """ + return self._compilation_feedback.count_per_tag(operations={PrimitiveOperation.WOP_PBS}) + + @property + def packing_key_switch_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of packing key switches per tag per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.WOP_PBS}, + key_types={KeyType.PACKING_KEY_SWITCH}, + client_parameters=self.client_specs.client_parameters, + ) + + # Clear Addition Statistics + + @property + def clear_addition_count(self) -> int: + """ + Get the number of clear additions in the compiled program. + """ + return self._compilation_feedback.count(operations={PrimitiveOperation.CLEAR_ADDITION}) + + @property + def clear_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear additions per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.CLEAR_ADDITION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + @property + def clear_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear additions per tag in the compiled program. + """ + return self._compilation_feedback.count_per_tag( + operations={PrimitiveOperation.CLEAR_ADDITION}, + ) + + @property + def clear_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear additions per tag per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.CLEAR_ADDITION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + # Encrypted Addition Statistics + + @property + def encrypted_addition_count(self) -> int: + """ + Get the number of encrypted additions in the compiled program. + """ + return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_ADDITION}) + + @property + def encrypted_addition_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted additions per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.ENCRYPTED_ADDITION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + @property + def encrypted_addition_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted additions per tag in the compiled program. + """ + return self._compilation_feedback.count_per_tag( + operations={PrimitiveOperation.ENCRYPTED_ADDITION}, + ) + + @property + def encrypted_addition_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted additions per tag per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.ENCRYPTED_ADDITION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + # Clear Multiplication Statistics + + @property + def clear_multiplication_count(self) -> int: + """ + Get the number of clear multiplications in the compiled program. + """ + return self._compilation_feedback.count( + operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, + ) + + @property + def clear_multiplication_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of clear multiplications per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + @property + def clear_multiplication_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of clear multiplications per tag in the compiled program. + """ + return self._compilation_feedback.count_per_tag( + operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, + ) + + @property + def clear_multiplication_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of clear multiplications per tag per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.CLEAR_MULTIPLICATION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + # Encrypted Negation Statistics + + @property + def encrypted_negation_count(self) -> int: + """ + Get the number of encrypted negations in the compiled program. + """ + return self._compilation_feedback.count(operations={PrimitiveOperation.ENCRYPTED_NEGATION}) + + @property + def encrypted_negation_count_per_parameter(self) -> Dict[Parameter, int]: + """ + Get the number of encrypted negations per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_parameter( + operations={PrimitiveOperation.ENCRYPTED_NEGATION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + @property + def encrypted_negation_count_per_tag(self) -> Dict[str, int]: + """ + Get the number of encrypted negations per tag in the compiled program. + """ + return self._compilation_feedback.count_per_tag( + operations={PrimitiveOperation.ENCRYPTED_NEGATION}, + ) + + @property + def encrypted_negation_count_per_tag_per_parameter(self) -> Dict[str, Dict[Parameter, int]]: + """ + Get the number of encrypted negations per tag per parameter in the compiled program. + """ + return self._compilation_feedback.count_per_tag_per_parameter( + operations={PrimitiveOperation.ENCRYPTED_NEGATION}, + key_types={KeyType.SECRET}, + client_parameters=self.client_specs.client_parameters, + ) + + # All Statistics + + @property + def statistics(self) -> Dict: + """ + Get all statistics of the compiled program. + """ + attributes = [ + "size_of_secret_keys", + "size_of_bootstrap_keys", + "size_of_keyswitch_keys", + "size_of_inputs", + "size_of_outputs", + "p_error", + "global_p_error", + "complexity", + "programmable_bootstrap_count", + "programmable_bootstrap_count_per_parameter", + "programmable_bootstrap_count_per_tag", + "programmable_bootstrap_count_per_tag_per_parameter", + "key_switch_count", + "key_switch_count_per_parameter", + "key_switch_count_per_tag", + "key_switch_count_per_tag_per_parameter", + "packing_key_switch_count", + "packing_key_switch_count_per_parameter", + "packing_key_switch_count_per_tag", + "packing_key_switch_count_per_tag_per_parameter", + "clear_addition_count", + "clear_addition_count_per_parameter", + "clear_addition_count_per_tag", + "clear_addition_count_per_tag_per_parameter", + "encrypted_addition_count", + "encrypted_addition_count_per_parameter", + "encrypted_addition_count_per_tag", + "encrypted_addition_count_per_tag_per_parameter", + "clear_multiplication_count", + "clear_multiplication_count_per_parameter", + "clear_multiplication_count_per_tag", + "clear_multiplication_count_per_tag_per_parameter", + "encrypted_negation_count", + "encrypted_negation_count_per_parameter", + "encrypted_negation_count_per_tag", + "encrypted_negation_count_per_tag_per_parameter", + ] + return {attribute: getattr(self, attribute) for attribute in attributes} diff --git a/frontends/concrete-python/tests/compilation/test_circuit.py b/frontends/concrete-python/tests/compilation/test_circuit.py index 08dd63c581..b60a6062f5 100644 --- a/frontends/concrete-python/tests/compilation/test_circuit.py +++ b/frontends/concrete-python/tests/compilation/test_circuit.py @@ -523,6 +523,20 @@ def f(x, y): assert f(*inputset[0]) == circuit.simulate(*inputset[0]) +def tagged_function(x, y, z): + """ + A tagged function to test statistics. + """ + with fhe.tag("a"): + x = fhe.univariate(lambda v: v)(x) + with fhe.tag("b"): + y = fhe.univariate(lambda v: v)(y) + with fhe.tag("c"): + z = fhe.univariate(lambda v: v)(z) + + return x + y + z + + @pytest.mark.parametrize( "function,parameters,expected_statistics", [ @@ -532,12 +546,11 @@ def f(x, y): "x": {"status": "encrypted", "range": [0, 10], "shape": ()}, }, { - "total_pbs_count": 1, - "total_ks_count": 1, - "total_clear_addition_count": 0, - "total_encrypted_addition_count": 0, - "total_clear_multiplication_count": 0, - "total_encrypted_negation_count": 0, + "programmable_bootstrap_count": 1, + "clear_addition_count": 0, + "encrypted_addition_count": 0, + "clear_multiplication_count": 0, + "encrypted_negation_count": 0, }, id="x**2 | x.is_encrypted | x.shape == ()", ), @@ -547,11 +560,11 @@ def f(x, y): "x": {"status": "encrypted", "range": [0, 10], "shape": (3,)}, }, { - "total_pbs_count": 3, - "total_clear_addition_count": 0, - "total_encrypted_addition_count": 0, - "total_clear_multiplication_count": 0, - "total_encrypted_negation_count": 0, + "programmable_bootstrap_count": 3, + "clear_addition_count": 0, + "encrypted_addition_count": 0, + "clear_multiplication_count": 0, + "encrypted_negation_count": 0, }, id="x**2 | x.is_encrypted | x.shape == (3,)", ), @@ -561,11 +574,11 @@ def f(x, y): "x": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)}, }, { - "total_pbs_count": 3 * 2, - "total_clear_addition_count": 0, - "total_encrypted_addition_count": 0, - "total_clear_multiplication_count": 0, - "total_encrypted_negation_count": 0, + "programmable_bootstrap_count": 3 * 2, + "clear_addition_count": 0, + "encrypted_addition_count": 0, + "clear_multiplication_count": 0, + "encrypted_negation_count": 0, }, id="x**2 | x.is_encrypted | x.shape == (3, 2)", ), @@ -576,11 +589,11 @@ def f(x, y): "y": {"status": "encrypted", "range": [0, 10], "shape": ()}, }, { - "total_pbs_count": 2, - "total_clear_addition_count": 1, - "total_encrypted_addition_count": 3, - "total_clear_multiplication_count": 0, - "total_encrypted_negation_count": 2, + "programmable_bootstrap_count": 2, + "clear_addition_count": 1, + "encrypted_addition_count": 3, + "clear_multiplication_count": 0, + "encrypted_negation_count": 2, }, id="x * y | x.is_encrypted | x.shape == () | y.is_encrypted | y.shape == ()", ), @@ -591,11 +604,11 @@ def f(x, y): "y": {"status": "encrypted", "range": [0, 10], "shape": (3,)}, }, { - "total_pbs_count": 3 * 2, - "total_clear_addition_count": 3 * 1, - "total_encrypted_addition_count": 3 * 3, - "total_clear_multiplication_count": 0, - "total_encrypted_negation_count": 3 * 2, + "programmable_bootstrap_count": 3 * 2, + "clear_addition_count": 3 * 1, + "encrypted_addition_count": 3 * 3, + "clear_multiplication_count": 0, + "encrypted_negation_count": 3 * 2, }, id="x * y | x.is_encrypted | x.shape == (3,) | y.is_encrypted | y.shape == (3,)", ), @@ -606,14 +619,30 @@ def f(x, y): "y": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)}, }, { - "total_pbs_count": 3 * 2 * 2, - "total_clear_addition_count": 3 * 2 * 1, - "total_encrypted_addition_count": 3 * 2 * 3, - "total_clear_multiplication_count": 0, - "total_encrypted_negation_count": 3 * 2 * 2, + "programmable_bootstrap_count": 3 * 2 * 2, + "clear_addition_count": 3 * 2 * 1, + "encrypted_addition_count": 3 * 2 * 3, + "clear_multiplication_count": 0, + "encrypted_negation_count": 3 * 2 * 2, }, id="x * y | x.is_encrypted | x.shape == (3, 2) | y.is_encrypted | y.shape == (3, 2)", ), + pytest.param( + tagged_function, + { + "x": {"status": "encrypted", "range": [0, 2**3 - 1], "shape": ()}, + "y": {"status": "encrypted", "range": [0, 2**4 - 1], "shape": ()}, + "z": {"status": "encrypted", "range": [0, 2**5 - 1], "shape": ()}, + }, + { + "programmable_bootstrap_count_per_tag": { + "a": 3, + "a.b": 2, + "a.b.c": 1, + }, + }, + id="tagged_function", + ), ], ) def test_statistics(function, parameters, expected_statistics, helpers):