Skip to content

Commit

Permalink
feat(ssa): array sort (#754)
Browse files Browse the repository at this point in the history
* Constrains for sorting network

* cargo fmt + (some) clippy

* Removing unreachable case

* rand in cargo.lock

* add length to array sort
  • Loading branch information
guipublic authored Feb 13, 2023
1 parent 48d92df commit 32e9320
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 22 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions crates/nargo/tests/test_data/array_len/src/main.nr
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ fn main(len3: [u8; 3], len4: [Field; 4]) {

// std::array::len returns a comptime value
constrain len4[std::array::len(len3)] == 4;

// test for std::array::sort
let mut unsorted = len3;
unsorted[0] = len3[1];
unsorted[1] = len3[0];
constrain unsorted[0] > unsorted[1];
let sorted = std::array::sort(unsorted);
constrain sorted[0] < sorted[1];
}
3 changes: 3 additions & 0 deletions crates/noirc_evaluator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ iter-extended.workspace = true
thiserror.workspace = true
num-bigint = "0.4"
num-traits = "0.2.8"

[dev-dependencies]
rand="0.8.5"
13 changes: 12 additions & 1 deletion crates/noirc_evaluator/src/ssa/acir_gen/internal_var.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{
ssa::acir_gen::{const_from_expression, expression_to_witness, optional_expression_to_witness},
ssa::acir_gen::{
const_from_expression, expression_from_witness, expression_to_witness,
optional_expression_to_witness,
},
ssa::node::NodeId,
Evaluator,
};
Expand Down Expand Up @@ -46,6 +49,14 @@ impl InternalVar {
&self.cached_witness
}

pub fn to_expression(&self) -> Expression {
if let Some(w) = self.cached_witness {
expression_from_witness(w)
} else {
self.expression().clone()
}
}

/// If the InternalVar holds a constant expression
/// Return that constant.Otherwise, return None.
pub(super) fn to_const(&self) -> Option<FieldElement> {
Expand Down
1 change: 1 addition & 0 deletions crates/noirc_evaluator/src/ssa/acir_gen/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ pub mod intrinsics;
pub mod load;
pub mod not;
pub mod r#return;
mod sort;
pub mod store;
pub mod truncate;
50 changes: 47 additions & 3 deletions crates/noirc_evaluator/src/ssa/acir_gen/operations/intrinsics.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
use crate::{
ssa::{
acir_gen::{constraints::to_radix_base, InternalVar, InternalVarCache, MemoryMap},
acir_gen::{
constraints::{bound_constraint_with_offset, to_radix_base},
expression_from_witness,
operations::sort::evaluate_permutation,
InternalVar, InternalVarCache, MemoryMap,
},
builtin,
context::SsaContext,
mem::ArrayId,
mem::{ArrayId, Memory},
node::{self, Instruction, Node, NodeId, ObjectType},
},
Evaluator,
};
use acvm::acir::{
circuit::opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode},
circuit::{
directives::Directive,
opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode},
},
native_types::{Expression, Witness},
};
use iter_extended::vecmap;
Expand Down Expand Up @@ -66,6 +74,42 @@ pub(crate) fn evaluate(
};
evaluator.opcodes.push(AcirOpcode::BlackBoxFuncCall(func_call));
}
Opcode::Sort => {
let mut in_expr = Vec::new();
let array_id = Memory::deref(ctx, args[0]).unwrap();
let array = &ctx.mem[array_id];
let num_bits = array.element_type.bits();
for i in 0..array.len {
in_expr.push(
memory_map.load_array_element_constant_index(array, i).unwrap().to_expression(),
);
}
outputs = prepare_outputs(memory_map, instruction_id, array.len, ctx, evaluator);
let out_expr: Vec<Expression> =
outputs.iter().map(|w| expression_from_witness(*w)).collect();
for i in 0..(out_expr.len() - 1) {
bound_constraint_with_offset(
&out_expr[i],
&out_expr[i + 1],
&Expression::zero(),
num_bits,
evaluator,
);
}
let bits = evaluate_permutation(&in_expr, &out_expr, evaluator);
let inputs = in_expr.iter().map(|a| vec![a.clone()]).collect();
evaluator.opcodes.push(AcirOpcode::Directive(Directive::PermutationSort {
inputs,
tuple: 1,
bits,
sort_by: vec![0],
}));
if let node::ObjectType::Pointer(a) = res_type {
memory_map.map_array(a, &outputs, ctx);
} else {
unreachable!();
}
}
}

// If more than witness is returned,
Expand Down
172 changes: 172 additions & 0 deletions crates/noirc_evaluator/src/ssa/acir_gen/operations/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use acvm::{
acir::native_types::Expression,
acir::{circuit::opcodes::Opcode as AcirOpcode, native_types::Witness},
FieldElement,
};

use crate::{
ssa::acir_gen::{
constraints::{add, mul_with_witness, subtract},
expression_from_witness,
},
Evaluator,
};

// Generate gates which ensure that out_expr is a permutation of in_expr
// Returns the control bits of the sorting network used to generate the constrains
pub fn evaluate_permutation(
in_expr: &Vec<Expression>,
out_expr: &Vec<Expression>,
evaluator: &mut Evaluator,
) -> Vec<Witness> {
let (w, b) = permutation_layer(in_expr, evaluator);
// we contrain the network output to out_expr
for (b, o) in b.iter().zip(out_expr) {
evaluator.opcodes.push(AcirOpcode::Arithmetic(subtract(b, FieldElement::one(), o)));
}
w
}

// Generates gates for a sorting network
// returns witness corresponding to the network configuration and the expressions corresponding to the network output
// in_expr: inputs of the sorting network
pub fn permutation_layer(
in_expr: &Vec<Expression>,
evaluator: &mut Evaluator,
) -> (Vec<Witness>, Vec<Expression>) {
let n = in_expr.len();
if n == 1 {
return (Vec::new(), in_expr.clone());
}
let n1 = n / 2;
let mut conf = Vec::new();
// witness for the input switches
for _ in 0..n1 {
conf.push(evaluator.add_witness_to_cs());
}
// compute expressions after the input switches
// If inputs are a1,a2, and the switch value is c, then we compute expresions b1,b2 where
// b1 = a1+q, b2 = a2-q, q = c(a2-a1)
let mut in_sub1 = Vec::new();
let mut in_sub2 = Vec::new();
for i in 0..n1 {
//q = c*(a2-a1);
let intermediate = mul_with_witness(
evaluator,
&expression_from_witness(conf[i]),
&subtract(&in_expr[2 * i + 1], FieldElement::one(), &in_expr[2 * i]),
);
//b1=a1+q
in_sub1.push(add(&intermediate, FieldElement::one(), &in_expr[2 * i]));
//b2=a2-q
in_sub2.push(subtract(&in_expr[2 * i + 1], FieldElement::one(), &intermediate));
}
if n % 2 == 1 {
in_sub2.push(in_expr.last().unwrap().clone());
}
let mut out_expr = Vec::new();
// compute results for the sub networks
let (w1, b1) = permutation_layer(&in_sub1, evaluator);
let (w2, b2) = permutation_layer(&in_sub2, evaluator);
// apply the output swithces
for i in 0..(n - 1) / 2 {
let c = evaluator.add_witness_to_cs();
conf.push(c);
let intermediate = mul_with_witness(
evaluator,
&expression_from_witness(c),
&subtract(&b2[i], FieldElement::one(), &b1[i]),
);
out_expr.push(add(&intermediate, FieldElement::one(), &b1[i]));
out_expr.push(subtract(&b2[i], FieldElement::one(), &intermediate));
}
if n % 2 == 0 {
out_expr.push(b1.last().unwrap().clone());
}
out_expr.push(b2.last().unwrap().clone());
conf.extend(w1);
conf.extend(w2);
(conf, out_expr)
}

#[cfg(test)]
mod test {
use std::collections::BTreeMap;

use acvm::{
acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness},
FieldElement, OpcodeResolutionError, PartialWitnessGenerator,
};

use crate::{
ssa::acir_gen::{expression_from_witness, operations::sort::evaluate_permutation},
Evaluator,
};
use rand::prelude::*;

struct MockBackend {}
impl PartialWitnessGenerator for MockBackend {
fn solve_black_box_function_call(
_initial_witness: &mut BTreeMap<Witness, FieldElement>,
_func_call: &BlackBoxFuncCall,
) -> Result<(), OpcodeResolutionError> {
unreachable!();
}
}

// Check that a random network constrains its output to be a permutation of any random input
#[test]
fn test_permutation() {
let mut rng = rand::thread_rng();
for n in 2..50 {
let mut eval = Evaluator {
current_witness_index: 0,
num_witnesses_abi_len: 0,
public_inputs: Vec::new(),
opcodes: Vec::new(),
};

//we generate random inputs
let mut input = Vec::new();
let mut a_val = Vec::new();
let mut b_wit = Vec::new();
let mut solved_witness: BTreeMap<Witness, FieldElement> = BTreeMap::new();
for i in 0..n {
let w = eval.add_witness_to_cs();
input.push(expression_from_witness(w));
a_val.push(FieldElement::from(rng.next_u32() as i128));
solved_witness.insert(w, a_val[i]);
}

let mut output = Vec::new();
for _i in 0..n {
let w = eval.add_witness_to_cs();
b_wit.push(w);
output.push(expression_from_witness(w));
}
//generate constraints for the inputs
let w = evaluate_permutation(&input, &output, &mut eval);

//we generate random network
let mut c = Vec::new();
for _i in 0..w.len() {
c.push(rng.next_u32() % 2 != 0);
}
// intialise bits
for i in 0..w.len() {
solved_witness.insert(w[i], FieldElement::from(c[i] as i128));
}
// compute the network output by solving the constraints
let backend = MockBackend {};
backend
.solve(&mut solved_witness, eval.opcodes.clone())
.expect("Could not solve permutation constraints");
let mut b_val = Vec::new();
for i in 0..output.len() {
b_val.push(solved_witness[&b_wit[i]]);
}
// ensure the outputs are a permutation of the inputs
assert_eq!(a_val.sort(), b_val.sort());
}
}
}
16 changes: 13 additions & 3 deletions crates/noirc_evaluator/src/ssa/builtin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::ssa::node::ObjectType;
use crate::ssa::{
context::SsaContext,
node::{NodeId, ObjectType},
};
use acvm::{acir::BlackBoxFunc, FieldElement};
use num_bigint::BigUint;
use num_traits::{One, Zero};
Expand All @@ -12,6 +15,7 @@ pub enum Opcode {
LowLevel(BlackBoxFunc),
ToBits,
ToRadix,
Sort,
}

impl std::fmt::Display for Opcode {
Expand All @@ -30,6 +34,7 @@ impl Opcode {
match op_name {
"to_le_bits" => Some(Opcode::ToBits),
"to_radix" => Some(Opcode::ToRadix),
"arraysort" => Some(Opcode::Sort),
_ => BlackBoxFunc::lookup(op_name).map(Opcode::LowLevel),
}
}
Expand All @@ -39,6 +44,7 @@ impl Opcode {
Opcode::LowLevel(op) => op.name(),
Opcode::ToBits => "to_le_bits",
Opcode::ToRadix => "to_radix",
Opcode::Sort => "arraysort",
}
}

Expand All @@ -64,13 +70,13 @@ impl Opcode {
}
}
}
Opcode::ToBits | Opcode::ToRadix => BigUint::zero(), //pointers do not overflow
Opcode::ToBits | Opcode::ToRadix | Opcode::Sort => BigUint::zero(), //pointers do not overflow
}
}

/// Returns the number of elements that the `Opcode` should return
/// and the type.
pub fn get_result_type(&self) -> (u32, ObjectType) {
pub fn get_result_type(&self, args: &[NodeId], ctx: &SsaContext) -> (u32, ObjectType) {
match self {
Opcode::LowLevel(op) => {
match op {
Expand All @@ -90,6 +96,10 @@ impl Opcode {
}
Opcode::ToBits => (FieldElement::max_num_bits(), ObjectType::Boolean),
Opcode::ToRadix => (FieldElement::max_num_bits(), ObjectType::NativeField),
Opcode::Sort => {
let a = super::mem::Memory::deref(ctx, args[0]).unwrap();
(ctx.mem[a].len, ctx.mem[a].element_type)
}
}
}
}
3 changes: 1 addition & 2 deletions crates/noirc_evaluator/src/ssa/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl IrGenerator {
op: builtin::Opcode,
args: Vec<NodeId>,
) -> Result<Vec<NodeId>, RuntimeError> {
let (len, elem_type) = op.get_result_type();
let (len, elem_type) = op.get_result_type(&args, &self.context);

let result_type = if len > 1 {
//We create an array that will contain the result and set the res_type to point to that array
Expand All @@ -342,7 +342,6 @@ impl IrGenerator {
} else {
elem_type
};

//when the function returns an array, we use ins.res_type(array)
//else we map ins.id to the returned witness
let id = self.context.new_instruction(node::Operation::Intrinsic(op, args), result_type)?;
Expand Down
15 changes: 2 additions & 13 deletions noir_stdlib/src/array.nr
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
#[builtin(array_len)]
fn len<T>(_input : [T]) -> comptime Field {}

// insertion sort - n.b. it is a quadratic sort
fn sort<T>(mut a: [T]) -> [T] {
for i in 1..len(a) {
for j in 0..i {
if(a[i] < a[j]) {
let c = a[j];
a[j] = a[i];
a[i]= c;
}
};
};
a
}
#[builtin(arraysort)]
fn sort<T, N>(_a: [T; N]) -> [T; N] {}

0 comments on commit 32e9320

Please sign in to comment.