Skip to content

Commit

Permalink
feat: better take at type hierarchy in SSA
Browse files Browse the repository at this point in the history
  • Loading branch information
sirasistant committed Jul 18, 2023
1 parent 211e251 commit c539c2d
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,93 @@ use dep::std::slice;
use dep::std;

unconstrained fn main(x: Field, y: Field) {
// Mark it as mut so the compiler doesn't simplify the following operations
// But don't reuse the mut slice variable until this is fixed https://github.com/noir-lang/noir/issues/1931
let slice: [Field] = [y, x];
//Get the slice from a function to correct the type until https://github.com/noir-lang/noir/issues/1931 is fixed
let mut slice: [Field] = get_me_a_slice(y, x);
assert(slice.len() == 2);

let mut pushed_back_slice = slice.push_back(7);
assert(pushed_back_slice.len() == 3);
assert(pushed_back_slice[0] == y);
assert(pushed_back_slice[1] == x);
assert(pushed_back_slice[2] == 7);
slice = slice.push_back(7);
assert(slice.len() == 3);
assert(slice[0] == y);
assert(slice[1] == x);
assert(slice[2] == 7);

// Array set on slice target
pushed_back_slice[0] = x;
pushed_back_slice[1] = y;
pushed_back_slice[2] = 1;

assert(pushed_back_slice[0] == x);
assert(pushed_back_slice[1] == y);
assert(pushed_back_slice[2] == 1);

assert(slice.len() == 2);

let pushed_front_slice = pushed_back_slice.push_front(2);
assert(pushed_front_slice.len() == 4);
assert(pushed_front_slice[0] == 2);
assert(pushed_front_slice[1] == x);
assert(pushed_front_slice[2] == y);
assert(pushed_front_slice[3] == 1);

let (item, popped_front_slice) = pushed_front_slice.pop_front();
slice[0] = x;
slice[1] = y;
slice[2] = 1;

assert(slice[0] == x);
assert(slice[1] == y);
assert(slice[2] == 1);

slice = slice.push_front(2);
assert(slice.len() == 4);
assert(slice[0] == 2);
assert(slice[1] == x);
assert(slice[2] == y);
assert(slice[3] == 1);

let (item, popped_front_slice) = slice.pop_front();
assert(item == 2);
slice = popped_front_slice;

assert(popped_front_slice.len() == 3);
assert(popped_front_slice[0] == x);
assert(popped_front_slice[1] == y);
assert(popped_front_slice[2] == 1);
assert(slice.len() == 3);
assert(slice[0] == x);
assert(slice[1] == y);
assert(slice[2] == 1);

let (popped_back_slice, another_item) = popped_front_slice.pop_back();
let (popped_back_slice, another_item) = slice.pop_back();
assert(another_item == 1);
slice = popped_back_slice;

assert(popped_back_slice.len() == 2);
assert(popped_back_slice[0] == x);
assert(popped_back_slice[1] == y);
assert(slice.len() == 2);
assert(slice[0] == x);
assert(slice[1] == y);

let inserted_slice = popped_back_slice.insert(1, 2);
assert(inserted_slice.len() == 3);
assert(inserted_slice[0] == x);
assert(inserted_slice[1] == 2);
assert(inserted_slice[2] == y);
slice = slice.insert(1, 2);
assert(slice.len() == 3);
assert(slice[0] == x);
assert(slice[1] == 2);
assert(slice[2] == y);

let (removed_slice, should_be_2) = inserted_slice.remove(1);
let (removed_slice, should_be_2) = slice.remove(1);
assert(should_be_2 == 2);
slice = removed_slice;

assert(removed_slice.len() == 2);
assert(removed_slice[0] == x);
assert(removed_slice[1] == y);
assert(slice.len() == 2);
assert(slice[0] == x);
assert(slice[1] == y);

let (slice_with_only_x, should_be_y) = removed_slice.remove(1);
let (slice_with_only_x, should_be_y) = slice.remove(1);
assert(should_be_y == y);
slice = slice_with_only_x;

assert(slice_with_only_x.len() == 1);
assert(slice.len() == 1);
assert(removed_slice[0] == x);

let (empty_slice, should_be_x) = slice_with_only_x.remove(0);
let (empty_slice, should_be_x) = slice.remove(0);
assert(should_be_x == x);
assert(empty_slice.len() == 0);


let sequence = create_sequence([], 5);

assert(sequence.len() == 5);
assert(sequence[0] == 5);
assert(sequence[1] == 4);
assert(sequence[2] == 3);
assert(sequence[3] == 2);
assert(sequence[4] == 1);
}

unconstrained fn get_me_a_slice(x: Field, y: Field) -> [Field] {
[x, y]
}

unconstrained fn create_sequence(mut sequence: [Field], n: Field) -> [Field] {
if n != 0 {
sequence = create_sequence(sequence.push_back(n), n - 1);
}

sequence
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.6.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "5"
y = "10"
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use dep::std::slice;
use dep::std;

unconstrained fn main(x: Field, y: Field) {
let mut array = [x, y];

// Dynamic dispatch typing
let mut test_slice = wrapper(push_accepts_array_returns_slice, array);
assert(test_slice[2] == 1);

test_slice = wrapper(push_accepts_slice_returns_slice, array);
assert(test_slice[2] == 1);


//Get the slice from a function to correct the type until https://github.com/noir-lang/noir/issues/1931 is fixed
let mut slice = get_me_a_slice(x, y);


// No cast
let another_array = push_accepts_array_returns_array(array);
assert(another_array[2] == 1);

slice = push_accepts_array_returns_slice(array);
assert(slice[2] == 1);

// Cast return value
slice = push_accepts_array_returns_array(array);
assert(slice[2] == 1);

// Cast param
slice = push_accepts_slice_returns_slice(array);
assert(slice[2] == 1);
}

unconstrained fn get_me_a_slice(x: Field, y: Field) -> [Field] {
[x, y]
}

unconstrained fn push_accepts_array_returns_slice<N>(array: [Field; N]) -> [Field] {
let slice: [Field] = array;
slice.push_back(1)
}

unconstrained fn push_accepts_array_returns_array(array: [Field; 2]) -> [Field; 3] {
[array[0], array[1], 1]
}

unconstrained fn push_accepts_slice_returns_slice(slice: [Field]) -> [Field] {
slice.push_back(1)
}

unconstrained fn wrapper(function: fn ([Field]) -> [Field], param: [Field; 2]) -> [Field] {
function(param)
}
18 changes: 15 additions & 3 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ pub(crate) mod brillig_directive;
pub(crate) mod brillig_fn;
pub(crate) mod brillig_slice_ops;

use crate::ssa_refactor::ir::{function::Function, post_order::PostOrder};
use crate::ssa_refactor::ir::{
function::{Function, FunctionId, Signature},
post_order::PostOrder,
};

use std::collections::HashMap;

Expand All @@ -13,7 +16,10 @@ use self::{brillig_block::BrilligBlock, brillig_fn::FunctionContext};
use super::brillig_ir::{artifact::BrilligArtifact, BrilligContext};

/// Converting an SSA function into Brillig bytecode.
pub(crate) fn convert_ssa_function(func: &Function) -> BrilligArtifact {
pub(crate) fn convert_ssa_function(
func: &Function,
function_to_signature: &HashMap<FunctionId, Signature>,
) -> BrilligArtifact {
let mut reverse_post_order = Vec::new();
reverse_post_order.extend_from_slice(PostOrder::with_function(func).as_slice());
reverse_post_order.reverse();
Expand All @@ -28,7 +34,13 @@ pub(crate) fn convert_ssa_function(func: &Function) -> BrilligArtifact {

brillig_context.enter_context(FunctionContext::function_id_to_function_label(func.id()));
for block in reverse_post_order {
BrilligBlock::compile(&mut function_context, &mut brillig_context, block, &func.dfg);
BrilligBlock::compile(
&mut function_context,
&mut brillig_context,
block,
&func.dfg,
function_to_signature,
);
}

brillig_context.artifact()
Expand Down
66 changes: 56 additions & 10 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::HashMap;

use crate::brillig::brillig_gen::brillig_slice_ops::{
convert_array_or_vector_to_vector, slice_push_back_operation,
};
use crate::brillig::brillig_ir::{BrilligBinaryOp, BrilligContext};
use crate::ssa_refactor::ir::function::FunctionId;
use crate::ssa_refactor::ir::function::{FunctionId, Signature};
use crate::ssa_refactor::ir::instruction::Intrinsic;
use crate::ssa_refactor::ir::{
basic_block::{BasicBlock, BasicBlockId},
Expand All @@ -17,7 +19,7 @@ use acvm::FieldElement;
use iter_extended::vecmap;

use super::brillig_black_box::convert_black_box_call;
use super::brillig_fn::{compute_size_of_composite_type, FunctionContext};
use super::brillig_fn::{compute_size_of_composite_type, compute_size_of_type, FunctionContext};
use super::brillig_slice_ops::{
slice_insert_operation, slice_pop_back_operation, slice_pop_front_operation,
slice_push_front_operation, slice_remove_operation,
Expand All @@ -30,6 +32,8 @@ pub(crate) struct BrilligBlock<'block> {
block_id: BasicBlockId,
/// Context for creating brillig opcodes
brillig_context: &'block mut BrilligContext,
/// A map of function ids to their signatures
function_to_signature: &'block HashMap<FunctionId, Signature>,
}

impl<'block> BrilligBlock<'block> {
Expand All @@ -39,8 +43,10 @@ impl<'block> BrilligBlock<'block> {
brillig_context: &'block mut BrilligContext,
block_id: BasicBlockId,
dfg: &DataFlowGraph,
function_to_signature: &'block HashMap<FunctionId, Signature>,
) {
let mut brillig_block = BrilligBlock { function_context, block_id, brillig_context };
let mut brillig_block =
BrilligBlock { function_context, block_id, brillig_context, function_to_signature };

brillig_block.convert_block(dfg);
}
Expand Down Expand Up @@ -285,13 +291,22 @@ impl<'block> BrilligBlock<'block> {
}
}
Value::Function(func_id) => {
let signature_of_called_function = self
.function_to_signature
.get(func_id)
.expect("ICE: cannot find function signature");

let argument_registers: Vec<RegisterIndex> = arguments
.iter()
.flat_map(|arg| {
let arg = self.convert_ssa_value(*arg, dfg);
self.function_context.extract_registers(arg)
.zip(&signature_of_called_function.params)
.flat_map(|(argument_id, receiver_typ)| {
let variable_to_pass = self.convert_ssa_value(*argument_id, dfg);
let casted_to_param_type =
self.cast_variable_for_call(variable_to_pass, receiver_typ);
self.function_context.extract_registers(casted_to_param_type)
})
.collect();

let result_ids = dfg.instruction_results(instruction_id);

// Create label for the function that will be called
Expand All @@ -308,13 +323,16 @@ impl<'block> BrilligBlock<'block> {
// This ensures we don't save the results to registers unnecessarily.
let result_registers: Vec<RegisterIndex> = result_ids
.iter()
.flat_map(|arg| {
let arg = self.function_context.create_variable(
.zip(&signature_of_called_function.returns)
.flat_map(|(result_id, receiver_typ)| {
let variable_assigned_to = self.function_context.create_variable(
self.brillig_context,
*arg,
*result_id,
dfg,
);
self.function_context.extract_registers(arg)
let casted_to_return_type = self
.cast_back_variable_from_call(variable_assigned_to, receiver_typ);
self.function_context.extract_registers(casted_to_return_type)
})
.collect();

Expand Down Expand Up @@ -430,6 +448,34 @@ impl<'block> BrilligBlock<'block> {
};
}

fn cast_variable_for_call(
&mut self,
variable_to_pass: RegisterOrMemory,
param_type: &Type,
) -> RegisterOrMemory {
match (variable_to_pass, param_type) {
(RegisterOrMemory::HeapArray(array), Type::Slice(..)) => {
RegisterOrMemory::HeapVector(self.brillig_context.array_to_vector(&array))
}
(_, _) => variable_to_pass,
}
}

fn cast_back_variable_from_call(
&mut self,
variable_assigned_to: RegisterOrMemory,
return_type: &Type,
) -> RegisterOrMemory {
match (variable_assigned_to, return_type) {
(RegisterOrMemory::HeapVector(vector), Type::Array(..)) => {
let size = compute_size_of_type(return_type);
self.brillig_context.const_instruction(vector.size, size.into());
RegisterOrMemory::HeapArray(HeapArray { pointer: vector.pointer, size })
}
(_, _) => variable_assigned_to,
}
}

/// Array set operation in SSA returns a new array or slice that is a copy of the parameter array or slice
/// With a specific value changed.
fn convert_ssa_array_set(
Expand Down
Loading

0 comments on commit c539c2d

Please sign in to comment.