Skip to content

Commit

Permalink
fix: call bug and allow loops with fields
Browse files Browse the repository at this point in the history
  • Loading branch information
sirasistant committed Jul 19, 2023
1 parent e3a87ad commit 1f37820
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 72 deletions.
198 changes: 129 additions & 69 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::collections::HashMap;
use crate::brillig::brillig_gen::brillig_slice_ops::{
convert_array_or_vector_to_vector, slice_push_back_operation,
};
use crate::brillig::brillig_ir::{BrilligBinaryOp, BrilligContext};
use crate::brillig::brillig_ir::{
BrilligBinaryOp, BrilligContext, BRILLIG_INTEGER_ARITHMETIC_BIT_SIZE,
};
use crate::ssa_refactor::ir::function::{FunctionId, Signature};
use crate::ssa_refactor::ir::instruction::Intrinsic;
use crate::ssa_refactor::ir::{
Expand Down Expand Up @@ -291,57 +293,7 @@ impl<'block> BrilligBlock<'block> {
}
}
Value::Function(func_id) => {
let signature_of_called_function = self
.function_to_signature
.get(func_id)
.expect("ICE: cannot find function signature");

let argument_registers: Vec<RegisterIndex> = arguments
.iter()
.zip(&signature_of_called_function.params)
.flat_map(|(argument_id, receiver_typ)| {
let variable_to_pass = self.convert_ssa_value(*argument_id, dfg);
let casted_to_param_type =
self.cast_variable_for_call(variable_to_pass, receiver_typ);
self.function_context.extract_registers(casted_to_param_type)
})
.collect();

let result_ids = dfg.instruction_results(instruction_id);

// Create label for the function that will be called
let label_of_function_to_call =
FunctionContext::function_id_to_function_label(*func_id);

let saved_registers =
self.brillig_context.pre_call_save_registers_prep_args(&argument_registers);

// Call instruction, which will interpret above registers 0..num args
self.brillig_context.add_external_call_instruction(label_of_function_to_call);

// Important: resolve after pre_call_save_registers_prep_args
// This ensures we don't save the results to registers unnecessarily.
let result_registers: Vec<RegisterIndex> = result_ids
.iter()
.zip(&signature_of_called_function.returns)
.flat_map(|(result_id, receiver_typ)| {
let variable_assigned_to = self.function_context.create_variable(
self.brillig_context,
*result_id,
dfg,
);
let casted_to_return_type = self
.cast_back_variable_from_call(variable_assigned_to, receiver_typ);
self.function_context.extract_registers(casted_to_return_type)
})
.collect();

assert!(
!saved_registers.iter().any(|x| result_registers.contains(x)),
"should not save registers used as function results"
);
self.brillig_context
.post_call_prep_returns_load_registers(&result_registers, &saved_registers);
self.convert_ssa_function_call(*func_id, arguments, dfg, instruction_id);
}
Value::Intrinsic(Intrinsic::BlackBox(bb_func)) => {
let function_arguments =
Expand Down Expand Up @@ -448,6 +400,81 @@ impl<'block> BrilligBlock<'block> {
};
}

fn convert_ssa_function_call(
&mut self,
func_id: FunctionId,
arguments: &[ValueId],
dfg: &DataFlowGraph,
instruction_id: InstructionId,
) {
let signature_of_called_function =
self.function_to_signature.get(&func_id).expect("ICE: cannot find function signature");

// Convert the arguments to registers casting those to the types of the receiving function
let argument_registers: Vec<RegisterIndex> = arguments
.iter()
.zip(&signature_of_called_function.params)
.flat_map(|(argument_id, receiver_typ)| {
let variable_to_pass = self.convert_ssa_value(*argument_id, dfg);
let casted_to_param_type =
self.cast_variable_for_call(variable_to_pass, receiver_typ);
self.function_context.extract_registers(casted_to_param_type)
})
.collect();

let result_ids = dfg.instruction_results(instruction_id);

// Create label for the function that will be called
let label_of_function_to_call = FunctionContext::function_id_to_function_label(func_id);

let saved_registers =
self.brillig_context.pre_call_save_registers_prep_args(&argument_registers);

// Call instruction, which will interpret above registers 0..num args
self.brillig_context.add_external_call_instruction(label_of_function_to_call);

// Important: resolve after pre_call_save_registers_prep_args
// This ensures we don't save the results to registers unnecessarily.

// Allocate the registers for the variables where we are assigning the returns
let variables_assigned_to = vecmap(result_ids, |result_id| {
self.function_context.create_variable(self.brillig_context, *result_id, dfg)
});

// Transform the assigned to variables into the types of the called function returns
let returned_variables: Vec<RegisterOrMemory> = variables_assigned_to
.iter()
.zip(&signature_of_called_function.returns)
.map(|(variable_assigned_to, return_typ)| {
self.cast_back_variable_from_call(*variable_assigned_to, return_typ)
})
.collect();

// Collect the registers that should have been returned
let returned_registers: Vec<RegisterIndex> = returned_variables
.iter()
.flat_map(|casted_to_return_type| {
self.function_context.extract_registers(*casted_to_return_type)
})
.collect();

assert!(
!saved_registers.iter().any(|x| returned_registers.contains(x)),
"should not save registers used as function results"
);

// puts the returns into the returned_registers and restores saved_registers
self.brillig_context
.post_call_prep_returns_load_registers(&returned_registers, &saved_registers);

// Reconciliate the types of the variables that the returns are assigned to with the types of the returns
variables_assigned_to.iter().zip(returned_variables).for_each(
|(variable_assigned_to, return_variable)| {
self.reconciliate_from_call(*variable_assigned_to, return_variable);
},
);
}

fn cast_variable_for_call(
&mut self,
variable_to_pass: RegisterOrMemory,
Expand All @@ -468,9 +495,24 @@ impl<'block> BrilligBlock<'block> {
) -> RegisterOrMemory {
match (variable_assigned_to, return_type) {
(RegisterOrMemory::HeapVector(vector), Type::Array(..)) => {
let size = compute_size_of_type(return_type);
self.brillig_context.const_instruction(vector.size, size.into());
RegisterOrMemory::HeapArray(HeapArray { pointer: vector.pointer, size })
RegisterOrMemory::HeapArray(HeapArray {
pointer: vector.pointer,
size: compute_size_of_type(return_type),
})
}
(_, _) => variable_assigned_to,
}
}

fn reconciliate_from_call(
&mut self,
variable_assigned_to: RegisterOrMemory,
return_variable: RegisterOrMemory,
) -> RegisterOrMemory {
match (variable_assigned_to, return_variable) {
(RegisterOrMemory::HeapVector(vector), RegisterOrMemory::HeapArray(array)) => {
self.brillig_context.const_instruction(vector.size, array.size.into());
RegisterOrMemory::HeapVector(vector)
}
(_, _) => variable_assigned_to,
}
Expand Down Expand Up @@ -743,11 +785,27 @@ impl<'block> BrilligBlock<'block> {
let binary_type =
type_of_binary_operation(dfg[binary.lhs].get_type(), dfg[binary.rhs].get_type());

let left = self.convert_ssa_register_value(binary.lhs, dfg);
let right = self.convert_ssa_register_value(binary.rhs, dfg);
let mut left = self.convert_ssa_register_value(binary.lhs, dfg);
let mut right = self.convert_ssa_register_value(binary.rhs, dfg);

let brillig_binary_op =
convert_ssa_binary_op_to_brillig_binary_op(binary.operator, binary_type);
convert_ssa_binary_op_to_brillig_binary_op(binary.operator, &binary_type);

// Some binary operations with fields are issued by the compiler, such as loop comparisons, cast those to the bit size here
if let (
BrilligBinaryOp::Integer { bit_size, .. },
Type::Numeric(NumericType::NativeField),
) = (&brillig_binary_op, &binary_type)
{
let new_lhs = self.brillig_context.allocate_register();
let new_rhs = self.brillig_context.allocate_register();

self.brillig_context.cast_instruction(new_lhs, left, *bit_size);
self.brillig_context.cast_instruction(new_rhs, right, *bit_size);

left = new_lhs;
right = new_rhs;
}

self.brillig_context.binary_instruction(left, right, result_register, brillig_binary_op);
}
Expand Down Expand Up @@ -876,7 +934,7 @@ pub(crate) fn type_of_binary_operation(lhs_type: Type, rhs_type: Type) -> Type {
/// - Brillig Binary Field Op, if it is a field type
pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(
ssa_op: BinaryOp,
typ: Type,
typ: &Type,
) -> BrilligBinaryOp {
// First get the bit size and whether its a signed integer, if it is a numeric type
// if it is not,then we return None, indicating that
Expand All @@ -891,18 +949,20 @@ pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(
};

fn binary_op_to_field_op(op: BinaryOp) -> BrilligBinaryOp {
let operation = match op {
BinaryOp::Add => BinaryFieldOp::Add,
BinaryOp::Sub => BinaryFieldOp::Sub,
BinaryOp::Mul => BinaryFieldOp::Mul,
BinaryOp::Div => BinaryFieldOp::Div,
BinaryOp::Eq => BinaryFieldOp::Equals,
match op {
BinaryOp::Add => BrilligBinaryOp::Field { op: BinaryFieldOp::Add },
BinaryOp::Sub => BrilligBinaryOp::Field { op: BinaryFieldOp::Sub },
BinaryOp::Mul => BrilligBinaryOp::Field { op: BinaryFieldOp::Mul },
BinaryOp::Div => BrilligBinaryOp::Field { op: BinaryFieldOp::Div },
BinaryOp::Eq => BrilligBinaryOp::Field { op: BinaryFieldOp::Equals },
BinaryOp::Lt => BrilligBinaryOp::Integer {
op: BinaryIntOp::LessThan,
bit_size: BRILLIG_INTEGER_ARITHMETIC_BIT_SIZE,
},
_ => unreachable!(
"Field type cannot be used with {op}. This should have been caught by the frontend"
),
};

BrilligBinaryOp::Field { op: operation }
}
}

fn binary_op_to_int_op(op: BinaryOp, bit_size: u32, is_signed: bool) -> BrilligBinaryOp {
Expand Down Expand Up @@ -934,7 +994,7 @@ pub(crate) fn convert_ssa_binary_op_to_brillig_binary_op(

// If bit size is available then it is a binary integer operation
match bit_size_signedness {
Some((bit_size, is_signed)) => binary_op_to_int_op(ssa_op, bit_size, is_signed),
Some((bit_size, is_signed)) => binary_op_to_int_op(ssa_op, *bit_size, is_signed),
None => binary_op_to_field_op(ssa_op),
}
}
10 changes: 7 additions & 3 deletions crates/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ impl BrilligContext {
.collect();
for (new_source, destination) in new_sources.iter().zip(destinations.iter()) {
self.mov_instruction(*destination, *new_source);
self.deallocate_register(*new_source);
}
}

Expand Down Expand Up @@ -821,9 +822,12 @@ impl BrilligContext {
) {
// Allocate our result registers and write into them
// We assume the return values of our call are held in 0..num results register indices
for (i, result_register) in result_registers.iter().enumerate() {
self.mov_instruction(*result_register, self.register(i));
}
let (sources, destinations) = result_registers
.iter()
.enumerate()
.map(|(i, result_register)| (self.register(i), *result_register))
.unzip();
self.mov_registers_to_registers_instruction(sources, destinations);

// Restore all the same registers we have, in exact reverse order.
// Note that we have allocated some registers above, which we will not be handling here,
Expand Down

0 comments on commit 1f37820

Please sign in to comment.