diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h index 24dd6153e9..eb7d5509e8 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Support/V0Parameters.h @@ -67,20 +67,19 @@ struct V0Parameter { }; namespace optimizer { -constexpr double DEFAULT_GLOBAL_P_ERROR = 1.0 / 100000.0; -constexpr double UNSPECIFIED_P_ERROR = NAN; // will use DEFAULT_GLOBAL_P_ERROR -constexpr double UNSPECIFIED_GLOBAL_P_ERROR = +const double DEFAULT_GLOBAL_P_ERROR = 1.0 / 100000.0; +const double UNSPECIFIED_P_ERROR = NAN; // will use DEFAULT_GLOBAL_P_ERROR +const double UNSPECIFIED_GLOBAL_P_ERROR = NAN; // will use DEFAULT_GLOBAL_P_ERROR -constexpr uint DEFAULT_SECURITY = 128; -constexpr uint DEFAULT_FALLBACK_LOG_NORM_WOPPBS = 8; -constexpr bool DEFAULT_DISPLAY = false; -constexpr bool DEFAULT_USE_GPU_CONSTRAINTS = false; -constexpr concrete_optimizer::Encoding DEFAULT_ENCODING = +const uint DEFAULT_SECURITY = 128; +const uint DEFAULT_FALLBACK_LOG_NORM_WOPPBS = 8; +const bool DEFAULT_DISPLAY = false; +const bool DEFAULT_USE_GPU_CONSTRAINTS = false; +const concrete_optimizer::Encoding DEFAULT_ENCODING = concrete_optimizer::Encoding::Auto; -constexpr bool DEFAULT_CACHE_ON_DISK = true; -constexpr uint32_t DEFAULT_CIPHERTEXT_MODULUS_LOG = 64; -constexpr uint32_t DEFAULT_FFT_PRECISION = 53; -constexpr bool DEFAULT_COMPOSABLE = false; +const bool DEFAULT_CACHE_ON_DISK = true; +const uint32_t DEFAULT_CIPHERTEXT_MODULUS_LOG = 64; +const uint32_t DEFAULT_FFT_PRECISION = 53; /// The strategy of the crypto optimization enum Strategy { @@ -96,10 +95,19 @@ enum Strategy { std::string const StrategyLabel[] = {"V0", "dag-mono", "dag-multi"}; -constexpr Strategy DEFAULT_STRATEGY = Strategy::DAG_MULTI; -constexpr concrete_optimizer::MultiParamStrategy DEFAULT_MULTI_PARAM_STRATEGY = +const Strategy DEFAULT_STRATEGY = Strategy::DAG_MULTI; +const concrete_optimizer::MultiParamStrategy DEFAULT_MULTI_PARAM_STRATEGY = concrete_optimizer::MultiParamStrategy::ByPrecision; -constexpr bool DEFAULT_KEY_SHARING = true; +const bool DEFAULT_KEY_SHARING = true; + +struct CompositionRule { + std::string from_func; + size_t from_pos; + std::string to_func; + size_t to_pos; +}; + +const std::vector DEFAULT_COMPOSITION_RULES = {}; struct Config { double p_error; @@ -115,10 +123,10 @@ struct Config { bool cache_on_disk; uint32_t ciphertext_modulus_log; uint32_t fft_precision; - bool composable; + std::vector composition_rules; }; -constexpr Config DEFAULT_CONFIG = { +const Config DEFAULT_CONFIG = { UNSPECIFIED_P_ERROR, UNSPECIFIED_GLOBAL_P_ERROR, DEFAULT_DISPLAY, @@ -132,7 +140,7 @@ constexpr Config DEFAULT_CONFIG = { DEFAULT_CACHE_ON_DISK, DEFAULT_CIPHERTEXT_MODULUS_LOG, DEFAULT_FFT_PRECISION, - DEFAULT_COMPOSABLE, + DEFAULT_COMPOSITION_RULES, }; using Dag = rust::Box; diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index c53b2dca56..4ad6e9bec4 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -4,6 +4,7 @@ // for license information. #include "concretelang/Bindings/Python/CompilerAPIModule.h" +#include "concrete-optimizer.hpp" #include "concrete-protocol.capnp.h" #include "concretelang/ClientLib/ClientLib.h" #include "concretelang/Common/Compat.h" @@ -744,9 +745,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( [](CompilationOptions &options, double global_p_error) { options.optimizerConfig.global_p_error = global_p_error; }) - .def("set_composable", - [](CompilationOptions &options, bool composable) { - options.optimizerConfig.composable = composable; + .def("add_composition", + [](CompilationOptions &options, std::string from_func, + size_t from_pos, std::string to_func, size_t to_pos) { + options.optimizerConfig.composition_rules.push_back( + {from_func, from_pos, to_func, to_pos}); }) .def("set_security_level", [](CompilationOptions &options, int security_level) { diff --git a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py index 7759718e70..b614d5cdaa 100644 --- a/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py +++ b/compilers/concrete-compiler/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py @@ -61,18 +61,27 @@ def new(backend=_Backend.CPU) -> "CompilationOptions": # pylint: enable=arguments-differ - def set_composable(self, composable: bool): - """Set option for composition. + def add_composition(self, from_func: str, from_pos: int, to_func: str, to_pos: int): + """Adds a composition rule. Args: - composable (bool): whether to turn it on or off + from_func(str): the name of the circuit the output comes from. + from_pos(int): the return position of the output. + to_func(str): the name of the circuit the input targets. + to_pos(int): the argument position of the input. Raises: - TypeError: if the value to set is not boolean + TypeError: if the inputs do not have the proper type. """ - if not isinstance(composable, bool): - raise TypeError("can't set the option to a non-boolean value") - self.cpp().set_composable(composable) + if not isinstance(from_func, str): + raise TypeError("expected `from_func` to be (str)") + if not isinstance(from_pos, int): + raise TypeError("expected `from_pos` to be (int)") + if not isinstance(to_func, str): + raise TypeError("expected `to_func` to be (str)") + if not isinstance(from_pos, int): + raise TypeError("expected `to_pos` to be (int)") + self.cpp().add_composition(from_func, from_pos, to_func, to_pos) def set_auto_parallelize(self, auto_parallelize: bool): """Set option for auto parallelization. diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 322bdb8c68..a27ca62cd2 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -976,15 +976,9 @@ std::unique_ptr createDagPass(optimizer::Config config, // Adds the composition rules to the void applyCompositionRules(optimizer::Config config, concrete_optimizer::Dag &dag) { - - if (config.composable) { - auto inputs = dag.get_input_indices(); - auto outputs = dag.get_output_indices(); - dag.add_compositions( - rust::Slice( - outputs.data(), outputs.size()), - rust::Slice( - inputs.data(), inputs.size())); + for (auto rule : config.composition_rules) { + dag.add_composition(rule.from_func, rule.from_pos, rule.to_func, + rule.to_pos); } } diff --git a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp index ee8f938958..3260d920ff 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/CompilerEngine.cpp @@ -178,12 +178,6 @@ CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { if (!description->has_value()) { // The pass has not been run return std::nullopt; } - if (description->value().dag.value()->get_circuit_count() > 1 && - config.strategy != - mlir::concretelang::optimizer::V0) { // Multi circuits without V0 - return StreamStringError( - "Multi-circuits is only supported for V0 optimization."); - } return description; } diff --git a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp index 05e7595ac1..3921c15111 100644 --- a/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp +++ b/compilers/concrete-compiler/compiler/lib/Support/V0Parameters.cpp @@ -377,7 +377,7 @@ getSolution(optimizer::Description &descr, ProgramCompilationFeedback &feedback, if (encoding != concrete_optimizer::Encoding::Crt) { config.encoding = concrete_optimizer::Encoding::Native; auto sol = getDagMultiSolution(descr.dag.value(), config); - if (sol.is_feasible || config.composable) { + if (sol.is_feasible || !config.composition_rules.empty()) { displayOptimizer(sol, descr, config); return toCompilerSolution(sol, feedback, config); } diff --git a/compilers/concrete-compiler/compiler/src/main.cpp b/compilers/concrete-compiler/compiler/src/main.cpp index 67c8d7adc0..8ff15b1891 100644 --- a/compilers/concrete-compiler/compiler/src/main.cpp +++ b/compilers/concrete-compiler/compiler/src/main.cpp @@ -332,12 +332,6 @@ llvm::cl::opt optimizerNoCacheOnDisk( "cache issues."), llvm::cl::init(false)); -llvm::cl::opt optimizerAllowComposition( - "optimizer-allow-composition", - llvm::cl::desc("Optimizer is parameterized to allow calling the circuit on " - "its own output without decryptions."), - llvm::cl::init(false)); - llvm::cl::list fhelinalgTileSizes( "fhelinalg-tile-sizes", llvm::cl::desc( @@ -508,7 +502,6 @@ cmdlineCompilationOptions() { cmdline::optimizerMultiParamStrategy; options.optimizerConfig.encoding = cmdline::optimizerEncoding; options.optimizerConfig.cache_on_disk = !cmdline::optimizerNoCacheOnDisk; - options.optimizerConfig.composable = cmdline::optimizerAllowComposition; if (!std::isnan(options.optimizerConfig.global_p_error) && options.optimizerConfig.strategy == optimizer::Strategy::V0) { diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc index dd15aefcc3..57f569d436 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc @@ -81,7 +81,7 @@ func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> )XXX", "main", false, true, true, DEFAULT_batchTFHEOps, DEFAULT_global_p_error, DEFAULT_chunkedIntegers, DEFAULT_chunkSize, - DEFAULT_chunkWidth, DEFAULT_composable, false); + DEFAULT_chunkWidth, false); const size_t dim0 = 200; const size_t dim1 = 4; @@ -121,8 +121,7 @@ TEST(Distributed, nn_med_sequential) { )XXX", "main", false, false, false, DEFAULT_batchTFHEOps, DEFAULT_global_p_error, DEFAULT_chunkedIntegers, - DEFAULT_chunkSize, DEFAULT_chunkWidth, DEFAULT_composable, - false); + DEFAULT_chunkSize, DEFAULT_chunkWidth, false); const size_t dim0 = 200; const size_t dim1 = 4; diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc index 3357d0d150..aeffd5c3b4 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.cc @@ -396,23 +396,25 @@ func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { TEST(CompileNotComposable, not_composable_1) { mlir::concretelang::CompilationOptions options; - options.optimizerConfig.composable = true; + options.optimizerConfig.composition_rules.push_back({"main", 0, "main", 0}); options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI; TestProgram circuit(options); auto err = circuit.compile(R"XXX( func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { - %cst_1 = arith.constant 1 : i4 - %1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> + %cst_2 = arith.constant 2 : i4 + %1 = "FHE.mul_eint_int"(%arg0, %cst_2) : (!FHE.eint<3>, i4) -> !FHE.eint<3> return %1: !FHE.eint<3> } )XXX"); ASSERT_OUTCOME_HAS_FAILURE_WITH_ERRORMSG( - err, "Program can not be composed: No luts in the circuit."); + err, "Program can not be composed: Dag is not composable, because of " + "output 1: Partition 0 has input coefficient 4"); } TEST(CompileNotComposable, not_composable_2) { mlir::concretelang::CompilationOptions options; - options.optimizerConfig.composable = true; + options.optimizerConfig.composition_rules.push_back({"main", 0, "main", 0}); + options.optimizerConfig.composition_rules.push_back({"main", 1, "main", 0}); options.optimizerConfig.display = true; options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MULTI; TestProgram circuit(options); @@ -430,25 +432,9 @@ func.func @main(%arg0: !FHE.eint<3>) -> (!FHE.eint<3>, !FHE.eint<3>) { "output 1: Partition 0 has input coefficient 4"); } -TEST(CompileComposable, composable_supported_dag_mono) { - mlir::concretelang::CompilationOptions options; - options.optimizerConfig.composable = true; - options.optimizerConfig.display = true; - options.optimizerConfig.strategy = mlir::concretelang::optimizer::DAG_MONO; - TestProgram circuit(options); - auto err = circuit.compile(R"XXX( -func.func @main(%arg0: !FHE.eint<3>) -> !FHE.eint<3> { - %cst_1 = arith.constant 1 : i4 - %1 = "FHE.add_eint_int"(%arg0, %cst_1) : (!FHE.eint<3>, i4) -> !FHE.eint<3> - return %1: !FHE.eint<3> -} -)XXX"); - assert(err.has_value()); -} - TEST(CompileComposable, composable_supported_v0) { mlir::concretelang::CompilationOptions options; - options.optimizerConfig.composable = true; + options.optimizerConfig.composition_rules.push_back({"main", 0, "main", 0}); options.optimizerConfig.display = true; options.optimizerConfig.strategy = mlir::concretelang::optimizer::V0; TestProgram circuit(options); diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h index 4d305c20b6..bddfdc441a 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_jit_test.h @@ -25,7 +25,6 @@ double DEFAULT_global_p_error = TEST_ERROR_RATE; bool DEFAULT_chunkedIntegers = false; unsigned int DEFAULT_chunkSize = 4; unsigned int DEFAULT_chunkWidth = 2; -bool DEFAULT_composable = false; bool DEFAULT_use_multi_parameter = true; // Jit-compiles the function specified by `func` from `src` and @@ -41,7 +40,6 @@ inline Result internalCheckedJit( bool chunkedIntegers = DEFAULT_chunkedIntegers, unsigned int chunkSize = DEFAULT_chunkSize, unsigned int chunkWidth = DEFAULT_chunkWidth, - bool composable = DEFAULT_composable, bool use_multi_parameter = DEFAULT_use_multi_parameter) { auto options = mlir::concretelang::CompilationOptions(); @@ -59,11 +57,6 @@ inline Result internalCheckedJit( options.dataflowParallelize = dataflowParallelize; #endif options.batchTFHEOps = batchTFHEOps; - if (composable) { - options.optimizerConfig.composable = composable; - options.optimizerConfig.strategy = - mlir::concretelang::optimizer::Strategy::DAG_MULTI; - } if (!use_multi_parameter) options.optimizerConfig.strategy = diff --git a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h index 37cee04730..f4c218a5cb 100644 --- a/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h +++ b/compilers/concrete-compiler/compiler/tests/end_to_end_tests/end_to_end_test.h @@ -237,9 +237,6 @@ std::string getOptionsName(mlir::concretelang::CompilationOptions compilation) { if (compilation.optimizerConfig.security != defaultOptions.optimizerConfig.security) os << "_optimizerSecurity" << compilation.optimizerConfig.security; - if (compilation.optimizerConfig.composable != - defaultOptions.optimizerConfig.composable) - os << "_optimizerSecurity" << compilation.optimizerConfig.composable; /// GPU if (compilation.emitGPUOps != defaultOptions.emitGPUOps) diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/Makefile b/compilers/concrete-optimizer/concrete-optimizer-cpp/Makefile index 82b6b2cc63..da55d3e43c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/Makefile +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/Makefile @@ -20,7 +20,7 @@ INTERFACE_CPP = src/cpp/concrete-optimizer.cpp SOURCES = $(shell find $(ROOT)/concrete-optimizer/src) \ $(shell find $(ROOT)/concrete-optimizer-cpp/src -name '*.rs') -build: $(INTERFACE_LIB) +build: $(INTERFACE_LIB) $(INTERFACE_CPP) $(INTERFACE_HEADER) $(INTERFACE_LIB_ORIG) $(INTERFACE_HEADER_ORIG) $(INTERFACE_CPP_ORIG): $(SOURCES) cd $(ROOT) && cargo build -p concrete-optimizer-cpp --profile $(CARGO_PROFILE) diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index c68574145c..a3a52c3ed4 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -16,6 +16,7 @@ use concrete_optimizer::optimization::decomposition; use concrete_optimizer::parameters::{BrDecompositionParameters, KsDecompositionParameters}; use concrete_optimizer::utils::cache::persistent::default_cache_dir; use concrete_optimizer::utils::viz::Viz; +use cxx::CxxString; fn no_solution() -> ffi::Solution { ffi::Solution { @@ -516,47 +517,48 @@ impl Dag { let search_space = SearchSpace::default(processing_unit); let encoding = options.encoding.into(); + if self.0.is_composed() { - let circuit_sol = - concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize( - &self.0, - config, - &search_space, - encoding, - options.default_log_norm2_woppbs, - &caches_from(options), - &Some(PartitionCut::empty()), - ); - let circuit_sol: ffi::CircuitSolution = circuit_sol.into(); - (&circuit_sol).into() - } else { - let result = - concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize( - &self.0, - config, - &search_space, - encoding, - options.default_log_norm2_woppbs, - &caches_from(options), - ); - result.map_or_else(no_dag_solution, |solution| solution.into()) + return no_dag_solution(); } + + let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize( + &self.0, + config, + &search_space, + encoding, + options.default_log_norm2_woppbs, + &caches_from(options), + ); + result.map_or_else(no_dag_solution, |solution| solution.into()) } fn get_circuit_count(&self) -> usize { self.0.get_circuit_count() } - fn add_compositions(&mut self, froms: &[ffi::OperatorIndex], tos: &[ffi::OperatorIndex]) { - self.0.add_compositions( - froms - .iter() - .map(|a| OperatorIndex(a.index)) - .collect::>(), - tos.iter() - .map(|a| OperatorIndex(a.index)) - .collect::>(), - ); + unsafe fn add_composition<'a>( + &mut self, + from_func: &'a CxxString, + from_pos: usize, + to_func: &'a CxxString, + to_pos: usize, + ) { + let from_index = self + .0 + .get_circuit(from_func.to_str().unwrap()) + .get_output_operators_iter() + .nth(from_pos) + .unwrap() + .id; + let to_index = self + .0 + .get_circuit(to_func.to_str().unwrap()) + .get_input_operators_iter() + .nth(to_pos) + .unwrap() + .id; + self.0.add_composition(from_index, to_index); } fn add_all_compositions(&mut self) { @@ -800,7 +802,13 @@ mod ffi { fn optimize(self: &Dag, options: Options) -> DagSolution; - fn add_compositions(self: &mut Dag, froms: &[OperatorIndex], tos: &[OperatorIndex]); + unsafe fn add_composition<'a>( + self: &mut Dag, + from_func: &'a CxxString, + from_pos: usize, + to_func: &'a CxxString, + to_pos: usize, + ); fn add_all_compositions(self: &mut Dag); diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 182bc1cd2a..f20c22d655 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -972,7 +972,7 @@ struct Dag final : public ::rust::Opaque { ::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept; ::rust::String dump() const noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; - void add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept; + void add_composition(::std::string const &from_func, ::std::size_t from_pos, ::std::string const &to_func, ::std::size_t to_pos) noexcept; void add_all_compositions() noexcept; ::std::size_t get_circuit_count() const noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; @@ -1319,7 +1319,7 @@ void concrete_optimizer$cxxbridge1$DagBuilder$tag_operator_as_output(::concrete_ void concrete_optimizer$cxxbridge1$Dag$optimize(::concrete_optimizer::Dag const &self, ::concrete_optimizer::Options options, ::concrete_optimizer::dag::DagSolution *return$) noexcept; -void concrete_optimizer$cxxbridge1$Dag$add_compositions(::concrete_optimizer::Dag &self, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept; +void concrete_optimizer$cxxbridge1$Dag$add_composition(::concrete_optimizer::Dag &self, ::std::string const &from_func, ::std::size_t from_pos, ::std::string const &to_func, ::std::size_t to_pos) noexcept; void concrete_optimizer$cxxbridge1$Dag$add_all_compositions(::concrete_optimizer::Dag &self) noexcept; } // extern "C" @@ -1449,8 +1449,8 @@ ::concrete_optimizer::dag::DagSolution Dag::optimize(::concrete_optimizer::Optio return ::std::move(return$.value); } -void Dag::add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept { - concrete_optimizer$cxxbridge1$Dag$add_compositions(*this, froms, tos); +void Dag::add_composition(::std::string const &from_func, ::std::size_t from_pos, ::std::string const &to_func, ::std::size_t to_pos) noexcept { + concrete_optimizer$cxxbridge1$Dag$add_composition(*this, from_func, from_pos, to_func, to_pos); } void Dag::add_all_compositions() noexcept { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index e893d708d2..abd0685d65 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -953,7 +953,7 @@ struct Dag final : public ::rust::Opaque { ::rust::Box<::concrete_optimizer::DagBuilder> builder(::rust::String circuit) noexcept; ::rust::String dump() const noexcept; ::concrete_optimizer::dag::DagSolution optimize(::concrete_optimizer::Options options) const noexcept; - void add_compositions(::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> froms, ::rust::Slice<::concrete_optimizer::dag::OperatorIndex const> tos) noexcept; + void add_composition(::std::string const &from_func, ::std::size_t from_pos, ::std::string const &to_func, ::std::size_t to_pos) noexcept; void add_all_compositions() noexcept; ::std::size_t get_circuit_count() const noexcept; ::concrete_optimizer::dag::CircuitSolution optimize_multi(::concrete_optimizer::Options options) const noexcept; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp index c2ca86991d..6bfd7bde27 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/tests/src/main.cpp @@ -231,79 +231,6 @@ TEST test_multi_parameters_2_precision_crt() { assert(circuit_solution.circuit_keys.conversion_keyswitch_keys.size() == 0); } -TEST test_composable_dag_mono_fallback_on_dag_multi() { - auto dag = concrete_optimizer::dag::empty(); - auto builder = dag->builder("test"); - - std::vector shape = {}; - - concrete_optimizer::dag::OperatorIndex input1 = - builder->add_input(PRECISION_8B, slice(shape)); - - std::vector inputs = {input1}; - std::vector weight_vec = {1 << 8}; - rust::cxxbridge1::Box weights1 = - concrete_optimizer::weights::vector(slice(weight_vec)); - - input1 = builder->add_dot(slice(inputs), std::move(weights1)); - std::vector table = {}; - auto lut1 = builder->add_lut(input1, slice(table), PRECISION_8B); - std::vector lut1v = {lut1}; - rust::cxxbridge1::Box weights2 = - concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = builder->add_dot(slice(lut1v), std::move(weights2)); - builder->tag_operator_as_output(id); - - auto options = default_options(); - auto solution1 = dag->optimize(options); - assert(!solution1.use_wop_pbs); - assert(solution1.p_error < options.maximum_acceptable_error_probability); - - dag->add_all_compositions(); - auto solution2 = dag->optimize(options); - assert(!solution2.use_wop_pbs); - assert(solution2.p_error < options.maximum_acceptable_error_probability); - assert(solution1.complexity < solution2.complexity); -} - -TEST test_non_composable_dag_mono_fallback_on_woppbs() { - auto dag = concrete_optimizer::dag::empty(); - auto builder = dag->builder("test"); - - std::vector shape = {}; - - concrete_optimizer::dag::OperatorIndex input1 = - builder->add_input(PRECISION_8B, slice(shape)); - - - std::vector inputs = {input1}; - std::vector weight_vec = {1 << 16}; - rust::cxxbridge1::Box weights1 = - concrete_optimizer::weights::vector(slice(weight_vec)); - - input1 = builder->add_dot(slice(inputs), std::move(weights1)); - std::vector table = {}; - auto lut1 = builder->add_lut(input1, slice(table), PRECISION_8B); - std::vector lut1v = {lut1}; - rust::cxxbridge1::Box weights2 = - concrete_optimizer::weights::vector(slice(weight_vec)); - auto id = builder->add_dot(slice(lut1v), std::move(weights2)); - builder->tag_operator_as_output(id); - - auto options = default_options(); - - auto solution1 = dag->optimize(options); - assert(!solution1.use_wop_pbs); - assert(solution1.p_error < options.maximum_acceptable_error_probability); - - dag->add_all_compositions(); - auto solution2 = dag->optimize(options); - assert(solution2.p_error < options.maximum_acceptable_error_probability); - assert(solution1.complexity < solution2.complexity); - assert(solution2.use_wop_pbs); - -} - int main() { test_v0(); @@ -314,8 +241,6 @@ int main() { test_multi_parameters_1_precision(); test_multi_parameters_2_precision(); test_multi_parameters_2_precision_crt(); - test_composable_dag_mono_fallback_on_dag_multi(); - test_non_composable_dag_mono_fallback_on_woppbs(); return 0; } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index 572d6561f8..c146803ab9 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -21,7 +21,7 @@ use crate::optimization::dag::multi_parameters::feasible::Feasible; use crate::optimization::dag::multi_parameters::partition_cut::PartitionCut; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; use crate::optimization::dag::multi_parameters::{analyze, keys_spec}; -use crate::optimization::Err::{NoParametersFound, NotComposable}; +use crate::optimization::Err::NoParametersFound; use super::keys_spec::InstructionKeys; @@ -1165,13 +1165,7 @@ pub fn optimize_to_circuit_solution( persistent_caches: &PersistDecompCaches, p_cut: &Option, ) -> keys_spec::CircuitSolution { - if lut_count_from_dag(dag) == 0 { - // If there are no lut in the dag the noise is never refresh so the dag cannot be composable - if dag.is_composed() { - return keys_spec::CircuitSolution::no_solution( - NotComposable("No luts in the circuit.".into()).to_string(), - ); - } + if lut_count_from_dag(dag) == 0 && !dag.is_composed() { let nb_instr = dag.operators.len(); if let Some(sol) = optimize_mono(dag, config, search_space, persistent_caches).best_solution { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs index 4684790bcb..0ac2dc8e9d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/viz.rs @@ -184,6 +184,7 @@ macro_rules! _viz { )) .output() .expect("Failed to execute dot. Do you have graphviz installed ?"); + path }}; } @@ -191,13 +192,12 @@ macro_rules! _viz { #[allow(unused)] macro_rules! viz { ($path: expr, $object:expr) => { - $crate::utils::viz::_viz!($path, $object); + let path = $crate::utils::viz::_viz!($path, $object); println!( - "Viz of {}:{} visible at {}/{}", + "Viz of {}:{} visible at {}", file!(), line!(), - std::env::temp_dir().display(), - $path + path.display() ); }; ($object:expr) => { @@ -210,13 +210,12 @@ macro_rules! viz { #[allow(unused)] macro_rules! vizp { ($path: expr, $object:expr) => {{ - $crate::utils::viz::_viz!($path, $object); + let path = $crate::utils::viz::_viz!($path, $object); panic!( - "Viz of {}:{} visible at {}/{}", + "Viz of {}:{} visible at {}", file!(), line!(), - std::env::temp_dir().display(), - $path + path.display() ); }}; ($object:expr) => { diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 7815052543..866e9dfa77 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -9,6 +9,9 @@ from .compilation import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, + AllComposable, + AllInputs, + AllOutputs, ApproximateRoundingConfig, BitwiseStrategy, Circuit, @@ -16,19 +19,25 @@ ClientSpecs, ComparisonStrategy, Compiler, + CompositionPolicy, Configuration, DebugArtifacts, EncryptionStatus, Exactness, FunctionDebugArtifacts, + Input, Keys, MinMaxStrategy, ModuleDebugArtifacts, MultiParameterStrategy, MultivariateStrategy, + NotComposable, + Output, ParameterSelectionStrategy, Server, Value, + Wire, + Wired, inputset, ) from .compilation.decorators import circuit, compiler, function, module diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index ca7a5f02d1..8a2ff857d0 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -6,6 +6,7 @@ from .circuit import Circuit from .client import Client from .compiler import Compiler, EncryptionStatus +from .composition import CompositionClause, CompositionPolicy, CompositionRule from .configuration import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, @@ -21,7 +22,18 @@ ) from .keys import Keys from .module import FheFunction, FheModule -from .module_compiler import FunctionDef, ModuleCompiler +from .module_compiler import ( + AllComposable, + AllInputs, + AllOutputs, + FunctionDef, + Input, + ModuleCompiler, + NotComposable, + Output, + Wire, + Wired, +) from .server import Server from .specs import ClientSpecs from .utils import inputset diff --git a/frontends/concrete-python/concrete/fhe/compilation/circuit.py b/frontends/concrete-python/concrete/fhe/compilation/circuit.py index 3f8db94d0e..994c17cefd 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/circuit.py +++ b/frontends/concrete-python/concrete/fhe/compilation/circuit.py @@ -5,7 +5,7 @@ # pylint: disable=import-error,no-member,no-name-in-module from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np from concrete.compiler import ( @@ -19,6 +19,7 @@ from ..internal.utils import assert_that from ..representation import Graph from .client import Client +from .composition import CompositionRule from .configuration import Configuration from .keys import Keys from .server import Server @@ -38,6 +39,7 @@ class Circuit: graph: Graph mlir_module: MlirModule compilation_context: CompilationContext + composition_rules: Optional[List[CompositionRule]] client: Client server: Server @@ -49,8 +51,10 @@ def __init__( mlir: MlirModule, compilation_context: CompilationContext, configuration: Optional[Configuration] = None, + composition_rules: Optional[Iterable[CompositionRule]] = None, ): self.configuration = configuration if configuration is not None else Configuration() + self.composition_rules = list(composition_rules) if composition_rules else [] self.graph = graph self.mlir_module = mlir @@ -118,6 +122,7 @@ def enable_fhe_simulation(self): self.configuration, is_simulated=True, compilation_context=self.compilation_context, + composition_rules=self.composition_rules, ) def enable_fhe_execution(self): @@ -127,7 +132,10 @@ def enable_fhe_execution(self): if not hasattr(self, "server"): self.server = Server.create( - self.mlir_module, self.configuration, compilation_context=self.compilation_context + self.mlir_module, + self.configuration, + compilation_context=self.compilation_context, + composition_rules=self.composition_rules, ) keyset_cache_directory = None diff --git a/frontends/concrete-python/concrete/fhe/compilation/compiler.py b/frontends/concrete-python/concrete/fhe/compilation/compiler.py index 99ac0c76b1..da01388401 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/compiler.py @@ -9,6 +9,7 @@ import traceback from copy import deepcopy from enum import Enum, unique +from itertools import product, repeat from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -21,6 +22,7 @@ from ..values import ValueDescription from .artifacts import DebugArtifacts from .circuit import Circuit +from .composition import CompositionClause, CompositionRule from .configuration import Configuration from .utils import fuse, get_terminal_size @@ -520,9 +522,26 @@ def compile( print() + # We generate the composition rules if needed: + composition_rules = [] + if self.configuration.composable: + compo_froms = map( + CompositionClause.create, + zip(repeat(self.graph.name), range(len(self.graph.output_nodes))), + ) + compo_tos = map( + CompositionClause.create, + zip(repeat(self.graph.name), range(len(self.graph.input_nodes))), + ) + composition_rules = list( + map(CompositionRule.create, product(compo_froms, compo_tos)) + ) + # in-memory MLIR module mlir_context = self.compilation_context.mlir_context() - mlir_module = GraphConverter(self.configuration).convert(self.graph, mlir_context) + mlir_module = GraphConverter(self.configuration, composition_rules).convert( + self.graph, mlir_context + ) # textual representation of the MLIR module mlir_str = str(mlir_module).strip() if self.artifacts is not None: @@ -589,6 +608,7 @@ def compile( mlir_module, self.compilation_context, self.configuration, + composition_rules, ) if hasattr(circuit, "client"): diff --git a/frontends/concrete-python/concrete/fhe/compilation/composition.py b/frontends/concrete-python/concrete/fhe/compilation/composition.py new file mode 100644 index 0000000000..10bcadeab2 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/compilation/composition.py @@ -0,0 +1,53 @@ +""" +Declaration of classes related to composition. +""" + +# pylint: disable=import-error,no-name-in-module + +from typing import Iterable, List, NamedTuple, Protocol, Tuple, runtime_checkable + +from ..representation import Graph + + +class CompositionClause(NamedTuple): + """ + A raw composition clause. + """ + + func: str + pos: int + + @staticmethod + def create(tup: Tuple[str, int]) -> "CompositionClause": + """ + Create a composition clause from a tuple of a function name and a position. + """ + return CompositionClause(tup[0], tup[1]) + + +class CompositionRule(NamedTuple): + """ + A raw composition rule. + """ + + from_: CompositionClause + to: CompositionClause + + @staticmethod + def create(tup: Tuple[CompositionClause, CompositionClause]) -> "CompositionRule": + """ + Create a composition rule from a tuple containing an output clause and an input clause. + """ + return CompositionRule(tup[0], tup[1]) + + +@runtime_checkable +class CompositionPolicy(Protocol): + """ + A protocol for composition policies. + """ + + def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]: + """ + Return an iterator over composition rules. + """ diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 0c7b1c5fd6..e8438984b8 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -977,7 +977,6 @@ class Configuration: shifts_with_promotion: bool multivariate_strategy_preference: List[MultivariateStrategy] min_max_strategy_preference: List[MinMaxStrategy] - composable: bool use_gpu: bool relu_on_bits_threshold: int relu_on_bits_chunk_size: int diff --git a/frontends/concrete-python/concrete/fhe/compilation/decorators.py b/frontends/concrete-python/concrete/fhe/compilation/decorators.py index 2c1f2b6e21..f73c4b3edc 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/decorators.py +++ b/frontends/concrete-python/concrete/fhe/compilation/decorators.py @@ -14,7 +14,7 @@ from .circuit import Circuit from .compiler import Compiler, EncryptionStatus from .configuration import Configuration -from .module_compiler import FunctionDef, ModuleCompiler +from .module_compiler import AllComposable, CompositionPolicy, FunctionDef, ModuleCompiler def circuit( @@ -179,7 +179,9 @@ def decoration(class_): if not functions: error = "Tried to define an @fhe.module without any @fhe.function" raise RuntimeError(error) - return ModuleCompiler([f for (_, f) in functions]) + composition = getattr(class_, "composition", AllComposable()) + assert isinstance(composition, CompositionPolicy) + return ModuleCompiler([f for (_, f) in functions], composition) return decoration diff --git a/frontends/concrete-python/concrete/fhe/compilation/module.py b/frontends/concrete-python/concrete/fhe/compilation/module.py index 1d20c90f12..ee4dce215c 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module.py @@ -5,7 +5,7 @@ # pylint: disable=import-error,no-member,no-name-in-module from pathlib import Path -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union import numpy as np from concrete.compiler import ( @@ -19,6 +19,7 @@ from ..internal.utils import assert_that from ..representation import Graph from .client import Client +from .composition import CompositionRule from .configuration import Configuration from .keys import Keys from .server import Server @@ -530,6 +531,7 @@ def __init__( mlir: MlirModule, compilation_context: CompilationContext, configuration: Optional[Configuration] = None, + composition_rules: Optional[Iterable[CompositionRule]] = None, ): assert configuration and (configuration.fhe_simulation or configuration.fhe_execution) @@ -548,7 +550,10 @@ def __init__( self.runtime = SimulationRt(server) else: server = Server.create( - self.mlir_module, self.configuration, compilation_context=self.compilation_context + self.mlir_module, + self.configuration, + compilation_context=self.compilation_context, + composition_rules=composition_rules, ) keyset_cache_directory = None diff --git a/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py b/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py index bd756d2180..de76e6b6a9 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py +++ b/frontends/concrete-python/concrete/fhe/compilation/module_compiler.py @@ -7,8 +7,21 @@ import inspect import traceback from copy import deepcopy +from itertools import chain, product, repeat from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Protocol, + Tuple, + Union, + runtime_checkable, +) import numpy as np from concrete.compiler import CompilationContext @@ -20,6 +33,7 @@ from ..values import ValueDescription from .artifacts import FunctionDebugArtifacts, ModuleDebugArtifacts from .compiler import EncryptionStatus +from .composition import CompositionClause, CompositionPolicy, CompositionRule from .configuration import Configuration from .module import ExecutionRt, FheModule from .utils import fuse, get_terminal_size @@ -235,6 +249,164 @@ def __call__( return self.graph(*args) +class NotComposable: + """ + Composition policy that does not allow the forwarding of any output to any input. + """ + + def get_rules_iter(self, _funcs: List[FunctionDef]) -> Iterable[CompositionRule]: + """ + Return an iterator over composition rules. + """ + return [] + + +class AllComposable: + """ + Composition policy that allows to forward any output of the module to any of its input. + """ + + def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]: + """ + Return an iterator over composition rules. + """ + outputs = chain( + *[ + map(CompositionClause.create, zip(repeat(f.name), range(len(f.output_nodes)))) + for f in funcs + ] + ) + inputs = chain( + *[ + map(CompositionClause.create, zip(repeat(f.name), range(len(f.input_nodes)))) + for f in funcs + ] + ) + return map(CompositionRule.create, product(outputs, inputs)) + + +@runtime_checkable +class WireOutput(Protocol): + """ + A protocol for wire outputs. + """ + + def get_outputs_iter(self) -> Iterable[CompositionClause]: + """ + Return an iterator over the possible outputs of the wire output. + """ + + +@runtime_checkable +class WireInput(Protocol): + """ + A protocol for wire inputs. + """ + + def get_inputs_iter(self) -> Iterable[CompositionClause]: + """ + Return an iterator over the possible inputs of the wire input. + """ + + +class Output(NamedTuple): + """ + The output of a given function of a module. + """ + + func: FunctionDef + pos: int + + def get_outputs_iter(self) -> Iterable[CompositionClause]: + """ + Return an iterator over the possible outputs of the wire output. + """ + return [CompositionClause(self.func.name, self.pos)] + + +class AllOutputs(NamedTuple): + """ + All the outputs of a given function of a module. + """ + + func: FunctionDef + + def get_outputs_iter(self) -> Iterable[CompositionClause]: + """ + Return an iterator over the possible outputs of the wire output. + """ + assert self.func.graph + return map( + CompositionClause.create, + zip(repeat(self.func.name), range(self.func.graph.outputs_count)), + ) + + +class Input(NamedTuple): + """ + The input of a given function of a module. + """ + + func: FunctionDef + pos: int + + def get_inputs_iter(self) -> Iterable[CompositionClause]: + """ + Return an iterator over the possible inputs of the wire input. + """ + return [CompositionClause(self.func.name, self.pos)] + + +class AllInputs(NamedTuple): + """ + All the inputs of a given function of a module. + """ + + func: FunctionDef + + def get_inputs_iter(self) -> Iterable[CompositionClause]: + """ + Return an iterator over the possible inputs of the wire input. + """ + assert self.func.graph + return map( + CompositionClause.create, + zip(repeat(self.func.name), range(self.func.graph.inputs_count)), + ) + + +class Wire(NamedTuple): + """ + A forwarding rule between an output and an input. + """ + + output: WireOutput + input: WireInput + + def get_rules_iter(self, _) -> Iterable[CompositionRule]: + """ + Return an iterator over composition rules. + """ + return map( + CompositionRule.create, + product(self.output.get_outputs_iter(), self.input.get_inputs_iter()), + ) + + +class Wired(NamedTuple): + """ + Composition policy which allows the forwarding of certain outputs to certain inputs. + """ + + wires: List[Wire] + + def get_rules_iter(self, _) -> Iterable[CompositionRule]: + """ + Return an iterator over composition rules. + """ + return chain(*[w.get_rules_iter(_) for w in self.wires]) + + class DebugManager: """ A debug manager, allowing streamlined debugging. @@ -450,15 +622,16 @@ class ModuleCompiler: default_configuration: Configuration functions: Dict[str, FunctionDef] compilation_context: CompilationContext + composition: CompositionPolicy - def __init__(self, functions: List[FunctionDef]): + def __init__(self, functions: List[FunctionDef], composition: CompositionPolicy): self.default_configuration = Configuration( p_error=0.00001, - composable=True, - parameter_selection_strategy="v0", + parameter_selection_strategy="multi", ) self.functions = {function.name: function for function in functions} self.compilation_context = CompilationContext.new() + self.composition = composition def compile( self, @@ -492,9 +665,6 @@ def compile( configuration = deepcopy(configuration) if len(kwargs) != 0: configuration = configuration.fork(**kwargs) - if not configuration.composable: - error = "Module can only be compiled with `composable` activated." - raise RuntimeError(error) module_artifacts = ( module_artifacts if module_artifacts is not None else ModuleDebugArtifacts() @@ -518,10 +688,18 @@ def compile( # Convert the graphs to an mlir module mlir_context = self.compilation_context.mlir_context() graphs = {} + for name, function in self.functions.items(): assert function.graph is not None graphs[name] = function.graph - mlir_module = GraphConverter(configuration).convert_many(graphs, mlir_context) + + # pylint: disable=protected-access + mlir_module = GraphConverter( + configuration, + self.composition.get_rules_iter( + list(filter(None, [f.graph for f in self.functions.values()])) + ), + ).convert_many(graphs, mlir_context) mlir_str = str(mlir_module).strip() dbg.debug_mlir(mlir_str) module_artifacts.add_mlir_to_compile(mlir_str) @@ -534,7 +712,16 @@ def compile( # Compile to a module! with dbg.debug_table("Optimizer", activate=dbg.show_optimizer()): - output = FheModule(graphs, mlir_module, self.compilation_context, configuration) + # pylint: disable=protected-access + output = FheModule( + graphs, + mlir_module, + self.compilation_context, + configuration, + self.composition.get_rules_iter( + list(filter(None, [f.graph for f in self.functions.values()])) + ), + ) if isinstance(output.runtime, ExecutionRt): client_parameters = output.runtime.client.specs.client_parameters module_artifacts.add_client_parameters(client_parameters.serialize()) diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index 6f31663f2a..06f9c6465a 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -7,7 +7,7 @@ import shutil import tempfile from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union # mypy: disable-error-code=attr-defined import concrete.compiler @@ -35,6 +35,7 @@ from mlir.ir import Module as MlirModule from ..internal.utils import assert_that +from .composition import CompositionRule from .configuration import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, @@ -95,6 +96,7 @@ def create( configuration: Configuration, is_simulated: bool = False, compilation_context: Optional[CompilationContext] = None, + composition_rules: Optional[Iterable[CompositionRule]] = None, ) -> "Server": """ Create a server using MLIR and output sign information. @@ -111,6 +113,9 @@ def create( compilation_context (CompilationContext): context to use for the Compiler + + composition_rules (Iterable[Tuple[str, int, str, int]]): + composition rules to be applied when compiling """ backend = Backend.GPU if configuration.use_gpu else Backend.CPU @@ -123,10 +128,12 @@ def create( options.set_auto_parallelize(configuration.auto_parallelize) options.set_compress_evaluation_keys(configuration.compress_evaluation_keys) options.set_compress_input_ciphertexts(configuration.compress_input_ciphertexts) - options.set_composable(configuration.composable) options.set_enable_overflow_detection_in_simulation( configuration.detect_overflow_in_simulation ) + composition_rules = composition_rules if composition_rules else [] + for rule in composition_rules: + options.add_composition(rule.from_.func, rule.from_.pos, rule.to.func, rule.to.pos) if configuration.auto_parallelize or configuration.dataflow_parallelize: # pylint: disable=c-extension-no-member,no-member diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 6245c2fa18..42dea77e3e 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -6,7 +6,7 @@ import math import sys -from typing import Dict, List, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import concrete.lang import concrete.lang.dialects.tracing @@ -18,6 +18,7 @@ from mlir.ir import Location as MlirLocation from mlir.ir import Module as MlirModule +from ..compilation.composition import CompositionRule from ..compilation.configuration import Configuration, Exactness from ..representation import Graph, GraphProcessor, MultiGraphProcessor, Node, Operation from .context import Context @@ -34,9 +35,15 @@ class Converter: """ configuration: Configuration + composition_rules: List[CompositionRule] - def __init__(self, configuration: Configuration): + def __init__( + self, + configuration: Configuration, + composition_rules: Optional[Iterable[CompositionRule]] = None, + ): self.configuration = configuration + self.composition_rules = list(composition_rules) if composition_rules else [] def convert_many( self, @@ -213,6 +220,7 @@ def process(self, graphs: Dict[str, Graph]): """ configuration = self.configuration + composition_rules = self.composition_rules pipeline = ( configuration.additional_pre_processors @@ -220,7 +228,7 @@ def process(self, graphs: Dict[str, Graph]): CheckIntegerOnly(), AssignBitWidths( single_precision=configuration.single_precision, - composable=configuration.composable, + composition_rules=composition_rules, comparison_strategy_preference=configuration.comparison_strategy_preference, bitwise_strategy_preference=configuration.bitwise_strategy_preference, shifts_with_promotion=configuration.shifts_with_promotion, diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index da5a96028a..12cac54aa6 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -2,11 +2,11 @@ Declaration of `AssignBitWidths` graph processor. """ -from itertools import chain from typing import Dict, List import z3 +from ...compilation.composition import CompositionRule from ...compilation.configuration import ( BitwiseStrategy, ComparisonStrategy, @@ -31,7 +31,7 @@ class AssignBitWidths(MultiGraphProcessor): """ single_precision: bool - composable: bool + composition_rules: List[CompositionRule] comparison_strategy_preference: List[ComparisonStrategy] bitwise_strategy_preference: List[BitwiseStrategy] shifts_with_promotion: bool @@ -41,7 +41,7 @@ class AssignBitWidths(MultiGraphProcessor): def __init__( self, single_precision: bool, - composable: bool, + composition_rules: List[CompositionRule], comparison_strategy_preference: List[ComparisonStrategy], bitwise_strategy_preference: List[BitwiseStrategy], shifts_with_promotion: bool, @@ -49,7 +49,7 @@ def __init__( min_max_strategy_preference: List[MinMaxStrategy], ): self.single_precision = single_precision - self.composable = composable + self.composition_rules = composition_rules self.comparison_strategy_preference = comparison_strategy_preference self.bitwise_strategy_preference = bitwise_strategy_preference self.shifts_with_promotion = shifts_with_promotion @@ -99,11 +99,11 @@ def apply_many(self, graphs: Dict[str, Graph]): for bit_width in bit_widths.values(): optimizer.add(bit_width == max_bit_width) - if self.composable: - input_output_bitwidth = z3.Int("input_output") - for node in chain(graph.input_nodes.values(), graph.output_nodes.values()): - bit_width = bit_widths[node] - optimizer.add(bit_width == input_output_bitwidth) + if self.composition_rules: + for compo in self.composition_rules: + from_node = graphs[compo.from_.func].ordered_outputs()[compo.from_.pos] + to_node = graphs[compo.to.func].ordered_inputs()[compo.to.pos] + optimizer.add(bit_widths[from_node] == bit_widths[to_node]) optimizer.minimize(sum(bit_width for bit_width in bit_widths.values())) diff --git a/frontends/concrete-python/concrete/fhe/representation/graph.py b/frontends/concrete-python/concrete/fhe/representation/graph.py index 400d42ee8b..5ba9c8ce91 100644 --- a/frontends/concrete-python/concrete/fhe/representation/graph.py +++ b/frontends/concrete-python/concrete/fhe/representation/graph.py @@ -971,6 +971,20 @@ def integer_range( return result + @property + def inputs_count(self) -> int: + """ + Returns the number of inputs of the graph. + """ + return len(self.input_nodes) + + @property + def outputs_count(self) -> int: + """ + Returns the number of outputs of the graph. + """ + return len(self.output_nodes) + class GraphProcessor(ABC): """ diff --git a/frontends/concrete-python/tests/compilation/test_program.py b/frontends/concrete-python/tests/compilation/test_modules.py similarity index 71% rename from frontends/concrete-python/tests/compilation/test_program.py rename to frontends/concrete-python/tests/compilation/test_modules.py index aaabb84809..a820bcf331 100644 --- a/frontends/concrete-python/tests/compilation/test_program.py +++ b/frontends/concrete-python/tests/compilation/test_modules.py @@ -1,5 +1,5 @@ """ -Tests of everything related to multi-circuit. +Tests of everything related to modules. """ import tempfile @@ -43,28 +43,6 @@ def square(x): return x**2 -def test_wrong_config(helpers): - """ - Test that defining a module with wrong configuration raises an error. - """ - - with pytest.raises(RuntimeError) as excinfo: - - @fhe.module() - class Module: - @fhe.function({"x": "encrypted"}) - def add(x): - return x + 2 - - inputset = [np.random.randint(1, 20, size=()) for _ in range(100)] - module = Module.compile( - {"add": inputset}, - composable=False, - ) - - assert str(excinfo.value) == ("Module can only be compiled with `composable` activated.") - - def test_wrong_info(): """ Test that defining a module with wrong information raises an error. @@ -429,3 +407,151 @@ def dec(x): x_enc = module.inc.run(x_enc) x_dec = module.inc.decrypt(x_enc) assert x_dec == 15 + + +def test_composition_policy_default(): + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def square(x): + return x**2 + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def add_sub(x, y): + return (x + y), (x - y) + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def mul(x, y): + return x * y + + assert isinstance(Module.composition, fhe.CompositionPolicy) + assert isinstance(Module.composition, fhe.AllComposable) + + +def test_composition_policy_all_composable(): + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def square(x): + return x**2 + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def add_sub(x, y): + return (x + y), (x - y) + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def mul(x, y): + return x * y + + composition = fhe.AllComposable() + + assert isinstance(Module.composition, fhe.CompositionPolicy) + assert isinstance(Module.composition, fhe.AllComposable) + + +def test_composition_policy_wires(): + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def square(x): + return x**2 + + @fhe.function({"x": "encrypted", "y": "encrypted"}) + def add_sub(x, y): + return (x + y), (x - y) + + composition = fhe.Wired( + [ + fhe.Wire(fhe.AllOutputs(add_sub), fhe.AllInputs(add_sub)), + fhe.Wire(fhe.AllOutputs(add_sub), fhe.Input(square, 0)), + ] + ) + + assert isinstance(Module.composition, fhe.CompositionPolicy) + assert isinstance(Module.composition, fhe.Wired) + + +def test_composition_wired_enhances_complexity(): + @fhe.module() + class Module1: + @fhe.function({"x": "encrypted"}) + def _1(x): + return (x * 2) % 20 + + @fhe.function({"x": "encrypted"}) + def _2(x): + return (x * 2) % 200 + + composition = fhe.Wired( + [ + fhe.Wire(fhe.Output(_1, 0), fhe.Input(_2, 0)), + ] + ) + + module1 = Module1.compile( + { + "_1": [np.random.randint(1, 20, size=()) for _ in range(100)], + "_2": [np.random.randint(1, 200, size=()) for _ in range(100)], + }, + ) + + @fhe.module() + class Module2: + @fhe.function({"x": "encrypted"}) + def _1(x): + return (x * 2) % 20 + + @fhe.function({"x": "encrypted"}) + def _2(x): + return (x * 2) % 200 + + composition = fhe.AllComposable() + + module2 = Module2.compile( + { + "_1": [np.random.randint(1, 20, size=()) for _ in range(100)], + "_2": [np.random.randint(1, 200, size=()) for _ in range(100)], + }, + ) + + assert module1.complexity < module2.complexity + + +def test_composition_wired_compilation(): + @fhe.module() + class Module: + @fhe.function({"x": "encrypted"}) + def a(x): + return (x * 2) % 20 + + @fhe.function({"x": "encrypted"}) + def b(x): + return (x * 2) % 50 + + @fhe.function({"x": "encrypted"}) + def c(x): + return (x * 2) % 100 + + composition = fhe.Wired( + [ + fhe.Wire(fhe.Output(a, 0), fhe.Input(b, 0)), + fhe.Wire(fhe.Output(b, 0), fhe.Input(c, 0)), + ] + ) + + module = Module.compile( + { + "a": [np.random.randint(1, 20, size=()) for _ in range(100)], + "b": [np.random.randint(1, 50, size=()) for _ in range(100)], + "c": [np.random.randint(1, 100, size=()) for _ in range(100)], + }, + p_error=0.01, + ) + + inp_enc = module.a.encrypt(5) + a_enc = module.a.run(inp_enc) + assert module.a.decrypt(a_enc) == 10 + b_enc = module.b.run(a_enc) + assert module.b.decrypt(b_enc) == 20 + c_enc = module.c.run(b_enc) + assert module.c.decrypt(c_enc) == 40