Skip to content

Commit

Permalink
feat: Add support for nested arrays on brillig gen (#2029)
Browse files Browse the repository at this point in the history
* feat: support for nested arrays inside brillig

* refactor: extract entry point management to a mod

* feat: add flatten/deflatten to brillig entrypoint

* style: remove unused import

* perf: deallocate registers on entrypoint

* style: improve error message
  • Loading branch information
sirasistant authored Jul 25, 2023
1 parent 2d87c87 commit 8adc57c
Show file tree
Hide file tree
Showing 12 changed files with 654 additions and 293 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.6.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "0"
y = "1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
struct Header {
params: [Field; 3],
}

struct MyNote {
plain: Field,
array: [Field; 2],
header: Header,
}

unconstrained fn access_nested(notes: [MyNote; 2], x: Field, y: Field) -> Field {
notes[x].array[y] + notes[y].array[x] + notes[x].plain + notes[y].header.params[x]
}

unconstrained fn create_inside_brillig(x: Field, y: Field) {
let header = Header { params: [1, 2, 3]};
let note0 = MyNote { array: [1, 2], plain : 3, header };
let note1 = MyNote { array: [4, 5], plain : 6, header };
assert(access_nested([note0, note1], x, y) == (2 + 4 + 3 + 1));
}

fn main(x: Field, y: Field) {
let header = Header { params: [1, 2, 3]};
let note0 = MyNote { array: [1, 2], plain : 3, header };
let note1 = MyNote { array: [4, 5], plain : 6, header };

create_inside_brillig(x, y);
assert(access_nested([note0, note1], x, y) == (2 + 4 + 3 + 1));
}

6 changes: 1 addition & 5 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ pub(crate) fn convert_ssa_function(func: &Function, enable_debug_trace: bool) ->
let mut function_context =
FunctionContext { function_id: func.id(), ssa_value_to_brillig_variable: HashMap::new() };

let mut brillig_context = BrilligContext::new(
FunctionContext::parameters(func),
FunctionContext::return_values(func),
enable_debug_trace,
);
let mut brillig_context = BrilligContext::new(enable_debug_trace);

brillig_context.enter_context(FunctionContext::function_id_to_function_label(func.id()));
for block in reverse_post_order {
Expand Down
124 changes: 68 additions & 56 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use acvm::FieldElement;
use iter_extended::vecmap;

use super::brillig_black_box::convert_black_box_call;
use super::brillig_fn::{compute_size_of_composite_type, FunctionContext};
use super::brillig_fn::FunctionContext;
use super::brillig_slice_ops::{
slice_insert_operation, slice_pop_back_operation, slice_pop_front_operation,
slice_push_front_operation, slice_remove_operation,
Expand Down Expand Up @@ -397,11 +397,8 @@ impl<'block> BrilligBlock<'block> {
}
Instruction::ArrayGet { array, index } => {
let result_ids = dfg.instruction_results(instruction_id);
let destination_register = self.function_context.create_register_variable(
self.brillig_context,
result_ids[0],
dfg,
);
let destination_variable =
self.function_context.create_variable(self.brillig_context, result_ids[0], dfg);

let array_variable = self.convert_ssa_value(*array, dfg);
let array_pointer = match array_variable {
Expand All @@ -411,12 +408,12 @@ impl<'block> BrilligBlock<'block> {
};

let index_register = self.convert_ssa_register_value(*index, dfg);
self.brillig_context.array_get(array_pointer, index_register, destination_register);
self.convert_ssa_array_get(array_pointer, index_register, destination_variable);
}
Instruction::ArraySet { array, index, value } => {
let source_variable = self.convert_ssa_value(*array, dfg);
let index_register = self.convert_ssa_register_value(*index, dfg);
let value_register = self.convert_ssa_register_value(*value, dfg);
let value_variable = self.convert_ssa_value(*value, dfg);

let result_ids = dfg.instruction_results(instruction_id);
let destination_variable =
Expand All @@ -426,7 +423,7 @@ impl<'block> BrilligBlock<'block> {
source_variable,
destination_variable,
index_register,
value_register,
value_variable,
);
}
_ => todo!("ICE: Instruction not supported {instruction:?}"),
Expand Down Expand Up @@ -486,14 +483,31 @@ impl<'block> BrilligBlock<'block> {
.post_call_prep_returns_load_registers(&returned_registers, &saved_registers);
}

fn convert_ssa_array_get(
&mut self,
array_pointer: RegisterIndex,
index_register: RegisterIndex,
destination_variable: RegisterOrMemory,
) {
match destination_variable {
RegisterOrMemory::RegisterIndex(destination_register) => {
self.brillig_context.array_get(array_pointer, index_register, destination_register);
}
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => {
self.brillig_context.array_get(array_pointer, index_register, pointer);
}
RegisterOrMemory::HeapVector(_) => unimplemented!("ICE: Array get for vector"),
}
}

/// Array set operation in SSA returns a new array or slice that is a copy of the parameter array or slice
/// With a specific value changed.
fn convert_ssa_array_set(
&mut self,
source_variable: RegisterOrMemory,
destination_variable: RegisterOrMemory,
index_register: RegisterIndex,
value_register: RegisterIndex,
value_variable: RegisterOrMemory,
) {
let destination_pointer = match destination_variable {
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => pointer,
Expand Down Expand Up @@ -532,11 +546,30 @@ impl<'block> BrilligBlock<'block> {
}

// Then set the value in the newly created array
self.brillig_context.array_set(destination_pointer, index_register, value_register);
self.store_variable_in_array(destination_pointer, index_register, value_variable);

self.brillig_context.deallocate_register(source_size_as_register);
}

fn store_variable_in_array(
&mut self,
destination_pointer: RegisterIndex,
index_register: RegisterIndex,
value_variable: RegisterOrMemory,
) {
match value_variable {
RegisterOrMemory::RegisterIndex(value_register) => {
self.brillig_context.array_set(destination_pointer, index_register, value_register);
}
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => {
self.brillig_context.array_set(destination_pointer, index_register, pointer);
}
RegisterOrMemory::HeapVector(_) => {
unimplemented!("ICE: cannot store a vector in array")
}
}
}

/// Convert the SSA slice operations to brillig slice operations
fn convert_ssa_slice_intrinsic_call(
&mut self,
Expand Down Expand Up @@ -661,47 +694,6 @@ impl<'block> BrilligBlock<'block> {
}
}

/// This function allows storing a Value in memory starting at the address specified by the
/// address_register. The value can be a single value or an array. The function will recursively
/// store the value in memory.
fn store_in_memory(
&mut self,
address_register: RegisterIndex,
value_id: ValueId,
dfg: &DataFlowGraph,
) {
let value = &dfg[value_id];
match value {
Value::Param { .. } | Value::Instruction { .. } | Value::NumericConstant { .. } => {
let value_register = self.convert_ssa_register_value(value_id, dfg);
self.brillig_context.store_instruction(address_register, value_register);
}
Value::Array { array, element_type } => {
// Allocate a register for the iterator
let iterator_register = self.brillig_context.allocate_register();
// Set the iterator to the address of the array
self.brillig_context.mov_instruction(iterator_register, address_register);

let size_of_item_register = self
.brillig_context
.make_constant(compute_size_of_composite_type(element_type).into());

for element_id in array.iter() {
// Store the item in memory
self.store_in_memory(iterator_register, *element_id, dfg);
// Increment the iterator by the size of the items
self.brillig_context.memory_op(
iterator_register,
size_of_item_register,
iterator_register,
BinaryIntOp::Add,
);
}
}
_ => unimplemented!("ICE: Value {:?} not storeable in memory", value),
}
}

/// Converts an SSA cast to a sequence of Brillig opcodes.
/// Casting is only necessary when shrinking the bit size of a numeric value.
fn convert_cast(
Expand Down Expand Up @@ -802,14 +794,34 @@ impl<'block> BrilligBlock<'block> {
self.brillig_context.const_instruction(register_index, (*constant).into());
new_variable
}
Value::Array { .. } => {
let new_variable =
self.function_context.create_variable(self.brillig_context, value_id, dfg);
Value::Array { array, .. } => {
let new_variable = self.function_context.get_or_create_variable(
self.brillig_context,
value_id,
dfg,
);
let heap_array = self.function_context.extract_heap_array(new_variable);

self.brillig_context
.allocate_fixed_length_array(heap_array.pointer, heap_array.size);
self.store_in_memory(heap_array.pointer, value_id, dfg);

// Allocate a register for the iterator
let iterator_register = self.brillig_context.make_constant(0_usize.into());

for element_id in array.iter() {
let element_variable = self.convert_ssa_value(*element_id, dfg);
// Store the item in memory
self.store_variable_in_array(
heap_array.pointer,
iterator_register,
element_variable,
);
// Increment the iterator
self.brillig_context.usize_op_in_place(iterator_register, BinaryIntOp::Add, 1);
}

self.brillig_context.deallocate_register(iterator_register);

new_variable
}
_ => {
Expand Down
53 changes: 26 additions & 27 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::collections::HashMap;

use acvm::brillig_vm::brillig::{HeapArray, HeapVector, RegisterIndex, RegisterOrMemory};
use iter_extended::vecmap;

use crate::{
brillig::brillig_ir::{
Expand Down Expand Up @@ -37,9 +38,9 @@ impl FunctionContext {
let register = brillig_context.allocate_register();
RegisterOrMemory::RegisterIndex(register)
}
Type::Array(_, _) => {
Type::Array(item_typ, elem_count) => {
let pointer_register = brillig_context.allocate_register();
let size = compute_size_of_type(&typ);
let size = compute_array_length(&item_typ, elem_count);
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_register, size })
}
Type::Slice(_) => {
Expand Down Expand Up @@ -139,18 +140,31 @@ impl FunctionContext {
function_id.to_string()
}

fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter {
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Simple,
Type::Array(item_type, size) => BrilligParameter::Array(
vecmap(item_type.iter(), |item_typ| {
FunctionContext::ssa_type_to_parameter(item_typ)
}),
*size,
),
Type::Slice(item_type) => {
BrilligParameter::Slice(vecmap(item_type.iter(), |item_typ| {
FunctionContext::ssa_type_to_parameter(item_typ)
}))
}
_ => unimplemented!("Unsupported function parameter/return type {typ:?}"),
}
}

/// Collects the parameters of a given function
pub(crate) fn parameters(func: &Function) -> Vec<BrilligParameter> {
func.parameters()
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Register,
Type::Array(..) => BrilligParameter::HeapArray(compute_size_of_type(&typ)),
Type::Slice(_) => BrilligParameter::HeapVector,
_ => unimplemented!("Unsupported function parameter type {typ:?}"),
}
FunctionContext::ssa_type_to_parameter(&typ)
})
.collect()
}
Expand All @@ -161,28 +175,13 @@ impl FunctionContext {
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Register,
Type::Array(..) => BrilligParameter::HeapArray(compute_size_of_type(&typ)),
Type::Slice(_) => BrilligParameter::HeapVector,
_ => unimplemented!("Unsupported return value type {typ:?}"),
}
FunctionContext::ssa_type_to_parameter(&typ)
})
.collect()
}
}

/// Computes the size of an SSA composite type
pub(crate) fn compute_size_of_composite_type(typ: &CompositeType) -> usize {
typ.iter().map(compute_size_of_type).sum()
}

/// Finds out the size of a given SSA type
/// This is needed to store values in memory
pub(crate) fn compute_size_of_type(typ: &Type) -> usize {
match typ {
Type::Numeric(_) => 1,
Type::Array(types, item_count) => compute_size_of_composite_type(types) * item_count,
_ => todo!("ICE: Type not supported {typ:?}"),
}
/// Computes the length of an array. This will match with the indexes that SSA will issue
pub(crate) fn compute_array_length(item_typ: &CompositeType, elem_count: usize) -> usize {
item_typ.len() * elem_count
}
Loading

0 comments on commit 8adc57c

Please sign in to comment.