From 283e15f6b4c1f5541c02350df9a0edad9483da0f Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 15 Sep 2023 14:37:58 +0200 Subject: [PATCH 1/2] fix(compiler): conversion to optimizer dag, bad dot before signed lut this has no effect apart making the shape incorrect --- .../lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp | 5 ++--- .../concrete-optimizer-cpp/src/concrete-optimizer.rs | 7 +++++++ .../concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp | 6 ++++++ .../concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp | 2 ++ 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index 2253b0182a..ac2499f17e 100644 --- a/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compilers/concrete-compiler/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -203,9 +203,8 @@ struct FunctionToDag { if (inputType.isSigned()) { // std::vector weights_vector{1}; auto addIndex = dag->add_dot(slice(encrypted_inputs), - concrete_optimizer::weights::vector( - slice(std::vector{1}))); - encrypted_inputs[0] = addIndex; + concrete_optimizer::weights::number(1)); + encrypted_input = addIndex; operatorIndexes.push_back(addIndex.index); } auto lutIndex = diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs index fdca093787..c37794676d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -574,6 +574,10 @@ fn vector(weights: &[i64]) -> Box { Box::new(Weights(operator::Weights::vector(weights))) } +fn number(weight: i64) -> Box { + Box::new(Weights(operator::Weights::number(weight))) +} + impl From for ffi::OperatorIndex { fn from(oi: OperatorIndex) -> Self { Self { index: oi.i } @@ -671,6 +675,9 @@ mod ffi { #[namespace = "concrete_optimizer::weights"] fn vector(weights: &[i64]) -> Box; + #[namespace = "concrete_optimizer::weights"] + fn number(weight: i64) -> Box; + fn optimize_multi(self: &OperationDag, _options: Options) -> CircuitSolution; fn NO_KEY_ID() -> u64; diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 765e60fc00..e51325713e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -1292,6 +1292,8 @@ ::std::size_t concrete_optimizer$cxxbridge1$Weights$operator$alignof() noexcept; namespace weights { extern "C" { ::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice<::std::int64_t const> weights) noexcept; + +::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$number(::std::int64_t weight) noexcept; } // extern "C" } // namespace weights @@ -1391,6 +1393,10 @@ namespace weights { ::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<::std::int64_t const> weights) noexcept { return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$vector(weights)); } + +::rust::Box<::concrete_optimizer::Weights> number(::std::int64_t weight) noexcept { + return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$number(weight)); +} } // namespace weights ::concrete_optimizer::dag::CircuitSolution OperationDag::optimize_multi(::concrete_optimizer::Options _options) const noexcept { diff --git a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index ac92dadb6b..0d20da41df 100644 --- a/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/compilers/concrete-optimizer/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -1232,6 +1232,8 @@ ::rust::Box<::concrete_optimizer::OperationDag> empty() noexcept; namespace weights { ::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<::std::int64_t const> weights) noexcept; + +::rust::Box<::concrete_optimizer::Weights> number(::std::int64_t weight) noexcept; } // namespace weights ::std::uint64_t NO_KEY_ID() noexcept; From 130d03921e2ab2c016d14ffad571285d1b8483e8 Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 15 Sep 2023 14:47:50 +0200 Subject: [PATCH 2/2] fix(optimizer): incorrect broadcast shape --- .../src/dag/operator/dot_kind.rs | 31 ++++++++++++++++--- .../src/dag/unparametrized.rs | 2 +- .../dag/multi_parameters/analyze.rs | 2 +- .../src/optimization/dag/solo_key/analyze.rs | 10 +++--- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs index bc3b2331f8..3a31ed3e3e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/operator/dot_kind.rs @@ -11,7 +11,7 @@ pub enum DotKind { // inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same] // inputs = [[x, y, z], [u, v, w]], weights = [a, b], [x*a + u*b, y*a + v*b, z*c + w*c] // inputs = [[x, y, z]], weights = [a], [x*a, y*a, z*a] - Broadcast, + Broadcast { shape: Shape }, Unsupported, } @@ -25,13 +25,19 @@ pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor } else if inputs_shape == weights.shape { DotKind::CompatibleTensor } else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape { - DotKind::Broadcast + DotKind::Broadcast { + shape: Shape::vector(input_shape.first_dim_size()), + } } else if weights.shape.is_vector() && weights.shape.flat_size() == nb_inputs { // Same as simple but with tensor inputs - DotKind::Broadcast + DotKind::Broadcast { + shape: input_shape.clone(), + } } else if weights.shape.is_number() && nb_inputs == 1 { // Any input multiply by one number - DotKind::Broadcast + DotKind::Broadcast { + shape: input_shape.clone(), + } } else { DotKind::Unsupported } @@ -65,7 +71,22 @@ mod tests { }; assert_eq!( dot_kind(1, &s2x2, &Weights::vector([1, 2])), - DotKind::Broadcast + DotKind::Broadcast { + shape: Shape::vector(2) + } + ); + } + + #[test] + fn test_broadcast_scalar_mul() { + let s2x2 = Shape { + dimensions_size: vec![2, 2], + }; + assert_eq!( + dot_kind(1, &s2x2, &Weights::number(1)), + DotKind::Broadcast { + shape: s2x2.clone() + } ); } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index b6d457d7f7..08b8e96294 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -260,7 +260,7 @@ impl OperationDag { DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => { Shape::number() } - DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()), + DotKind::Broadcast { shape } => shape, DotKind::Unsupported { .. } => { let weights_shape = &weights.shape; println!(); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index dcadad8c9f..1c308fce97 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -192,7 +192,7 @@ fn out_variance( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor | DK::Broadcast => { + DK::Simple | DK::Tensor | DK::Broadcast { .. } => { let inputs_variance = (0..weights.values.len()).map(|j| { let input = if inputs.len() > 1 { inputs[j] diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index d0a946915c..ddf17b697e 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -163,7 +163,7 @@ fn out_variance( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor | DK::Broadcast => { + DK::Simple | DK::Tensor | DK::Broadcast { .. } => { let first_input = inputs[0]; let mut out_variance = SymbolicVariance::ZERO; for (j, &weight) in weights.values.iter().enumerate() { @@ -269,7 +269,7 @@ fn op_levelled_complexity( let input_shape = first(inputs, out_shapes); let kind = dot_kind(inputs.len() as u64, input_shape, weights); match kind { - DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => { + DK::Simple | DK::Tensor | DK::Broadcast { .. } | DK::CompatibleTensor => { LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size() } DK::Unsupported { .. } => panic!("Unsupported"), @@ -883,10 +883,10 @@ pub mod tests { let shape = Shape { dimensions_size: vec![2, 2], }; - let input1 = graph.add_input(1, shape); + let input1 = graph.add_input(1, &shape); let weights = &Weights::number(2); _ = graph.add_dot([input1], weights); - assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2)); + assert!(*graph.out_shapes.last().unwrap() == shape); let analysis = analyze(&graph); assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0); } @@ -902,7 +902,7 @@ pub mod tests { let lut2 = graph.add_lut(input2, FunctionTable::UNKWOWN, 1); let weights = &Weights::vector([2, 3]); _ = graph.add_dot([input1, lut2], weights); - assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2)); + assert!(*graph.out_shapes.last().unwrap() == shape); let analysis = analyze(&graph); assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0); assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0);