Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(ssa refactor): Handle Brillig calls #1744

Merged
merged 17 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = "0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Tests a very simple program.
//
// The features being tested is brillig calls
fn main(x: u32) {
assert(entry_point(x) == 2);
swap_entry_point(x, x + 1);
assert(deep_entry_point(x) == 4);
}

unconstrained fn inner(x : u32) -> u32 {
x + 1
}

unconstrained fn entry_point(x : u32) -> u32 {
inner(x + 1)
}

unconstrained fn swap(x: u32, y:u32) -> (u32, u32) {
(y, x)
}

unconstrained fn swap_entry_point(x: u32, y: u32) {
let swapped = swap(x, y);
assert(swapped.0 == y);
assert(swapped.1 == x);
let swapped_twice = swap(swapped.0, swapped.1);
assert(swapped_twice.0 == x);
assert(swapped_twice.1 == y);
}

unconstrained fn level_3(x : u32) -> u32 {
x + 1
}

unconstrained fn level_2(x : u32) -> u32 {
level_3(x + 1)
}

unconstrained fn level_1(x : u32) -> u32 {
level_2(x + 1)
}

unconstrained fn deep_entry_point(x : u32) -> u32 {
level_1(x + 1)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.1"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x = ["1","2","3"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Tests a very simple program.
//
// The features being tested is brillig calls passing arrays around
fn main(x: [u32; 3]) {
assert(entry_point(x) == 9);
}

unconstrained fn inner(x : [u32; 3]) -> [u32; 3] {
[x[0] + 1, x[1] + 1, x[2] + 1]
}

unconstrained fn entry_point(x : [u32; 3]) -> u32 {
let y = inner(x);
y[0] + y[1] + y[2]
}

13 changes: 5 additions & 8 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@ pub(crate) fn convert_ssa_function(func: &Function) -> BrilligArtifact {
let mut function_context =
FunctionContext { function_id: func.id(), ssa_value_to_register: HashMap::new() };

let mut brillig_context = BrilligContext::new();
let mut brillig_context = BrilligContext::new(
FunctionContext::parameters(func),
FunctionContext::return_values(func),
);

brillig_context.enter_context(FunctionContext::function_id_to_function_label(func.id()));
for block in reverse_post_order {
BrilligBlock::compile(&mut function_context, &mut brillig_context, block, &func.dfg);
}

brillig_context.artifact()
}

/// Creates an entry point artifact, that will be linked with the brillig functions being called
pub(crate) fn create_entry_point_function(num_arguments: usize) -> BrilligArtifact {
let mut brillig_context = BrilligContext::new();
brillig_context.entry_point_instruction(num_arguments);
brillig_context.artifact()
}
75 changes: 59 additions & 16 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::brillig::brillig_ir::{
BrilligBinaryOp, BrilligContext, BRILLIG_MEMORY_ADDRESSING_BIT_SIZE,
};
use crate::ssa_refactor::ir::function::FunctionId;
use crate::ssa_refactor::ir::types::CompositeType;
use crate::ssa_refactor::ir::{
basic_block::{BasicBlock, BasicBlockId},
Expand Down Expand Up @@ -39,7 +40,7 @@ impl<'block> BrilligBlock<'block> {

fn convert_block(&mut self, dfg: &DataFlowGraph) {
// Add a label for this block
let block_label = self.create_block_label(self.block_id);
let block_label = self.create_block_label_for_current_function(self.block_id);
self.brillig_context.enter_context(block_label);

// Convert the block parameters
Expand All @@ -57,9 +58,22 @@ impl<'block> BrilligBlock<'block> {
self.convert_ssa_terminator(terminator_instruction, dfg);
}

/// Creates a unique global label for a block
fn create_block_label(&self, block_id: BasicBlockId) -> String {
format!("{}-{}", self.function_context.function_id, block_id)
/// Creates a unique global label for a block.
///
/// This uses the current functions's function ID and the block ID
/// Making the assumption that the block ID passed in belongs to this
/// function.
fn create_block_label_for_current_function(&self, block_id: BasicBlockId) -> String {
Self::create_block_label(self.function_context.function_id, block_id)
}
/// Creates a unique label for a block using the function Id and the block ID.
///
/// We implicitly assume that the function ID and the block ID is enough
/// for us to create a unique label across functions and blocks.
///
/// This is so that during linking there are no duplicates or labels being overwritten.
fn create_block_label(function_id: FunctionId, block_id: BasicBlockId) -> String {
format!("{}-{}", function_id, block_id)
}

/// Converts an SSA terminator instruction into the necessary opcodes.
Expand All @@ -74,9 +88,13 @@ impl<'block> BrilligBlock<'block> {
match terminator_instruction {
TerminatorInstruction::JmpIf { condition, then_destination, else_destination } => {
let condition = self.convert_ssa_value(*condition, dfg);
self.brillig_context
.jump_if_instruction(condition, self.create_block_label(*then_destination));
self.brillig_context.jump_instruction(self.create_block_label(*else_destination));
self.brillig_context.jump_if_instruction(
condition,
self.create_block_label_for_current_function(*then_destination),
);
self.brillig_context.jump_instruction(
self.create_block_label_for_current_function(*else_destination),
);
}
TerminatorInstruction::Jmp { destination, arguments } => {
let target = &dfg[*destination];
Expand All @@ -85,7 +103,8 @@ impl<'block> BrilligBlock<'block> {
let source = self.convert_ssa_value(*src, dfg);
self.brillig_context.mov_instruction(destination, source);
}
self.brillig_context.jump_instruction(self.create_block_label(*destination));
self.brillig_context
.jump_instruction(self.create_block_label_for_current_function(*destination));
}
TerminatorInstruction::Return { return_values } => {
let return_registers: Vec<_> = return_values
Expand All @@ -106,15 +125,12 @@ impl<'block> BrilligBlock<'block> {
_ => unreachable!("ICE: Only Param type values should appear in block parameters"),
};
match param_type {
Type::Numeric(_) => {
// Simple parameters and arrays are passed as already filled registers
// In the case of arrays, the values should already be in memory and the register should
// Be a valid pointer to the array.
Type::Numeric(_) | Type::Array(..) => {
self.function_context.get_or_create_register(self.brillig_context, *param_id);
}
Type::Array(_, size) => {
let pointer_register = self
.function_context
.get_or_create_register(self.brillig_context, *param_id);
self.brillig_context.allocate_fixed_length_array(pointer_register, *size);
}
_ => {
todo!("ICE: Param type not supported")
}
Expand Down Expand Up @@ -189,6 +205,33 @@ impl<'block> BrilligBlock<'block> {
&output_registers,
);
}
Value::Function(func_id) => {
let function_arguments: Vec<RegisterIndex> =
vecmap(arguments, |arg| self.convert_ssa_value(*arg, dfg));
let result_ids = dfg.instruction_results(instruction_id);

// Create label for the function that will be called
let label_of_function_to_call =
FunctionContext::function_id_to_function_label(*func_id);

let saved_registers =
self.brillig_context.pre_call_save_registers_prep_args(&function_arguments);

// Call instruction, which will interpret above registers 0..num args
self.brillig_context.add_external_call_instruction(label_of_function_to_call);

// Important: resolve after pre_call_save_registers_prep_args
// This ensures we don't save the results to registers unnecessarily.
let result_registers = vecmap(result_ids, |a| {
self.function_context.get_or_create_register(self.brillig_context, *a)
});
assert!(
!saved_registers.iter().any(|x| result_registers.contains(x)),
"should not save registers used as function results"
);
self.brillig_context
.post_call_prep_returns_load_registers(&result_registers, &saved_registers);
}
_ => {
unreachable!("only foreign function calls supported in unconstrained functions")
}
Expand Down Expand Up @@ -502,7 +545,7 @@ fn compute_size_of_composite_type(typ: &CompositeType) -> usize {

/// Finds out the size of a given SSA type
/// This is needed to store values in memory
fn compute_size_of_type(typ: &Type) -> usize {
pub(crate) fn compute_size_of_type(typ: &Type) -> usize {
match typ {
Type::Numeric(_) => 1,
Type::Array(types, item_count) => compute_size_of_composite_type(types) * item_count,
Expand Down
59 changes: 57 additions & 2 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@ use std::collections::HashMap;
use acvm::acir::brillig_vm::RegisterIndex;

use crate::{
brillig::brillig_ir::BrilligContext,
ssa_refactor::ir::{function::FunctionId, value::ValueId},
brillig::brillig_ir::{
artifact::{BrilligParameter, Label},
BrilligContext,
},
ssa_refactor::ir::{
function::{Function, FunctionId},
instruction::TerminatorInstruction,
types::Type,
value::ValueId,
},
};

use super::brillig_block::compute_size_of_type;

pub(crate) struct FunctionContext {
pub(crate) function_id: FunctionId,
/// Map from SSA values to Register Indices.
Expand Down Expand Up @@ -38,4 +48,49 @@ impl FunctionContext {

register
}

/// Creates a function label from a given SSA function id.
pub(crate) fn function_id_to_function_label(function_id: FunctionId) -> Label {
function_id.to_string()
}

/// Collects the parameters of a given function
pub(crate) fn parameters(func: &Function) -> Vec<BrilligParameter> {
func.parameters()
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
match typ {
Type::Numeric(_) => BrilligParameter::Register,
Type::Array(..) => BrilligParameter::HeapArray(compute_size_of_type(&typ)),
_ => unimplemented!("Unsupported function parameter type {typ:?}"),
}
})
.collect()
}

/// Collects the return values of a given function
pub(crate) fn return_values(func: &Function) -> Vec<BrilligParameter> {
let blocks = func.reachable_blocks();
let mut function_return_values = None;
for block in blocks {
let terminator = func.dfg[block].terminator();
if let Some(TerminatorInstruction::Return { return_values }) = terminator {
function_return_values = Some(return_values);
break;
}
}
function_return_values
.expect("Expected a return instruction, as block is finished construction")
.iter()
.map(|&value_id| {
let typ = func.dfg.type_of_value(value_id);
match typ {
Type::Numeric(_) => BrilligParameter::Register,
Type::Array(..) => BrilligParameter::HeapArray(compute_size_of_type(&typ)),
_ => unimplemented!("Unsupported return value type {typ:?}"),
}
})
.collect()
}
}
Loading