From 22280754dbc443c644a869846ccb8d5756bd6a63 Mon Sep 17 00:00:00 2001 From: rudy Date: Tue, 19 Sep 2023 17:25:08 +0200 Subject: [PATCH] fix(optimizer): tolerate overspecified partition_cut --- .../dag/multi_parameters/partition_cut.rs | 14 +++++++++++++- .../dag/multi_parameters/partitionning.rs | 7 +++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs index fb01efb810..502515d38f 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partition_cut.rs @@ -182,7 +182,7 @@ impl PartitionCut { _ = lut_partition.insert((input_precision, OrderedFloat(*output_norm2))); } } - let mut p_cut : Vec<_> = lut_partition.iter().copied().collect(); + let mut p_cut: Vec<_> = lut_partition.iter().copied().collect(); p_cut.sort_by(|a, b| a.partial_cmp(b).unwrap()); _ = p_cut.pop(); let p_cut = p_cut.iter().map(|(p, n)| (*p, n.into_inner())).collect(); @@ -192,6 +192,18 @@ impl PartitionCut { } } + pub fn delete_unused_cut(&self, used: &HashSet) -> Self { + let mut p_cut = vec![]; + for (i, &cut) in self.p_cut.iter().enumerate() { + if used.contains(&i) { + p_cut.push(cut); + } + } + Self { + p_cut, + rnorm2: self.rnorm2.clone(), + } + } } impl std::fmt::Display for PartitionCut { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs index e2d6985145..48fd270f2d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/partitionning.rs @@ -139,6 +139,13 @@ fn resolve_by_levelled_block( .copied() .collect(); let nb_partitions = present_partitions.len().max(1); // no tlu = no constraints + if p_cut.p_cut.len() + 1 != nb_partitions { + return resolve_by_levelled_block( + dag, + &p_cut.delete_unused_cut(&present_partitions), + default_partition, + ); + } if nb_partitions == 1 { return only_1_partition(dag); }