Skip to content

Commit

Permalink
chore: pull out foreign call nested array changes (#7216)
Browse files Browse the repository at this point in the history
This PR pulls out the changes related to returning nested arrays from
foreign calls.
  • Loading branch information
TomAFrench authored Jun 27, 2024
1 parent 412f02e commit 1faaaf5
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 45 deletions.
155 changes: 112 additions & 43 deletions noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,60 +482,69 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
destinations.iter().zip(destination_value_types).zip(&values)
{
match (destination, value_type) {
(ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => {
match output {
ForeignCallParam::Single(value) => {
self.write_value_to_memory(*value_index, value, *bit_size)?;
}
_ => return Err(format!(
"Function result size does not match brillig bytecode. Expected 1 result but got {output:?}")
),
(ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(bit_size)) => {
match output {
ForeignCallParam::Single(value) => {
self.write_value_to_memory(*value_index, value, *bit_size)?;
}
_ => return Err(format!(
"Function result size does not match brillig bytecode. Expected 1 result but got {output:?}")
),
}
(
ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }),
HeapValueType::Array { value_types, size: type_size },
) if size == type_size => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
return Err("Foreign call result array doesn't match expected size".to_string());
}
}
(
ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }),
HeapValueType::Array { value_types, size: type_size },
) if size == type_size => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
if values.len() != *size {
// foreign call returning flattened values into a nested type, so the sizes do not match
let destination = self.memory.read_ref(*pointer_index);
let return_type = value_type;
let mut flatten_values_idx = 0; //index of values read from flatten_values
self.write_slice_of_values_to_memory(destination, &output.fields(), &mut flatten_values_idx, return_type)?;
} else {
self.write_values_to_memory_slice(*pointer_index, values, value_types)?;
}
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
} else {
unimplemented!("deflattening heap arrays from foreign calls");
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
}
(
ValueOrArray::HeapVector(HeapVector {pointer: pointer_index, size: size_index }),
HeapValueType::Vector { value_types },
) => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size address
self.memory.write(*size_index, values.len().into());
} else {
// foreign call returning flattened values into a nested type, so the sizes do not match
let destination = self.memory.read_ref(*pointer_index);
let return_type = value_type;
let mut flatten_values_idx = 0; //index of values read from flatten_values
self.write_slice_of_values_to_memory(destination, &output.fields(), &mut flatten_values_idx, return_type)?;
}
}
(
ValueOrArray::HeapVector(HeapVector {pointer: pointer_index, size: size_index }),
HeapValueType::Vector { value_types },
) => {
if HeapValueType::all_simple(value_types) {
match output {
ForeignCallParam::Array(values) => {
// Set our size in the size address
self.memory.write(*size_index, values.len().into());
self.write_values_to_memory_slice(*pointer_index, values, value_types)?;

self.write_values_to_memory_slice(*pointer_index, values, value_types)?;
}
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
} else {
unimplemented!("deflattening heap vectors from foreign calls");
_ => {
return Err("Function result size does not match brillig bytecode size".to_string());
}
}
}
_ => {
return Err(format!("Unexpected value type {value_type:?} for destination {destination:?}"));
} else {
unimplemented!("deflattening heap vectors from foreign calls");
}
}
_ => {
return Err(format!("Unexpected value type {value_type:?} for destination {destination:?}"));
}
}
}

let _ =
Expand Down Expand Up @@ -596,6 +605,66 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
Ok(())
}

/// Writes flattened values to memory, using the provided type
/// Function calls itself recursively in order to work with recursive types (nested arrays)
/// values_idx is the current index in the values vector and is incremented every time
/// a value is written to memory
/// The function returns the address of the next value to be written
fn write_slice_of_values_to_memory(
&mut self,
destination: MemoryAddress,
values: &Vec<F>,
values_idx: &mut usize,
value_type: &HeapValueType,
) -> Result<MemoryAddress, String> {
let mut current_pointer = destination;
match value_type {
HeapValueType::Simple(bit_size) => {
self.write_value_to_memory(destination, &values[*values_idx], *bit_size)?;
*values_idx += 1;
Ok(MemoryAddress(destination.to_usize() + 1))
}
HeapValueType::Array { value_types, size } => {
for _ in 0..*size {
for typ in value_types {
match typ {
HeapValueType::Simple(len) => {
self.write_value_to_memory(
current_pointer,
&values[*values_idx],
*len,
)?;
*values_idx += 1;
current_pointer = MemoryAddress(current_pointer.to_usize() + 1);
}
HeapValueType::Array { .. } => {
let destination = self.memory.read_ref(current_pointer);
let destination = self.memory.read_ref(destination);
self.write_slice_of_values_to_memory(
destination,
values,
values_idx,
typ,
)?;
current_pointer = MemoryAddress(current_pointer.to_usize() + 1);
}
HeapValueType::Vector { .. } => {
return Err(format!(
"Unsupported returned type in foreign calls {:?}",
typ
));
}
}
}
}
Ok(current_pointer)
}
HeapValueType::Vector { .. } => {
Err(format!("Unsupported returned type in foreign calls {:?}", value_type))
}
}
}

/// Process a binary operation.
/// This method will not modify the program counter.
fn process_binary_field_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1737,8 +1737,7 @@ impl<'block> BrilligBlock<'block> {
dfg,
);
let array = variable.extract_array();
self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size);
self.brillig_context.usize_const_instruction(array.rc, 1_usize.into());
self.allocate_nested_array(typ, Some(array));

variable
}
Expand All @@ -1765,6 +1764,43 @@ impl<'block> BrilligBlock<'block> {
}
}

fn allocate_nested_array(
&mut self,
typ: &Type,
array: Option<BrilligArray>,
) -> BrilligVariable {
match typ {
Type::Array(types, size) => {
let array = array.unwrap_or(BrilligArray {
pointer: self.brillig_context.allocate_register(),
size: *size,
rc: self.brillig_context.allocate_register(),
});
self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size);
self.brillig_context.usize_const_instruction(array.rc, 1_usize.into());

let mut index = 0_usize;
for _ in 0..*size {
for element_type in types.iter() {
match element_type {
Type::Array(_, _) => {
let inner_array = self.allocate_nested_array(element_type, None);
let idx =
self.brillig_context.make_usize_constant_instruction(index.into());
self.store_variable_in_array(array.pointer, idx, inner_array);
}
Type::Slice(_) => unreachable!("ICE: unsupported slice type in allocate_nested_array(), expects an array or a numeric type"),
_ => (),
}
index += 1;
}
}
BrilligVariable::BrilligArray(array)
}
_ => unreachable!("ICE: allocate_nested_array() expects an array, got {typ:?}"),
}
}

/// Gets the "user-facing" length of an array.
/// An array of structs with two fields would be stored as an 2 * array.len() array/vector.
/// So we divide the length by the number of subitems in an item to get the user-facing length.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "regression_4561"
type = "bin"
authors = [""]

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Regression test for issue #4561
use std::test::OracleMock;

type TReturnElem = [Field; 3];
type TReturn = [TReturnElem; 2];

#[oracle(simple_nested_return)]
unconstrained fn simple_nested_return_oracle() -> TReturn {}

unconstrained fn simple_nested_return_unconstrained() -> TReturn {
simple_nested_return_oracle()
}

#[test]
fn test_simple_nested_return() {
OracleMock::mock("simple_nested_return").returns([1, 2, 3, 4, 5, 6]);
assert_eq(simple_nested_return_unconstrained(), [[1, 2, 3], [4, 5, 6]]);
}

#[oracle(nested_with_fields_return)]
unconstrained fn nested_with_fields_return_oracle() -> (Field, TReturn, Field) {}

unconstrained fn nested_with_fields_return_unconstrained() -> (Field, TReturn, Field) {
nested_with_fields_return_oracle()
}

#[test]
fn test_nested_with_fields_return() {
OracleMock::mock("nested_with_fields_return").returns((0, [1, 2, 3, 4, 5, 6], 7));
assert_eq(nested_with_fields_return_unconstrained(), (0, [[1, 2, 3], [4, 5, 6]], 7));
}

#[oracle(two_nested_return)]
unconstrained fn two_nested_return_oracle() -> (Field, TReturn, Field, TReturn) {}

unconstrained fn two_nested_return_unconstrained() -> (Field, TReturn, Field, TReturn) {
two_nested_return_oracle()
}

#[test]
fn two_nested_return() {
OracleMock::mock("two_nested_return").returns((0, [1, 2, 3, 4, 5, 6], 7, [1, 2, 3, 4, 5, 6]));
assert_eq(two_nested_return_unconstrained(), (0, [[1, 2, 3], [4, 5, 6]], 7, [[1, 2, 3], [4, 5, 6]]));
}

#[oracle(foo_return)]
unconstrained fn foo_return() -> (Field, TReturn, TestTypeFoo) {}
unconstrained fn foo_return_unconstrained() -> (Field, TReturn, TestTypeFoo) {
foo_return()
}

struct TestTypeFoo {
a: Field,
b: [[[Field; 3]; 4]; 2],
c: [TReturnElem; 2],
d: TReturnElem,
}

#[test]
fn complexe_struct_return() {
OracleMock::mock("foo_return").returns(
(
0, [1, 2, 3, 4, 5, 6], 7, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6]
)
);
let foo_x = foo_return_unconstrained();
assert_eq((foo_x.0, foo_x.1), (0, [[1, 2, 3], [4, 5, 6]]));
assert_eq(foo_x.2.a, 7);
assert_eq(
foo_x.2.b, [
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], [[13, 14, 15], [16, 17, 18], [19, 20, 21], [22, 23, 24]]
]
);
let a: TReturnElem = [1, 2, 3];
let b: TReturnElem = [4, 5, 6];
assert_eq(foo_x.2.c, [a, b]);
assert_eq(foo_x.2.d, a);
}

0 comments on commit 1faaaf5

Please sign in to comment.