Skip to content

Commit

Permalink
feat(optimizer): support partitionning on precision and norm2
Browse files Browse the repository at this point in the history
  • Loading branch information
rudy-6-4 committed Sep 25, 2023
1 parent db50074 commit c1aa78f
Show file tree
Hide file tree
Showing 19 changed files with 412 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ enum Strategy {
std::string const StrategyLabel[] = {"V0", "dag-mono", "dag-multi"};

constexpr Strategy DEFAULT_STRATEGY = Strategy::DAG_MONO;
constexpr concrete_optimizer::MultiParamStrategy DEFAULT_MULTI_PARAM_STRATEGY =
concrete_optimizer::MultiParamStrategy::ByPrecision;
constexpr bool DEFAULT_KEY_SHARING = true;

struct Config {
Expand All @@ -104,6 +106,7 @@ struct Config {
bool display;
Strategy strategy;
bool key_sharing;
concrete_optimizer::MultiParamStrategy multi_param_strategy;
std::uint64_t security;
double fallback_log_norm_woppbs;
bool use_gpu_constraints;
Expand All @@ -119,6 +122,7 @@ constexpr Config DEFAULT_CONFIG = {
DEFAULT_DISPLAY,
DEFAULT_STRATEGY,
DEFAULT_KEY_SHARING,
DEFAULT_MULTI_PARAM_STRATEGY,
DEFAULT_SECURITY,
DEFAULT_FALLBACK_LOG_NORM_WOPPBS,
DEFAULT_USE_GPU_CONSTRAINTS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ concrete_optimizer::Options options_from_config(optimizer::Config config) {
concrete_optimizer::Options options = {
/* .security_level = */ config.security,
/* .maximum_acceptable_error_probability = */ config.p_error,
/* . key_sharing */ config.key_sharing,
/* .key_sharing = */ config.key_sharing,
/* .multi_param_strategy = */ config.multi_param_strategy,
/* .default_log_norm2_woppbs = */ config.fallback_log_norm_woppbs,
/* .use_gpu_constraints = */ config.use_gpu_constraints,
/* .encoding = */ config.encoding,
Expand Down
17 changes: 17 additions & 0 deletions compilers/concrete-compiler/compiler/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,21 @@ llvm::cl::opt<double> fallbackLogNormWoppbs(
"when the precise value can't be computed."),
llvm::cl::init(optimizer::DEFAULT_CONFIG.fallback_log_norm_woppbs));

llvm::cl::opt<concrete_optimizer::MultiParamStrategy>
optimizerMultiParamStrategy(
"optimizer-multi-parameter-strategy",
llvm::cl::desc(
"Select the concrete optimizer multi parameter strategy"),
llvm::cl::init(optimizer::DEFAULT_MULTI_PARAM_STRATEGY),
llvm::cl::values(clEnumValN(
concrete_optimizer::MultiParamStrategy::ByPrecision, "by-precision",
"One partition set for each possible input TLU precision")),
llvm::cl::values(clEnumValN(
concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2,
"by-precision-and-norm2",
"One partition set for each possible input TLU precision and "
"output norm2")));

llvm::cl::opt<concrete_optimizer::Encoding> optimizerEncoding(
"force-encoding", llvm::cl::desc("Choose cyphertext encoding."),
llvm::cl::init(optimizer::DEFAULT_CONFIG.encoding),
Expand Down Expand Up @@ -488,6 +503,8 @@ cmdlineCompilationOptions() {
options.optimizerConfig.display = cmdline::displayOptimizerChoice;
options.optimizerConfig.strategy = cmdline::optimizerStrategy;
options.optimizerConfig.key_sharing = cmdline::optimizerKeySharing;
options.optimizerConfig.multi_param_strategy =
cmdline::optimizerMultiParamStrategy;
options.optimizerConfig.encoding = cmdline::optimizerEncoding;
options.optimizerConfig.cache_on_disk = !cmdline::optimizerNoCacheOnDisk;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,22 @@ parseEndToEndCommandLine(int argc, char **argv) {
"evaluation "
"keys")));

llvm::cl::opt<concrete_optimizer::MultiParamStrategy>
optimizerMultiParamStrategy(
"optimizer-multi-parameter-strategy",
llvm::cl::desc(
"Select the concrete optimizer multi parameter strategy"),
llvm::cl::init(optimizer::DEFAULT_MULTI_PARAM_STRATEGY),
llvm::cl::values(clEnumValN(
concrete_optimizer::MultiParamStrategy::ByPrecision,
"by-precision",
"One partition set for each possible input TLU precision")),
llvm::cl::values(clEnumValN(
concrete_optimizer::MultiParamStrategy::ByPrecisionAndNorm2,
"by-precision-and-norm2",
"One partition set for each possible input TLU precision and "
"output norm2")));

// JIT or Library support
llvm::cl::opt<bool> jit(
"jit",
Expand Down Expand Up @@ -130,6 +146,8 @@ parseEndToEndCommandLine(int argc, char **argv) {
compilationOptions.optimizerConfig.display = optimizerDisplay.getValue();
compilationOptions.optimizerConfig.security = securityLevel.getValue();
compilationOptions.optimizerConfig.strategy = optimizerStrategy.getValue();
compilationOptions.optimizerConfig.multi_param_strategy =
optimizerMultiParamStrategy.getValue();

mlir::concretelang::setupLogging(verbose.getValue());

Expand Down
10 changes: 10 additions & 0 deletions compilers/concrete-optimizer/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use concrete_optimizer::dag::operator::{
};
use concrete_optimizer::dag::unparametrized;
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::{self, CircuitSolution};
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec;
use concrete_optimizer::optimization::dag::multi_parameters::keys_spec::CircuitSolution;
use concrete_optimizer::optimization::dag::multi_parameters::partition_cut::PartitionCut;
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{
Encoding, Solution as DagSolution,
};
Expand Down Expand Up @@ -554,6 +556,13 @@ impl OperationDag {
let search_space = SearchSpace::default(processing_unit);

let encoding = options.encoding.into();
#[allow(clippy::wildcard_in_or_patterns)]
let p_cut = match options.multi_param_strategy {
ffi::MultiParamStrategy::ByPrecisionAndNorm2 => {
PartitionCut::maximal_partitionning(&self.0)
}
ffi::MultiParamStrategy::ByPrecision | _ => PartitionCut::for_each_precision(&self.0),
};
let circuit_sol =
concrete_optimizer::optimization::dag::multi_parameters::optimize_generic::optimize(
&self.0,
Expand All @@ -562,7 +571,7 @@ impl OperationDag {
encoding,
options.default_log_norm2_woppbs,
&caches_from(options),
&None,
&Some(p_cut),
);
circuit_sol.into()
}
Expand Down Expand Up @@ -729,12 +738,20 @@ mod ffi {
pub crt_decomposition: Vec<u64>,
}

#[derive(Debug, Clone, Copy)]
#[namespace = "concrete_optimizer"]
pub enum MultiParamStrategy {
ByPrecision,
ByPrecisionAndNorm2,
}

#[namespace = "concrete_optimizer"]
#[derive(Debug, Clone, Copy)]
pub struct Options {
pub security_level: u64,
pub maximum_acceptable_error_probability: f64,
pub key_sharing: bool,
pub multi_param_strategy: MultiParamStrategy,
pub default_log_norm2_woppbs: f64,
pub use_gpu_constraints: bool,
pub encoding: Encoding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,7 @@ namespace concrete_optimizer {
struct OperationDag;
struct Weights;
enum class Encoding : ::std::uint8_t;
enum class MultiParamStrategy : ::std::uint8_t;
struct Options;
namespace dag {
struct OperatorIndex;
Expand Down Expand Up @@ -1070,12 +1071,21 @@ struct DagSolution final {
#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$dag$DagSolution
} // namespace dag

#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy
#define CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy
enum class MultiParamStrategy : ::std::uint8_t {
ByPrecision = 0,
ByPrecisionAndNorm2 = 1,
};
#endif // CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy

#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Options
#define CXXBRIDGE1_STRUCT_concrete_optimizer$Options
struct Options final {
::std::uint64_t security_level;
double maximum_acceptable_error_probability;
bool key_sharing;
::concrete_optimizer::MultiParamStrategy multi_param_strategy;
double default_log_norm2_woppbs;
bool use_gpu_constraints;
::concrete_optimizer::Encoding encoding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ namespace concrete_optimizer {
struct OperationDag;
struct Weights;
enum class Encoding : ::std::uint8_t;
enum class MultiParamStrategy : ::std::uint8_t;
struct Options;
namespace dag {
struct OperatorIndex;
Expand Down Expand Up @@ -1051,12 +1052,21 @@ struct DagSolution final {
#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$dag$DagSolution
} // namespace dag

#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy
#define CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy
enum class MultiParamStrategy : ::std::uint8_t {
ByPrecision = 0,
ByPrecisionAndNorm2 = 1,
};
#endif // CXXBRIDGE1_ENUM_concrete_optimizer$MultiParamStrategy

#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$Options
#define CXXBRIDGE1_STRUCT_concrete_optimizer$Options
struct Options final {
::std::uint64_t security_level;
double maximum_acceptable_error_probability;
bool key_sharing;
::concrete_optimizer::MultiParamStrategy multi_param_strategy;
double default_log_norm2_woppbs;
bool use_gpu_constraints;
::concrete_optimizer::Encoding encoding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ concrete_optimizer::Options default_options() {
.security_level = SECURITY_128B,
.maximum_acceptable_error_probability = P_ERROR,
.key_sharing = false,
.multi_param_strategy = concrete_optimizer::MultiParamStrategy::ByPrecision,
.default_log_norm2_woppbs = WOP_FALLBACK_LOG_NORM,
.use_gpu_constraints = false,
.encoding = concrete_optimizer::Encoding::Auto,
Expand Down
9 changes: 5 additions & 4 deletions compilers/concrete-optimizer/concrete-optimizer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
concrete-security-curves = { path = "../../../tools/parameter-curves/concrete-security-curves-rust" }
bincode = "1.3"
concrete-cpu-noise-model = { path = "../../../backends/concrete-cpu/noise-model/" }
concrete-security-curves = { path = "../../../tools/parameter-curves/concrete-security-curves-rust" }
file-lock = "2.1.6"
serde = { version = "1.0", features = ["derive"] }
bincode = "1.3"
ordered-float = "3.9.1"
puruspe = "0.2.0"
rustc-hash = "1.1"
rand = "0.8"
rustc-hash = "1.1"
serde = { version = "1.0", features = ["derive"] }

[dev-dependencies]
approx = "0.5"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use std::collections::HashSet;

use crate::dag::operator::{
dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape,
};
use crate::dag::rewrite::round::expand_round_and_index_map;
use crate::dag::unparametrized;
use crate::optimization::config::NoiseBoundConfig;
use crate::optimization::dag::multi_parameters::partition_cut::PartitionCut;
use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred;
use crate::optimization::dag::multi_parameters::partitions::{
InstructionPartition, PartitionIndex, Transition,
};
use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut;
use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance;
use crate::optimization::dag::solo_key::analyze::{
extra_final_values_to_check, first, safe_noise_bound,
Expand Down Expand Up @@ -43,13 +41,13 @@ pub struct AnalyzedDag {
pub operations_count_per_instrs: Vec<OperationsCount>,
pub operations_count: OperationsCount,
pub instruction_rewrite_index: Vec<Vec<OperatorIndex>>,
pub p_cut: PrecisionCut,
pub p_cut: PartitionCut,
}

pub fn analyze(
dag: &unparametrized::OperationDag,
noise_config: &NoiseBoundConfig,
p_cut: &Option<PrecisionCut>,
p_cut: &Option<PartitionCut>,
default_partition: PartitionIndex,
) -> AnalyzedDag {
let (dag, instruction_rewrite_index) = expand_round_and_index_map(dag);
Expand All @@ -59,7 +57,7 @@ pub fn analyze(
#[allow(clippy::option_if_let_else)]
let p_cut = match p_cut {
Some(p_cut) => p_cut.clone(),
None => maximal_p_cut(&dag),
None => PartitionCut::for_each_precision(&dag),
};
let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition);
let instrs_partition = partitions.instrs_partition;
Expand Down Expand Up @@ -373,19 +371,6 @@ fn sum_operations_count(all_counts: &[OperationsCount]) -> OperationsCount {
OperationsCount { counts: sum_counts }
}

fn maximal_p_cut(dag: &unparametrized::OperationDag) -> PrecisionCut {
let mut lut_in_precisions: HashSet<_> = HashSet::default();
for op in &dag.operators {
if let Op::Lut { input, .. } = op {
_ = lut_in_precisions.insert(dag.out_precisions[input.i]);
}
}
let mut p_cut: Vec<_> = lut_in_precisions.iter().copied().collect();
p_cut.sort_unstable();
_ = p_cut.pop();
PrecisionCut { p_cut }
}

#[cfg(test)]
pub mod tests {
use super::*;
Expand All @@ -404,7 +389,7 @@ pub mod tests {
dag: &unparametrized::OperationDag,
default_partition: PartitionIndex,
) -> AnalyzedDag {
let p_cut = PrecisionCut { p_cut: vec![2] };
let p_cut = PartitionCut::for_each_precision(&dag);
super::analyze(dag, &CONFIG, &Some(p_cut), default_partition)
}

Expand Down Expand Up @@ -861,17 +846,12 @@ pub mod tests {
let mut dag = unparametrized::OperationDag::new();
let max_precision = 10;
let mut lut_input = dag.add_input(max_precision, Shape::number());
let mut p_cut = vec![];
for out_precision in (1..=max_precision).rev() {
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
}
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1);
for out_precision in 1..max_precision {
p_cut.push(out_precision);
}
eprintln!("{}", dag.dump());
let p_cut = PrecisionCut { p_cut };
eprintln!("{p_cut}");
let precisions: Vec<_> = (1..=max_precision).collect();
let p_cut = PartitionCut::from_precisions(&precisions);
let dag = super::analyze(&dag, &CONFIG, &Some(p_cut.clone()), LOW_PRECISION_PARTITION);
assert!(dag.nb_partitions == p_cut.p_cut.len() + 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ pub mod keys_spec;
mod operations_value;
pub mod optimize;
pub mod optimize_generic;
pub mod partition_cut;
mod partitionning;
mod partitions;
mod precision_cut;
mod symbolic_variance;
mod union_find;
mod variance_constraint;
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::parameters::GlweParameters;

use crate::optimization::dag::multi_parameters::complexity::Complexity;
use crate::optimization::dag::multi_parameters::feasible::Feasible;
use crate::optimization::dag::multi_parameters::partition_cut::PartitionCut;
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut;
use crate::optimization::dag::multi_parameters::{analyze, keys_spec};

use super::keys_spec::InstructionKeys;
Expand Down Expand Up @@ -870,7 +870,7 @@ pub fn optimize(
config: Config,
search_space: &SearchSpace,
persistent_caches: &PersistDecompCaches,
p_cut: &Option<PrecisionCut>,
p_cut: &Option<PartitionCut>,
default_partition: PartitionIndex,
) -> Option<(AnalyzedDag, Parameters)> {
let ciphertext_modulus_log = config.ciphertext_modulus_log;
Expand Down Expand Up @@ -1126,7 +1126,7 @@ pub fn optimize_to_circuit_solution(
config: Config,
search_space: &SearchSpace,
persistent_caches: &PersistDecompCaches,
p_cut: &Option<PrecisionCut>,
p_cut: &Option<PartitionCut>,
) -> keys_spec::CircuitSolution {
if lut_count_from_dag(dag) == 0 {
let nb_instr = dag.operators.len();
Expand Down
Loading

0 comments on commit c1aa78f

Please sign in to comment.