Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ssa): array sort #754

Merged
merged 10 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions crates/nargo/tests/test_data/array_len/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ fn main(len3: [u8; 3], len4: [Field; 4]) {

// std::array::len returns a comptime value
constrain len4[std::array::len(len3)] == 4;

// test for std::array::sort
let mut unsorted = len3;
unsorted[0] = len3[1];
unsorted[1] = len3[0];
constrain unsorted[0] > unsorted[1];
let sorted = std::array::sort(unsorted);
constrain sorted[0] < sorted[1];
}
3 changes: 3 additions & 0 deletions crates/noirc_evaluator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ iter-extended.workspace = true
thiserror.workspace = true
num-bigint = "0.4"
num-traits = "0.2.8"

[dev-dependencies]
rand="0.8.5"
13 changes: 12 additions & 1 deletion crates/noirc_evaluator/src/ssa/acir_gen/internal_var.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{
ssa::acir_gen::{const_from_expression, expression_to_witness, optional_expression_to_witness},
ssa::acir_gen::{
const_from_expression, expression_from_witness, expression_to_witness,
optional_expression_to_witness,
},
ssa::node::NodeId,
Evaluator,
};
Expand Down Expand Up @@ -46,6 +49,14 @@ impl InternalVar {
&self.cached_witness
}

pub fn to_expression(&self) -> Expression {
if let Some(w) = self.cached_witness {
expression_from_witness(w)
} else {
self.expression().clone()
}
}

/// If the InternalVar holds a constant expression
/// Return that constant.Otherwise, return None.
pub(super) fn to_const(&self) -> Option<FieldElement> {
Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa/acir_gen/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ pub mod intrinsics;
pub mod load;
pub mod not;
pub mod r#return;
mod sort;
pub mod store;
pub mod truncate;
50 changes: 47 additions & 3 deletions crates/noirc_evaluator/src/ssa/acir_gen/operations/intrinsics.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
use crate::{
ssa::{
acir_gen::{constraints::to_radix_base, InternalVar, InternalVarCache, MemoryMap},
acir_gen::{
constraints::{bound_constraint_with_offset, to_radix_base},
expression_from_witness,
operations::sort::evaluate_permutation,
InternalVar, InternalVarCache, MemoryMap,
},
builtin,
context::SsaContext,
mem::ArrayId,
mem::{ArrayId, Memory},
node::{self, Instruction, Node, NodeId, ObjectType},
},
Evaluator,
};
use acvm::acir::{
circuit::opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode},
circuit::{
directives::Directive,
opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode},
},
native_types::{Expression, Witness},
};
use iter_extended::vecmap;
Expand Down Expand Up @@ -66,6 +74,42 @@ pub(crate) fn evaluate(
};
evaluator.opcodes.push(AcirOpcode::BlackBoxFuncCall(func_call));
}
Opcode::Sort => {
let mut in_expr = Vec::new();
let array_id = Memory::deref(ctx, args[0]).unwrap();
let array = &ctx.mem[array_id];
let num_bits = array.element_type.bits();
for i in 0..array.len {
in_expr.push(
memory_map.load_array_element_constant_index(array, i).unwrap().to_expression(),
);
}
outputs = prepare_outputs(memory_map, instruction_id, array.len, ctx, evaluator);
let out_expr: Vec<Expression> =
outputs.iter().map(|w| expression_from_witness(*w)).collect();
for i in 0..(out_expr.len() - 1) {
bound_constraint_with_offset(
&out_expr[i],
&out_expr[i + 1],
&Expression::zero(),
num_bits,
evaluator,
);
}
let bits = evaluate_permutation(&in_expr, &out_expr, evaluator);
let inputs = in_expr.iter().map(|a| vec![a.clone()]).collect();
evaluator.opcodes.push(AcirOpcode::Directive(Directive::PermutationSort {
inputs,
tuple: 1,
bits,
sort_by: vec![0],
}));
if let node::ObjectType::Pointer(a) = res_type {
memory_map.map_array(a, &outputs, ctx);
} else {
unreachable!();
}
}
}

// If more than witness is returned,
Expand Down
172 changes: 172 additions & 0 deletions crates/noirc_evaluator/src/ssa/acir_gen/operations/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use acvm::{
acir::native_types::Expression,
acir::{circuit::opcodes::Opcode as AcirOpcode, native_types::Witness},
FieldElement,
};

use crate::{
ssa::acir_gen::{
constraints::{add, mul_with_witness, subtract},
expression_from_witness,
},
Evaluator,
};

// Generate gates which ensure that out_expr is a permutation of in_expr
// Returns the control bits of the sorting network used to generate the constrains
pub fn evaluate_permutation(
in_expr: &Vec<Expression>,
out_expr: &Vec<Expression>,
evaluator: &mut Evaluator,
) -> Vec<Witness> {
let (w, b) = permutation_layer(in_expr, evaluator);
// we contrain the network output to out_expr
for (b, o) in b.iter().zip(out_expr) {
evaluator.opcodes.push(AcirOpcode::Arithmetic(subtract(b, FieldElement::one(), o)));
}
w
}

// Generates gates for a sorting network
// returns witness corresponding to the network configuration and the expressions corresponding to the network output
// in_expr: inputs of the sorting network
pub fn permutation_layer(
in_expr: &Vec<Expression>,
evaluator: &mut Evaluator,
) -> (Vec<Witness>, Vec<Expression>) {
let n = in_expr.len();
if n == 1 {
return (Vec::new(), in_expr.clone());
}
let n1 = n / 2;
let mut conf = Vec::new();
// witness for the input switches
for _ in 0..n1 {
conf.push(evaluator.add_witness_to_cs());
}
// compute expressions after the input switches
// If inputs are a1,a2, and the switch value is c, then we compute expresions b1,b2 where
// b1 = a1+q, b2 = a2-q, q = c(a2-a1)
let mut in_sub1 = Vec::new();
let mut in_sub2 = Vec::new();
for i in 0..n1 {
//q = c*(a2-a1);
let intermediate = mul_with_witness(
evaluator,
&expression_from_witness(conf[i]),
&subtract(&in_expr[2 * i + 1], FieldElement::one(), &in_expr[2 * i]),
);
//b1=a1+q
in_sub1.push(add(&intermediate, FieldElement::one(), &in_expr[2 * i]));
//b2=a2-q
in_sub2.push(subtract(&in_expr[2 * i + 1], FieldElement::one(), &intermediate));
}
if n % 2 == 1 {
in_sub2.push(in_expr.last().unwrap().clone());
}
let mut out_expr = Vec::new();
// compute results for the sub networks
let (w1, b1) = permutation_layer(&in_sub1, evaluator);
let (w2, b2) = permutation_layer(&in_sub2, evaluator);
// apply the output swithces
for i in 0..(n - 1) / 2 {
let c = evaluator.add_witness_to_cs();
conf.push(c);
let intermediate = mul_with_witness(
evaluator,
&expression_from_witness(c),
&subtract(&b2[i], FieldElement::one(), &b1[i]),
);
out_expr.push(add(&intermediate, FieldElement::one(), &b1[i]));
out_expr.push(subtract(&b2[i], FieldElement::one(), &intermediate));
}
if n % 2 == 0 {
out_expr.push(b1.last().unwrap().clone());
}
out_expr.push(b2.last().unwrap().clone());
conf.extend(w1);
conf.extend(w2);
(conf, out_expr)
}

#[cfg(test)]
mod test {
use std::collections::BTreeMap;

use acvm::{
acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness},
FieldElement, OpcodeResolutionError, PartialWitnessGenerator,
};

use crate::{
ssa::acir_gen::{expression_from_witness, operations::sort::evaluate_permutation},
Evaluator,
};
use rand::prelude::*;

struct MockBackend {}
impl PartialWitnessGenerator for MockBackend {
fn solve_black_box_function_call(
_initial_witness: &mut BTreeMap<Witness, FieldElement>,
_func_call: &BlackBoxFuncCall,
) -> Result<(), OpcodeResolutionError> {
unreachable!();
}
}

// Check that a random network constrains its output to be a permutation of any random input
#[test]
fn test_permutation() {
let mut rng = rand::thread_rng();
for n in 2..50 {
let mut eval = Evaluator {
current_witness_index: 0,
num_witnesses_abi_len: 0,
public_inputs: Vec::new(),
opcodes: Vec::new(),
};

//we generate random inputs
let mut input = Vec::new();
let mut a_val = Vec::new();
let mut b_wit = Vec::new();
let mut solved_witness: BTreeMap<Witness, FieldElement> = BTreeMap::new();
for i in 0..n {
let w = eval.add_witness_to_cs();
input.push(expression_from_witness(w));
a_val.push(FieldElement::from(rng.next_u32() as i128));
solved_witness.insert(w, a_val[i]);
}

let mut output = Vec::new();
for _i in 0..n {
let w = eval.add_witness_to_cs();
b_wit.push(w);
output.push(expression_from_witness(w));
}
//generate constraints for the inputs
let w = evaluate_permutation(&input, &output, &mut eval);

//we generate random network
let mut c = Vec::new();
for _i in 0..w.len() {
c.push(rng.next_u32() % 2 != 0);
}
// intialise bits
for i in 0..w.len() {
solved_witness.insert(w[i], FieldElement::from(c[i] as i128));
}
// compute the network output by solving the constraints
let backend = MockBackend {};
backend
.solve(&mut solved_witness, eval.opcodes.clone())
.expect("Could not solve permutation constraints");
let mut b_val = Vec::new();
for i in 0..output.len() {
b_val.push(solved_witness[&b_wit[i]]);
}
// ensure the outputs are a permutation of the inputs
assert_eq!(a_val.sort(), b_val.sort());
}
}
}
16 changes: 13 additions & 3 deletions crates/noirc_evaluator/src/ssa/builtin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::ssa::node::ObjectType;
use crate::ssa::{
context::SsaContext,
node::{NodeId, ObjectType},
};
use acvm::{acir::BlackBoxFunc, FieldElement};
use num_bigint::BigUint;
use num_traits::{One, Zero};
Expand All @@ -12,6 +15,7 @@ pub enum Opcode {
LowLevel(BlackBoxFunc),
ToBits,
ToRadix,
Sort,
}

impl std::fmt::Display for Opcode {
Expand All @@ -30,6 +34,7 @@ impl Opcode {
match op_name {
"to_le_bits" => Some(Opcode::ToBits),
"to_radix" => Some(Opcode::ToRadix),
"arraysort" => Some(Opcode::Sort),
_ => BlackBoxFunc::lookup(op_name).map(Opcode::LowLevel),
}
}
Expand All @@ -39,6 +44,7 @@ impl Opcode {
Opcode::LowLevel(op) => op.name(),
Opcode::ToBits => "to_le_bits",
Opcode::ToRadix => "to_radix",
Opcode::Sort => "arraysort",
}
}

Expand All @@ -64,13 +70,13 @@ impl Opcode {
}
}
}
Opcode::ToBits | Opcode::ToRadix => BigUint::zero(), //pointers do not overflow
Opcode::ToBits | Opcode::ToRadix | Opcode::Sort => BigUint::zero(), //pointers do not overflow
}
}

/// Returns the number of elements that the `Opcode` should return
/// and the type.
pub fn get_result_type(&self) -> (u32, ObjectType) {
pub fn get_result_type(&self, args: &[NodeId], ctx: &SsaContext) -> (u32, ObjectType) {
match self {
Opcode::LowLevel(op) => {
match op {
Expand All @@ -90,6 +96,10 @@ impl Opcode {
}
Opcode::ToBits => (FieldElement::max_num_bits(), ObjectType::Boolean),
Opcode::ToRadix => (FieldElement::max_num_bits(), ObjectType::NativeField),
Opcode::Sort => {
let a = super::mem::Memory::deref(ctx, args[0]).unwrap();
(ctx.mem[a].len, ctx.mem[a].element_type)
}
}
}
}
3 changes: 1 addition & 2 deletions crates/noirc_evaluator/src/ssa/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl IrGenerator {
op: builtin::Opcode,
args: Vec<NodeId>,
) -> Result<Vec<NodeId>, RuntimeError> {
let (len, elem_type) = op.get_result_type();
let (len, elem_type) = op.get_result_type(&args, &self.context);

let result_type = if len > 1 {
//We create an array that will contain the result and set the res_type to point to that array
Expand All @@ -342,7 +342,6 @@ impl IrGenerator {
} else {
elem_type
};

//when the function returns an array, we use ins.res_type(array)
//else we map ins.id to the returned witness
let id = self.context.new_instruction(node::Operation::Intrinsic(op, args), result_type)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl From<ResolverError> for Diagnostic {
),
ResolverError::NonStructUsedInConstructor { typ, span } => Diagnostic::simple_error(
"Only struct types can be used in constructor expressions".into(),
format!("{} has no fields to construct it with", typ),
format!("{typ} has no fields to construct it with"),
span,
),
ResolverError::NonStructWithGenerics { span } => Diagnostic::simple_error(
Expand Down
Loading