Skip to content

Commit

Permalink
feat(optimizer): multi-parameters, direct variance and cost operation…
Browse files Browse the repository at this point in the history
… bound
  • Loading branch information
rudy-6-4 committed Sep 26, 2023
1 parent 7e6ce03 commit c84372e
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,40 @@ impl Complexity {
pub fn complexity(&self, costs: &OperationsValue) -> f64 {
f64_dot(&self.counts, costs)
}

pub fn ks_max_cost(
&self,
complexity_cut: f64,
costs: &OperationsValue,
src_partition: usize,
dst_partition: usize,
) -> f64 {
let ks_index = costs.index.keyswitch_to_small(src_partition, dst_partition);
let actual_ks_cost = costs.values[ks_index];
let ks_coeff = self.counts[self
.counts
.index
.keyswitch_to_small(src_partition, dst_partition)];
let actual_complexity = self.complexity(costs) - ks_coeff * actual_ks_cost;

(complexity_cut - actual_complexity) / ks_coeff
}

pub fn fks_max_cost(
&self,
complexity_cut: f64,
costs: &OperationsValue,
src_partition: usize,
dst_partition: usize,
) -> f64 {
let fks_index = costs.index.keyswitch_to_big(src_partition, dst_partition);
let actual_fks_cost = costs.values[fks_index];
let fks_coeff = self.counts[self
.counts
.index
.keyswitch_to_big(src_partition, dst_partition)];
let actual_complexity = self.complexity(costs) - fks_coeff * actual_fks_cost;

(complexity_cut - actual_complexity) / fks_coeff
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,87 @@ impl Feasible {
}
}

pub fn pbs_max_feasible_variance(
&self,
operations_variance: &OperationsValue,
partition: usize,
) -> f64 {
let pbs_index = operations_variance.index.pbs(partition);
let actual_pbs_variance = operations_variance.values[pbs_index];

let mut smallest_pbs_max_variance = std::f64::MAX;

for constraint in &self.undominated_constraints {
let pbs_coeff = constraint.variance.coeff_pbs(partition);
if pbs_coeff == 0.0 {
continue;
}
let actual_variance = f64_dot(operations_variance, &constraint.variance.coeffs)
- pbs_coeff * actual_pbs_variance;
let pbs_max_variance = (constraint.safe_variance_bound - actual_variance) / pbs_coeff;
smallest_pbs_max_variance = smallest_pbs_max_variance.min(pbs_max_variance);
}
smallest_pbs_max_variance
}

pub fn ks_max_feasible_variance(
&self,
operations_variance: &OperationsValue,
src_partition: usize,
dst_partition: usize,
) -> f64 {
let ks_index = operations_variance
.index
.keyswitch_to_small(src_partition, dst_partition);
let actual_ks_variance = operations_variance.values[ks_index];

let mut smallest_ks_max_variance = std::f64::MAX;

for constraint in &self.undominated_constraints {
let ks_coeff = constraint
.variance
.coeff_keyswitch_to_small(src_partition, dst_partition);
if ks_coeff == 0.0 {
continue;
}
let actual_variance = f64_dot(operations_variance, &constraint.variance.coeffs)
- ks_coeff * actual_ks_variance;
let ks_max_variance = (constraint.safe_variance_bound - actual_variance) / ks_coeff;
smallest_ks_max_variance = smallest_ks_max_variance.min(ks_max_variance);
}

smallest_ks_max_variance
}

pub fn fks_max_feasible_variance(
&self,
operations_variance: &OperationsValue,
src_partition: usize,
dst_partition: usize,
) -> f64 {
let fks_index = operations_variance
.index
.keyswitch_to_big(src_partition, dst_partition);
let actual_fks_variance = operations_variance.values[fks_index];

let mut smallest_fks_max_variance = std::f64::MAX;

for constraint in &self.undominated_constraints {
let fks_coeff = constraint
.variance
.coeff_partition_keyswitch_to_big(src_partition, dst_partition);
if fks_coeff == 0.0 {
continue;
}
let actual_variance = f64_dot(operations_variance, &constraint.variance.coeffs)
- fks_coeff * actual_fks_variance;
let fks_max_variance = (constraint.safe_variance_bound - actual_variance) / fks_coeff;
smallest_fks_max_variance = smallest_fks_max_variance.min(fks_max_variance);
}

smallest_fks_max_variance
}

pub fn feasible(&self, operations_variance: &OperationsValue) -> bool {
if self.global_p_error.is_none() {
self.local_feasible(operations_variance)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,23 @@ impl Indexing {
}

pub fn input(self, partition: usize) -> usize {
assert!(partition < self.nb_partitions);
partition * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH
}

pub fn pbs(self, partition: usize) -> usize {
assert!(partition < self.nb_partitions);
partition * self.nb_coeff_per_partition() + VALUE_INDEX_PBS
}

pub fn modulus_switching(self, partition: usize) -> usize {
assert!(partition < self.nb_partitions);
partition * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS
}

pub fn keyswitch_to_small(self, src_partition: usize, dst_partition: usize) -> usize {
assert!(src_partition < self.nb_partitions);
assert!(dst_partition < self.nb_partitions);
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
Expand All @@ -63,6 +68,8 @@ impl Indexing {
}

pub fn keyswitch_to_big(self, src_partition: usize, dst_partition: usize) -> usize {
assert!(src_partition < self.nb_partitions);
assert!(dst_partition < self.nb_partitions);
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,18 @@ fn optimize_1_ks(
cut_complexity: f64,
) -> Option<KsComplexityNoise> {
// find the first feasible (and less complex)
let ks_max_variance = feasible.ks_max_feasible_variance(&operations.variance, ks_src, ks_dst);
let ks_max_cost = complexity.ks_max_cost(cut_complexity, &operations.cost, ks_src, ks_dst);
for &ks_quantity in ks_pareto {
// variance is decreasing, complexity is increasing
*operations.variance.ks(ks_src, ks_dst) = ks_quantity.noise(ks_input_lwe_dim);
*operations.cost.ks(ks_src, ks_dst) = ks_quantity.complexity(ks_input_lwe_dim);
if complexity.complexity(&operations.cost) > cut_complexity {
let ks_cost = ks_quantity.complexity(ks_input_lwe_dim);
let ks_variance = ks_quantity.noise(ks_input_lwe_dim);
if ks_cost > ks_max_cost {
return None;
}
if feasible.feasible(&operations.variance) {
if ks_variance <= ks_max_variance {
*operations.variance.ks(ks_src, ks_dst) = ks_variance;
*operations.cost.ks(ks_src, ks_dst) = ks_cost;
return Some(ks_quantity);
}
}
Expand Down Expand Up @@ -174,6 +178,10 @@ fn optimize_1_fks_and_all_compatible_ks(
let mut cut_complexity = cut_complexity;
let same_dim = input_glwe == output_glwe;

let fks_max_variance =
feasible.fks_max_feasible_variance(&operations.variance, fks_src, fks_dst);
let mut fks_max_cost =
complexity.fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst);
for &ks_quantity in &ks_pareto {
// OPT: add a pareto cache for fks
let fks_quantity = if same_dim {
Expand Down Expand Up @@ -212,17 +220,20 @@ fn optimize_1_fks_and_all_compatible_ks(
dst_glwe_param: output_glwe,
}
};
*operations.cost.fks(fks_src, fks_dst) = fks_quantity.complexity;
*operations.variance.fks(fks_src, fks_dst) = fks_quantity.noise;

if complexity.complexity(&operations.cost) > cut_complexity {
if fks_quantity.complexity > fks_max_cost {
// complexity is strictly increasing by level
// next complexity will be worse
return best_sol;
}
if !feasible.feasible(&operations.variance) {

if fks_quantity.noise > fks_max_variance {
continue;
}

*operations.cost.fks(fks_src, fks_dst) = fks_quantity.complexity;
*operations.variance.fks(fks_src, fks_dst) = fks_quantity.noise;

let sol = optimize_many_independant_ks(
macro_parameters,
ks_src,
Expand All @@ -243,6 +254,7 @@ fn optimize_1_fks_and_all_compatible_ks(
continue;
}
cut_complexity = cost;
fks_max_cost = complexity.fks_max_cost(cut_complexity, &operations.cost, fks_src, fks_dst);
// COULD: handle complexity tie
let bests = Best1FksAndManyKs {
fks: Some((fks_src, fks_quantity)),
Expand Down Expand Up @@ -334,20 +346,26 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks(
let mut best_sol_complexity = cut_complexity;
let mut best_sol_p_error = best_p_error;
let mut best_sol_global_p_error = 1.0;

let pbs_max_feasible_variance =
feasible.pbs_max_feasible_variance(&operations.variance, partition);
for &cmux_quantity in cmux_pareto {
// increasing complexity, decreasing variance

// Lower bounds cuts
let pbs_cost = cmux_quantity.complexity_br(internal_dim);
*operations.cost.pbs(partition) = pbs_cost;
// Lower bounds cuts
let lower_cost = complexity.complexity(&operations.cost);
if lower_cost > best_sol_complexity {
continue;
}

let pbs_variance = cmux_quantity.noise_br(internal_dim);
*operations.variance.pbs(partition) = pbs_variance;
if !feasible.feasible(&operations.variance) {
if pbs_variance > pbs_max_feasible_variance {
continue;
}

*operations.variance.pbs(partition) = pbs_variance;
let sol = optimize_dst_exclusive_fks_subset_and_all_ks(
macro_parameters,
fks_paretos,
Expand Down

0 comments on commit c84372e

Please sign in to comment.