Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/compiler optimizer bad signed dot #578

Merged
merged 2 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,8 @@ struct FunctionToDag {
if (inputType.isSigned()) {
// std::vector<std::int64_t> weights_vector{1};
auto addIndex = dag->add_dot(slice(encrypted_inputs),
concrete_optimizer::weights::vector(
slice(std::vector<std::int64_t>{1})));
encrypted_inputs[0] = addIndex;
concrete_optimizer::weights::number(1));
encrypted_input = addIndex;
operatorIndexes.push_back(addIndex.index);
}
auto lutIndex =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,10 @@ fn vector(weights: &[i64]) -> Box<Weights> {
Box::new(Weights(operator::Weights::vector(weights)))
}

fn number(weight: i64) -> Box<Weights> {
Box::new(Weights(operator::Weights::number(weight)))
}

impl From<OperatorIndex> for ffi::OperatorIndex {
fn from(oi: OperatorIndex) -> Self {
Self { index: oi.i }
Expand Down Expand Up @@ -671,6 +675,9 @@ mod ffi {
#[namespace = "concrete_optimizer::weights"]
fn vector(weights: &[i64]) -> Box<Weights>;

#[namespace = "concrete_optimizer::weights"]
fn number(weight: i64) -> Box<Weights>;

fn optimize_multi(self: &OperationDag, _options: Options) -> CircuitSolution;

fn NO_KEY_ID() -> u64;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,8 @@ ::std::size_t concrete_optimizer$cxxbridge1$Weights$operator$alignof() noexcept;
namespace weights {
extern "C" {
::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$vector(::rust::Slice<::std::int64_t const> weights) noexcept;

::concrete_optimizer::Weights *concrete_optimizer$weights$cxxbridge1$number(::std::int64_t weight) noexcept;
} // extern "C"
} // namespace weights

Expand Down Expand Up @@ -1391,6 +1393,10 @@ namespace weights {
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<::std::int64_t const> weights) noexcept {
return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$vector(weights));
}

::rust::Box<::concrete_optimizer::Weights> number(::std::int64_t weight) noexcept {
return ::rust::Box<::concrete_optimizer::Weights>::from_raw(concrete_optimizer$weights$cxxbridge1$number(weight));
}
} // namespace weights

::concrete_optimizer::dag::CircuitSolution OperationDag::optimize_multi(::concrete_optimizer::Options _options) const noexcept {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,8 @@ ::rust::Box<::concrete_optimizer::OperationDag> empty() noexcept;

namespace weights {
::rust::Box<::concrete_optimizer::Weights> vector(::rust::Slice<::std::int64_t const> weights) noexcept;

::rust::Box<::concrete_optimizer::Weights> number(::std::int64_t weight) noexcept;
} // namespace weights

::std::uint64_t NO_KEY_ID() noexcept;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub enum DotKind {
// inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same]
// inputs = [[x, y, z], [u, v, w]], weights = [a, b], [x*a + u*b, y*a + v*b, z*c + w*c]
// inputs = [[x, y, z]], weights = [a], [x*a, y*a, z*a]
Broadcast,
Broadcast { shape: Shape },
Unsupported,
}

Expand All @@ -25,13 +25,19 @@ pub fn dot_kind<W>(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor<W>
} else if inputs_shape == weights.shape {
DotKind::CompatibleTensor
} else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape {
DotKind::Broadcast
DotKind::Broadcast {
shape: Shape::vector(input_shape.first_dim_size()),
}
} else if weights.shape.is_vector() && weights.shape.flat_size() == nb_inputs {
// Same as simple but with tensor inputs
DotKind::Broadcast
DotKind::Broadcast {
shape: input_shape.clone(),
}
} else if weights.shape.is_number() && nb_inputs == 1 {
// Any input multiply by one number
DotKind::Broadcast
DotKind::Broadcast {
shape: input_shape.clone(),
}
} else {
DotKind::Unsupported
}
Expand Down Expand Up @@ -65,7 +71,22 @@ mod tests {
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::vector([1, 2])),
DotKind::Broadcast
DotKind::Broadcast {
shape: Shape::vector(2)
}
);
}

#[test]
fn test_broadcast_scalar_mul() {
let s2x2 = Shape {
dimensions_size: vec![2, 2],
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::number(1)),
DotKind::Broadcast {
shape: s2x2.clone()
}
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl OperationDag {
DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => {
Shape::number()
}
DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
DotKind::Broadcast { shape } => shape,
DotKind::Unsupported { .. } => {
let weights_shape = &weights.shape;
println!();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
let inputs_variance = (0..weights.values.len()).map(|j| {
let input = if inputs.len() > 1 {
inputs[j]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ fn out_variance(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } => {
let first_input = inputs[0];
let mut out_variance = SymbolicVariance::ZERO;
for (j, &weight) in weights.values.iter().enumerate() {
Expand Down Expand Up @@ -269,7 +269,7 @@ fn op_levelled_complexity(
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast | DK::CompatibleTensor => {
DK::Simple | DK::Tensor | DK::Broadcast { .. } | DK::CompatibleTensor => {
LevelledComplexity::ADDITION * (inputs.len() as u64) * input_shape.flat_size()
}
DK::Unsupported { .. } => panic!("Unsupported"),
Expand Down Expand Up @@ -883,10 +883,10 @@ pub mod tests {
let shape = Shape {
dimensions_size: vec![2, 2],
};
let input1 = graph.add_input(1, shape);
let input1 = graph.add_input(1, &shape);
let weights = &Weights::number(2);
_ = graph.add_dot([input1], weights);
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
assert!(*graph.out_shapes.last().unwrap() == shape);
let analysis = analyze(&graph);
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
}
Expand All @@ -902,7 +902,7 @@ pub mod tests {
let lut2 = graph.add_lut(input2, FunctionTable::UNKWOWN, 1);
let weights = &Weights::vector([2, 3]);
_ = graph.add_dot([input1, lut2], weights);
assert!(*graph.out_shapes.last().unwrap() == Shape::vector(2));
assert!(*graph.out_shapes.last().unwrap() == shape);
let analysis = analyze(&graph);
assert_f64_eq(analysis.out_variances.last().unwrap().input_coeff, 4.0);
assert_f64_eq(analysis.out_variances.last().unwrap().lut_coeff, 9.0);
Expand Down
Loading