Skip to content

Commit

Permalink
feat(optimizer): multi-parameters, variance and cost value and coeffi…
Browse files Browse the repository at this point in the history
…cient compression
  • Loading branch information
rudy-6-4 committed Sep 25, 2023
1 parent 311a24a commit dcf7329
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl fmt::Display for OperationsCount {
let mut add_plus = "";
let counts = &self.counts;
let nb_partitions = counts.nb_partitions();
let index = counts.index;
let index = &counts.index;
for src_partition in 0..nb_partitions {
for dst_partition in 0..nb_partitions {
let coeff = counts.values[index.keyswitch_to_small(src_partition, dst_partition)];
Expand Down Expand Up @@ -113,4 +113,23 @@ impl Complexity {
(complexity_cut - actual_complexity) / fks_coeff
}

pub fn compressed(self) -> Self {
let mut detect_used: Vec<bool> = vec![false; self.counts.len()];
for (i, &count) in self.counts.iter().enumerate() {
if count > 0.0 {
detect_used[i] = true;
}
}
Self {
counts: self.counts.compress(&detect_used),
}
}

pub fn zero_cost(&self) -> OperationsValue {
if self.counts.index.is_compressed() {
OperationsValue::zero_compressed(&self.counts.index)
} else {
OperationsValue::zero(self.counts.nb_partitions())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::utils::f64::f64_dot;
use super::operations_value::OperationsValue;
use super::partitions::PartitionIndex;

#[derive(Clone)]
pub struct Feasible {
// TODO: move kappa here
pub constraints: Vec<VarianceConstraint>,
Expand Down Expand Up @@ -204,4 +205,44 @@ impl Feasible {
.collect();
Self::of(&partition_constraints, self.kappa, self.global_p_error)
}

pub fn compressed(self) -> Self {
let mut detect_used: Vec<bool> = vec![false; self.constraints[0].variance.coeffs.len()];
for constraint in &self.constraints {
for (i, &coeff) in constraint.variance.coeffs.iter().enumerate() {
if coeff > 0.0 {
detect_used[i] = true;
}
}
}
let compress = |c: &VarianceConstraint| VarianceConstraint {
variance: c.variance.compress(&detect_used),
..(*c)
};
let constraints = self.constraints.iter().map(compress).collect();
let undominated_constraints = self.undominated_constraints.iter().map(compress).collect();
Self {
constraints,
undominated_constraints,
..self
}
}

pub fn zero_variance(&self) -> OperationsValue {
if self.undominated_constraints[0]
.variance
.coeffs
.index
.is_compressed()
{
OperationsValue::zero_compressed(&self.undominated_constraints[0].variance.coeffs.index)
} else {
OperationsValue::zero(
self.undominated_constraints[0]
.variance
.coeffs
.nb_partitions(),
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut};
/**
* Index actual operations (input, ks, pbs, fks, modulus switching, etc).
*/
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd)]
#[derive(Clone, Debug, Eq, PartialEq, PartialOrd)]
pub struct Indexing {
/* Values order
[
Expand All @@ -20,64 +20,111 @@ pub struct Indexing {
]
*/
pub nb_partitions: usize,
pub compressed_index: Vec<usize>,
}

pub const VALUE_INDEX_FRESH: usize = 0;
pub const VALUE_INDEX_PBS: usize = 1;
pub const VALUE_INDEX_MODULUS: usize = 2;
const VALUE_INDEX_FRESH: usize = 0;
const VALUE_INDEX_PBS: usize = 1;
const VALUE_INDEX_MODULUS: usize = 2;
// number of value always present for a partition
pub const STABLE_NB_VALUES_BY_PARTITION: usize = 3;
const STABLE_NB_VALUES_BY_PARTITION: usize = 3;

pub const COMPRESSED_0_INDEX: usize = 0; // all 0.0 value are indexed here
pub const COMPRESSED_FIRST_FREE_INDEX: usize = 1;

impl Indexing {
fn nb_keyswitchs_per_partition(self) -> usize {
fn uncompressed(nb_partitions: usize) -> Self {
Self {
nb_partitions,
compressed_index: vec![],
}
}

fn compress(&self, used: &[bool]) -> Self {
assert!(!self.is_compressed());
let mut compressed_index = vec![COMPRESSED_0_INDEX; self.nb_coeff()];
let mut index = COMPRESSED_FIRST_FREE_INDEX;
for (i, &is_used) in used.iter().enumerate() {
if is_used {
compressed_index[i] = index;
index += 1;
}
}
Self {
compressed_index,
..(*self)
}
}

pub fn is_compressed(&self) -> bool {
!self.compressed_index.is_empty()
}

fn nb_keyswitchs_per_partition(&self) -> usize {
self.nb_partitions
}

pub fn nb_coeff_per_partition(self) -> usize {
pub fn maybe_compressed(&self, i: usize) -> usize {
if self.is_compressed() {
self.compressed_index[i]
} else {
i
}
}

pub fn nb_coeff_per_partition(&self) -> usize {
STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions
}

pub fn nb_coeff(self) -> usize {
pub fn nb_coeff(&self) -> usize {
self.nb_partitions * (STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions)
}

pub fn input(self, partition: usize) -> usize {
pub fn input(&self, partition: usize) -> usize {
assert!(partition < self.nb_partitions);
partition * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH
self.maybe_compressed(partition * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH)
}

pub fn pbs(self, partition: usize) -> usize {
pub fn pbs(&self, partition: usize) -> usize {
assert!(partition < self.nb_partitions);
partition * self.nb_coeff_per_partition() + VALUE_INDEX_PBS
self.maybe_compressed(partition * self.nb_coeff_per_partition() + VALUE_INDEX_PBS)
}

pub fn modulus_switching(self, partition: usize) -> usize {
pub fn modulus_switching(&self, partition: usize) -> usize {
assert!(partition < self.nb_partitions);
partition * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS
self.maybe_compressed(partition * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS)
}

pub fn keyswitch_to_small(self, src_partition: usize, dst_partition: usize) -> usize {
pub fn keyswitch_to_small(&self, src_partition: usize, dst_partition: usize) -> usize {
assert!(src_partition < self.nb_partitions);
assert!(dst_partition < self.nb_partitions);
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
+ STABLE_NB_VALUES_BY_PARTITION
// Select the right keyswicth to small
+ src_partition
self.maybe_compressed(
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
+ STABLE_NB_VALUES_BY_PARTITION
// Select the right keyswicth to small
+ src_partition,
)
}

pub fn keyswitch_to_big(self, src_partition: usize, dst_partition: usize) -> usize {
pub fn keyswitch_to_big(&self, src_partition: usize, dst_partition: usize) -> usize {
assert!(src_partition < self.nb_partitions);
assert!(dst_partition < self.nb_partitions);
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
+ STABLE_NB_VALUES_BY_PARTITION
// Skip keyswitch to small
+ self.nb_keyswitchs_per_partition()
// Select the right keyswicth to big
+ src_partition
self.maybe_compressed(
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
+ STABLE_NB_VALUES_BY_PARTITION
// Skip keyswitch to small
+ self.nb_keyswitchs_per_partition()
// Select the right keyswicth to big
+ src_partition,
)
}

pub fn compressed_size(&self) -> usize {
self.compressed_index.iter().copied().max().unwrap_or(0) + 1
}
}

Expand All @@ -92,23 +139,36 @@ pub struct OperationsValue {

impl OperationsValue {
pub const ZERO: Self = Self {
index: Indexing { nb_partitions: 0 },
index: Indexing {
nb_partitions: 0,
compressed_index: vec![],
},
values: vec![],
};

pub fn zero(nb_partitions: usize) -> Self {
let index = Indexing { nb_partitions };
let index = Indexing::uncompressed(nb_partitions);
let nb_coeff = index.nb_coeff();
Self {
index,
values: vec![0.0; index.nb_coeff()],
values: vec![0.0; nb_coeff],
}
}

pub fn zero_compressed(index: &Indexing) -> Self {
assert!(index.is_compressed());
Self {
index: index.clone(),
values: vec![0.0; index.compressed_size()],
}
}

pub fn nan(nb_partitions: usize) -> Self {
let index = Indexing { nb_partitions };
let index = Indexing::uncompressed(nb_partitions);
let nb_coeff = index.nb_coeff();
Self {
index,
values: vec![f64::NAN; index.nb_coeff()],
values: vec![f64::NAN; nb_coeff],
}
}

Expand All @@ -135,6 +195,32 @@ impl OperationsValue {
pub fn nb_partitions(&self) -> usize {
self.index.nb_partitions
}

pub fn compress(&self, used: &[bool]) -> Self {
self.compress_with(self.index.compress(used))
}

pub fn compress_like(&self, other: Self) -> Self {
self.compress_with(other.index)
}

fn compress_with(&self, index: Indexing) -> Self {
assert!(!index.compressed_index.is_empty());
assert!(self.index.compressed_index.is_empty());
let mut values = vec![0.0; index.compressed_size()];
for (i, &value) in self.values.iter().enumerate() {
#[allow(clippy::option_if_let_else)]
let j = index.compressed_index[i];
if j == COMPRESSED_0_INDEX {
assert!(value == 0.0, "Cannot compress non null value");
} else {
values[j] = value;
}
}
assert!(values[COMPRESSED_0_INDEX] == 0.0);
assert!(!index.compressed_index.is_empty());
Self { index, values }
}
}

impl Deref for OperationsValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,8 @@ fn optimize_macro(

let fks_to_optimize = fks_to_optimize(nb_partitions, used_conversion_keyswitch, partition);
let operations = OperationsCV {
variance: OperationsValue::zero(nb_partitions),
cost: OperationsValue::zero(nb_partitions),
variance: feasible.zero_variance(),
cost: complexity.zero_cost(),
};
let partition_feasible = feasible.filter_constraints(partition);

Expand Down Expand Up @@ -907,8 +907,8 @@ pub fn optimize(

let mut caches = persistent_caches.caches();

let feasible = Feasible::of(&dag.variance_constraints, kappa, None);
let complexity = Complexity::of(&dag.operations_count);
let feasible = Feasible::of(&dag.variance_constraints, kappa, None).compressed();
let complexity = Complexity::of(&dag.operations_count).compressed();
let used_tlu_keyswitch = used_tlu_keyswitch(&dag);
let used_conversion_keyswitch = used_conversion_keyswitch(&dag);

Expand Down Expand Up @@ -1071,8 +1071,8 @@ fn sanity_check(
);
let nb_partitions = params.macro_params.len();
let mut operations = OperationsCV {
variance: OperationsValue::zero(nb_partitions),
cost: OperationsValue::zero(nb_partitions),
variance: feasible.zero_variance(),
cost: complexity.zero_cost(),
};
let micro_params = &params.micro_params;
for partition in 0..nb_partitions {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt;

use crate::optimization::dag::multi_parameters::operations_value::{
OperationsValue, VALUE_INDEX_FRESH, VALUE_INDEX_PBS,
OperationsValue
};

/**
Expand Down Expand Up @@ -147,9 +147,8 @@ impl SymbolicVariance {
// this is the maximum value of fresh base noise and pbs base noise
let mut current_max: f64 = 0.0;
for partition in 0..self.nb_partitions() {
let partition_offset = partition * self.coeffs.index.nb_coeff_per_partition();
let fresh_coeff = self.coeffs[partition_offset + VALUE_INDEX_FRESH];
let pbs_noise_coeff = self.coeffs[partition_offset + VALUE_INDEX_PBS];
let fresh_coeff = self.coeff_input(partition);
let pbs_noise_coeff = self.coeff_pbs(partition);
current_max = current_max.max(fresh_coeff).max(pbs_noise_coeff);
}
assert!(1.0 <= current_max);
Expand Down Expand Up @@ -177,6 +176,13 @@ impl SymbolicVariance {
}
Self { coeffs, ..*self }
}

pub fn compress(&self, detect_used: &[bool]) -> Self {
Self {
coeffs: self.coeffs.compress(detect_used),
..(*self)
}
}
}

impl fmt::Display for SymbolicVariance {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ pub fn f64_max(values: &[f64], default: f64) -> f64 {
}

pub fn f64_dot(a: &[f64], b: &[f64]) -> f64 {
assert!(
a.len() == b.len(),
"Dot incompatible size: {} vs {}",
a.len(),
b.len()
);
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
Expand Down

1 comment on commit dcf7329

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: dcf7329 Previous: 311a24a Ratio
v0 PBS table generation 88480005 ns/iter (± 427124) 89409197 ns/iter (± 205827) 0.99
v0 PBS simulate dag table generation 54397444 ns/iter (± 131472) 59011192 ns/iter (± 148339) 0.92
v0 WoP-PBS table generation 144758715 ns/iter (± 156850) 131727207 ns/iter (± 481389) 1.10

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.