Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement copy on write optimization for arrays in brillig #3118

Closed
wants to merge 12 commits into from
48 changes: 27 additions & 21 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
// resolved inputs back to the caller. Once the caller pushes to `foreign_call_results`,
// they can then make another call to the VM that starts at this opcode
// but has the necessary results to proceed with execution.
let resolved_inputs = inputs

Check warning on line 204 in acvm-repo/brillig_vm/src/lib.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (callstack)
.iter()
.map(|input| self.get_register_value_or_memory_values(*input))
.collect::<Vec<_>>();
Expand All @@ -221,32 +221,38 @@
"Function result size does not match brillig bytecode (expected 1 result)"
),
},
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => {
RegisterOrMemory::HeapArray(HeapArray { pointer, size }) => {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
invalid_foreign_call_result = true;
break;
}
// Convert the destination pointer to a usize
let destination = self.registers.get(*pointer_index).to_usize();
// Write to our destination memory
self.memory.write_slice(destination, values);
let destination = self.registers.get(*pointer).to_usize();

// Write to the reference count
self.memory.write_slice(destination, &[1_usize.into()]);
// Then write to the rest of the array
self.memory.write_slice(destination + 1, values);
}
_ => {
unreachable!("Function result size does not match brillig bytecode size")
}
}
}
RegisterOrMemory::HeapVector(HeapVector { pointer: pointer_index, size: size_index }) => {
RegisterOrMemory::HeapVector(HeapVector { pointer, size }) => {
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size register
self.registers.set(*size_index, Value::from(values.len()));
self.registers.set(*size, Value::from(values.len()));
// Convert the destination pointer to a usize
let destination = self.registers.get(*pointer_index).to_usize();
// Write to our destination memory
self.memory.write_slice(destination, values);
let destination = self.registers.get(*pointer).to_usize();

// Write to the reference count
self.memory.write_slice(destination, &[1_usize.into()]);
// Then write to the rest of the vector
self.memory.write_slice(destination + 1, values);
}
_ => {
unreachable!("Function result size does not match brillig bytecode size")
Expand Down Expand Up @@ -337,17 +343,16 @@
fn get_register_value_or_memory_values(&self, input: RegisterOrMemory) -> ForeignCallParam {
match input {
RegisterOrMemory::RegisterIndex(value_index) => self.registers.get(value_index).into(),
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_index, size }) => {
let start = self.registers.get(pointer_index);
self.memory.read_slice(start.to_usize(), size).to_vec().into()
RegisterOrMemory::HeapArray(HeapArray { pointer, size }) => {
// add 1 to the start to skip past the reference count field
let start = self.registers.get(pointer).to_usize() + 1;
self.memory.read_slice(start, size).to_vec().into()
}
RegisterOrMemory::HeapVector(HeapVector {
pointer: pointer_index,
size: size_index,
}) => {
let start = self.registers.get(pointer_index);
let size = self.registers.get(size_index);
self.memory.read_slice(start.to_usize(), size.to_usize()).to_vec().into()
RegisterOrMemory::HeapVector(HeapVector { pointer, size }) => {
// add 1 to the start to skip past the reference count field
let start = self.registers.get(pointer).to_usize() + 1;
let size = self.registers.get(size);
self.memory.read_slice(start, size.to_usize()).to_vec().into()
}
}
}
Expand Down Expand Up @@ -516,7 +521,7 @@
rhs: RegisterIndex::from(1),
destination: RegisterIndex::from(2),
};

Check warning on line 524 in acvm-repo/brillig_vm/src/lib.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (jmpifnot)
let jump_opcode = Opcode::Jump { location: 2 };

let jump_if_not_opcode =
Expand Down Expand Up @@ -943,6 +948,7 @@
// Ensure the foreign call counter has been incremented
assert_eq!(vm.foreign_call_counter, 1);
}

#[test]
fn foreign_call_opcode_memory_result() {
let r_input = RegisterIndex::from(0);
Expand Down Expand Up @@ -1009,8 +1015,8 @@
let r_input_pointer = RegisterIndex::from(0);
let r_input_size = RegisterIndex::from(1);
// We need to pass a location of appropriate size
let r_output_pointer = RegisterIndex::from(2);
let r_output_size = RegisterIndex::from(3);
let r_output_pointer = RegisterIndex::from(3);
let r_output_size = RegisterIndex::from(4);

// Our first string to use the identity function with
let input_string =
Expand Down
148 changes: 106 additions & 42 deletions compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,15 @@ impl<'block> BrilligBlock<'block> {
// In the case of arrays, the values should already be in memory and the register should
// Be a valid pointer to the array.
// For slices, two registers are passed, the pointer to the data and a register holding the size of the slice.
Type::Numeric(_) | Type::Array(..) | Type::Slice(..) | Type::Reference => {
Type::Numeric(_) | Type::Array(..) | Type::Slice(..) | Type::Reference(_) => {
self.variables.get_block_param(
self.function_context,
self.block_id,
*param_id,
dfg,
);
}
_ => {
todo!("ICE: Param type not supported")
}
Type::Function => todo!("ICE: Param type not supported"),
}
}
}
Expand Down Expand Up @@ -540,21 +538,32 @@ impl<'block> BrilligBlock<'block> {
let value_variable = self.convert_ssa_value(*value, dfg);

let result_ids = dfg.instruction_results(instruction_id);
let destination_variable = self.variables.define_variable(
self.function_context,
self.brillig_context,
result_ids[0],
dfg,
);

self.convert_ssa_array_set(
source_variable,
destination_variable,
index_register,
value_variable,
result_ids[0],
dfg,
);
}
_ => todo!("ICE: Instruction not supported {instruction:?}"),
Instruction::IncrementRc { value } => {
let rc_register = match self.convert_ssa_value(*value, dfg) {
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => {
todo!()
}
RegisterOrMemory::HeapVector(HeapVector { pointer, .. }) => {
todo!()
}
_ => unreachable!("ICE: array set on non-array"),
};

// TODO: Does this work for += 1?
self.brillig_context.increment(rc_register, rc_register);
}
Instruction::EnableSideEffects { .. } => {
todo!("ICE: Instruction not supported {instruction:?}")
}
};

let dead_variables = self
Expand Down Expand Up @@ -658,16 +667,11 @@ impl<'block> BrilligBlock<'block> {
fn convert_ssa_array_set(
&mut self,
source_variable: RegisterOrMemory,
destination_variable: RegisterOrMemory,
index_register: RegisterIndex,
value_variable: RegisterOrMemory,
destination_id: ValueId,
dfg: &DataFlowGraph,
) {
let destination_pointer = match destination_variable {
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => pointer,
RegisterOrMemory::HeapVector(HeapVector { pointer, .. }) => pointer,
_ => unreachable!("ICE: array set returns non-array"),
};

// First issue a array copy to the destination
let (source_pointer, source_size_as_register) = match source_variable {
RegisterOrMemory::HeapArray(HeapArray { size, pointer }) => {
Expand All @@ -683,23 +687,63 @@ impl<'block> BrilligBlock<'block> {
_ => unreachable!("ICE: array set on non-array"),
};

self.brillig_context
.allocate_array_instruction(destination_pointer, source_size_as_register);

self.brillig_context.copy_array_instruction(
source_pointer,
destination_pointer,
source_size_as_register,
// Retrieve the reference count and check if it equals 1
let reference_count = self.brillig_context.allocate_register();
let zero = self.brillig_context.make_constant(0_usize.into());
self.brillig_context.array_get(source_pointer, zero, reference_count);

let one = self.brillig_context.make_constant(1_usize.into());
let condition = self.brillig_context.allocate_register();
self.brillig_context.binary_instruction(
reference_count,
one,
condition,
BrilligBinaryOp::Field { op: BinaryFieldOp::Equals },
);

if let RegisterOrMemory::HeapVector(HeapVector { size: target_size, .. }) =
destination_variable
{
self.brillig_context.mov_instruction(target_size, source_size_as_register);
}
self.brillig_context.branch_instruction(condition, |ctx, cond| {
if cond {
// Reference count is 1, we can mutate the array directly
Self::store_variable_in_array_with_ctx(
ctx,
source_pointer,
index_register,
value_variable,
);
} else {
// Reference count is not 1, so we need to copy the array then set on the copy
let destination_variable =
self.variables.define_variable(self.function_context, ctx, destination_id, dfg);

// Then set the value in the newly created array
self.store_variable_in_array(destination_pointer, index_register, value_variable);
let destination_pointer = match destination_variable {
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => pointer,
RegisterOrMemory::HeapVector(HeapVector { pointer, .. }) => pointer,
_ => unreachable!("ICE: array set returns non-array"),
};

ctx.allocate_array_instruction(destination_pointer, source_size_as_register);

ctx.copy_array_instruction(
source_pointer,
destination_pointer,
source_size_as_register,
);

if let RegisterOrMemory::HeapVector(HeapVector { size: target_size, .. }) =
destination_variable
{
ctx.mov_instruction(target_size, source_size_as_register);
}

// Then set the value in the newly created array
Self::store_variable_in_array_with_ctx(
ctx,
destination_pointer,
index_register,
value_variable,
);
}
});

self.brillig_context.deallocate_register(source_size_as_register);
}
Expand All @@ -709,21 +753,35 @@ impl<'block> BrilligBlock<'block> {
destination_pointer: RegisterIndex,
index_register: RegisterIndex,
value_variable: RegisterOrMemory,
) {
Self::store_variable_in_array_with_ctx(
self.brillig_context,
destination_pointer,
index_register,
value_variable,
);
}

pub(crate) fn store_variable_in_array_with_ctx(
ctx: &mut BrilligContext,
destination_pointer: RegisterIndex,
index_register: RegisterIndex,
value_variable: RegisterOrMemory,
) {
match value_variable {
RegisterOrMemory::RegisterIndex(value_register) => {
self.brillig_context.array_set(destination_pointer, index_register, value_register);
ctx.array_set(destination_pointer, index_register, value_register);
}
RegisterOrMemory::HeapArray(HeapArray { pointer, .. }) => {
self.brillig_context.array_set(destination_pointer, index_register, pointer);
ctx.array_set(destination_pointer, index_register, pointer);
}
RegisterOrMemory::HeapVector(_) => {
// Vectors are stored as references inside arrays to be able to match SSA indexes
let reference = self.brillig_context.allocate_register();
self.brillig_context.allocate_variable_instruction(reference);
self.brillig_context.store_variable_instruction(reference, value_variable);
self.brillig_context.array_set(destination_pointer, index_register, reference);
self.brillig_context.deallocate_register(reference);
let reference = ctx.allocate_register();
ctx.allocate_variable_instruction(reference);
ctx.store_variable_instruction(reference, value_variable);
ctx.array_set(destination_pointer, index_register, reference);
ctx.deallocate_register(reference);
}
}
}
Expand Down Expand Up @@ -1086,6 +1144,11 @@ impl<'block> BrilligBlock<'block> {
RegisterOrMemory::HeapVector(heap_vector) => {
self.brillig_context
.const_instruction(heap_vector.size, array.len().into());

// Add one to the vector's size to account for the extra reference count field
let size = self.brillig_context.allocate_register();
self.brillig_context.increment(heap_vector.size, size);

self.brillig_context
.allocate_array_instruction(heap_vector.pointer, heap_vector.size);

Expand All @@ -1105,6 +1168,7 @@ impl<'block> BrilligBlock<'block> {
let element_variable = self.convert_ssa_value(*element_id, dfg);
// Store the item in memory
self.store_variable_in_array(pointer, iterator_register, element_variable);

// Increment the iterator
self.brillig_context.usize_op_in_place(
iterator_register,
Expand All @@ -1118,7 +1182,7 @@ impl<'block> BrilligBlock<'block> {
new_variable
}
}
_ => {
Value::Function(_) | Value::Intrinsic(_) | Value::ForeignFunction(_) => {
todo!("ICE: Cannot convert value {value:?}")
}
}
Expand Down Expand Up @@ -1220,7 +1284,7 @@ pub(crate) fn type_of_binary_operation(lhs_type: &Type, rhs_type: &Type) -> Type
(_, Type::Function) | (Type::Function, _) => {
unreachable!("Functions are invalid in binary operations")
}
(_, Type::Reference) | (Type::Reference, _) => {
(_, Type::Reference(_)) | (Type::Reference(_), _) => {
unreachable!("References are invalid in binary operations")
}
(_, Type::Array(..)) | (Type::Array(..), _) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,19 @@ pub(crate) fn allocate_value(
let typ = dfg.type_of_value(value_id);

match typ {
Type::Numeric(_) | Type::Reference => {
Type::Numeric(_) | Type::Reference(_) => {
let register = brillig_context.allocate_register();
RegisterOrMemory::RegisterIndex(register)
}
Type::Array(item_typ, elem_count) => {
let pointer_register = brillig_context.allocate_register();
let pointer = brillig_context.allocate_register();
let size = compute_array_length(&item_typ, elem_count);
RegisterOrMemory::HeapArray(HeapArray { pointer: pointer_register, size })
RegisterOrMemory::HeapArray(HeapArray { pointer, size })
}
Type::Slice(_) => {
let pointer_register = brillig_context.allocate_register();
let size_register = brillig_context.allocate_register();
RegisterOrMemory::HeapVector(HeapVector {
pointer: pointer_register,
size: size_register,
})
let pointer = brillig_context.allocate_register();
let size = brillig_context.allocate_register();
RegisterOrMemory::HeapVector(HeapVector { pointer, size })
}
Type::Function => {
unreachable!("ICE: Function values should have been removed from the SSA")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl FunctionContext {

fn ssa_type_to_parameter(typ: &Type) -> BrilligParameter {
match typ {
Type::Numeric(_) | Type::Reference => BrilligParameter::Simple,
Type::Numeric(_) | Type::Reference(_) => BrilligParameter::Simple,
Type::Array(item_type, size) => BrilligParameter::Array(
vecmap(item_type.iter(), |item_typ| {
FunctionContext::ssa_type_to_parameter(item_typ)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ mod test {
let v0 = builder.add_parameter(Type::field());
let v1 = builder.add_parameter(Type::field());

let v3 = builder.insert_allocate();
let v3 = builder.insert_allocate(Type::field());

let zero = builder.numeric_constant(0u128, Type::field());
builder.insert_store(v3, zero);
Expand Down Expand Up @@ -439,7 +439,7 @@ mod test {
let v0 = builder.add_parameter(Type::field());
let v1 = builder.add_parameter(Type::field());

let v3 = builder.insert_allocate();
let v3 = builder.insert_allocate(Type::field());

let zero = builder.numeric_constant(0u128, Type::field());
builder.insert_store(v3, zero);
Expand Down
Loading
Loading