Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: place return value witnesses directly after function arguments #5142

Merged
merged 6 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl AcirContext {
}

/// Converts an [`AcirVar`] to a [`Witness`]
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
pub(crate) fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let expression = self.var_to_expression(var)?;
let witness = if let Some(constant) = expression.to_const() {
// Check if a witness has been assigned this value already, if so reuse it.
Expand Down Expand Up @@ -1017,15 +1017,6 @@ impl AcirContext {
Ok(remainder)
}

/// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the
/// `GeneratedAcir`'s return witnesses.
pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> {
let return_var = self.get_or_create_witness_var(acir_var)?;
let witness = self.var_to_witness(return_var)?;
self.acir_ir.push_return_witness(witness);
Ok(())
}

/// Constrains the `AcirVar` variable to be of type `NumericType`.
pub(crate) fn range_constrain_var(
&mut self,
Expand Down Expand Up @@ -1528,9 +1519,11 @@ impl AcirContext {
pub(crate) fn finish(
mut self,
inputs: Vec<Witness>,
return_values: Vec<Witness>,
warnings: Vec<SsaReport>,
) -> GeneratedAcir {
self.acir_ir.input_witnesses = inputs;
self.acir_ir.return_witnesses = return_values;
self.acir_ir.warnings = warnings;
self.acir_ir
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ pub(crate) struct GeneratedAcir {
opcodes: Vec<AcirOpcode<FieldElement>>,

/// All witness indices that comprise the final return value of the program
///
/// Note: This may contain repeated indices, which is necessary for later mapping into the
/// abi's return type.
pub(crate) return_witnesses: Vec<Witness>,

/// All witness indices which are inputs to the main function
Expand Down Expand Up @@ -164,11 +161,6 @@ impl GeneratedAcir {

fresh_witness
}

/// Adds a witness index to the program's return witnesses.
pub(crate) fn push_return_witness(&mut self, witness: Witness) {
self.return_witnesses.push(witness);
}
}

impl GeneratedAcir {
Expand Down
122 changes: 72 additions & 50 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@
use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation};
use acvm::acir::native_types::Witness;
use acvm::acir::BlackBoxFunc;
use acvm::{
acir::AcirField,
acir::{circuit::opcodes::BlockId, native_types::Expression},
FieldElement,
};
use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement};
use fxhash::FxHashMap as HashMap;
use im::Vector;
use iter_extended::{try_vecmap, vecmap};
Expand Down Expand Up @@ -285,7 +281,7 @@
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn into_acir(self, brillig: &Brillig) -> Result<Artifacts, RuntimeError> {
let mut acirs = Vec::new();
// TODO: can we parallelise this?

Check warning on line 284 in compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (parallelise)
let mut shared_context = SharedContext::default();
for function in self.functions.values() {
let context = Context::new(&mut shared_context);
Expand Down Expand Up @@ -330,38 +326,10 @@
bytecode: brillig.byte_code,
});

let runtime_types = self.functions.values().map(|function| function.runtime());
for (acir, runtime_type) in acirs.iter_mut().zip(runtime_types) {
if matches!(runtime_type, RuntimeType::Acir(_)) {
generate_distinct_return_witnesses(acir);
}
}

Ok((acirs, brillig, self.error_selector_to_type))
}
}

fn generate_distinct_return_witnesses(acir: &mut GeneratedAcir) {
// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
// layout for serializing those types as if they were being passed as inputs.
//
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
// working with rather than following the standard ABI encoding rules.
//
// TODO: We're being conservative here by generating a new witness for every expression.
// This means that we're likely to get a number of constraints which are just renumbering witnesses.
// This can be tackled by:
// - Tracking the last assigned public input witness and only renumbering a witness if it is below this value.
// - Modifying existing constraints to rearrange their outputs so they are suitable
// - See: https://github.com/noir-lang/noir/pull/4467
let distinct_return_witness = vecmap(acir.return_witnesses.clone(), |return_witness| {
acir.create_witness_for_expression(&Expression::from(return_witness))
});

acir.return_witnesses = distinct_return_witness;
}

impl<'a> Context<'a> {
fn new(shared_context: &'a mut SharedContext) -> Context<'a> {
let mut acir_context = AcirContext::default();
Expand Down Expand Up @@ -422,15 +390,45 @@
let dfg = &main_func.dfg;
let entry_block = &dfg[main_func.entry_block()];
let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?;
let num_return_witnesses =
self.get_num_return_witnesses(entry_block.unwrap_terminator(), dfg);

// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
// layout for serializing those types as if they were being passed as inputs.
//
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
// working with rather than following the standard ABI encoding rules.
//
// We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of
// the function's return values can then be determined through knowledge of its ABI alone.
let return_witness_vars =
vecmap(0..num_return_witnesses, |_| self.acir_context.add_variable());

let return_witnesses = vecmap(&return_witness_vars, |return_var| {
let expr = self.acir_context.var_to_expression(*return_var).unwrap();
expr.to_witness().expect("return vars should be witnesses")
});

self.data_bus = dfg.data_bus.to_owned();
let mut warnings = Vec::new();
for instruction_id in entry_block.instructions() {
warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?);
}

warnings.extend(self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?);
Ok(self.acir_context.finish(input_witness, warnings))
let (return_vars, return_warnings) =
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;

// TODO: This is a naive method of assigning the return values to their witnesses as
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
//
// We should search through the program and relabel these witnesses so we can remove this constraint.
for (witness_var, return_var) in return_witness_vars.iter().zip(return_vars) {
self.acir_context.assert_eq_var(*witness_var, return_var, None)?;
}

warnings.extend(return_warnings);
Ok(self.acir_context.finish(input_witness, return_witnesses, warnings))
}

fn convert_brillig_main(
Expand Down Expand Up @@ -468,17 +466,13 @@
)?;
self.shared_context.insert_generated_brillig(main_func.id(), arguments, 0, code);

let output_vars: Vec<_> = output_values
let return_witnesses: Vec<Witness> = output_values
.iter()
.flat_map(|value| value.clone().flatten())
.map(|value| value.0)
.collect();
.map(|(value, _)| self.acir_context.var_to_witness(value))
.collect::<Result<_, _>>()?;

for acir_var in output_vars {
self.acir_context.return_var(acir_var)?;
}

let generated_acir = self.acir_context.finish(witness_inputs, Vec::new());
let generated_acir = self.acir_context.finish(witness_inputs, return_witnesses, Vec::new());

assert_eq!(
generated_acir.opcodes().len(),
Expand Down Expand Up @@ -1724,12 +1718,39 @@
self.define_result(dfg, instruction, AcirValue::Var(result, typ));
}

/// Converts an SSA terminator's return values into their ACIR representations
fn get_num_return_witnesses(
&mut self,
terminator: &TerminatorInstruction,
dfg: &DataFlowGraph,
) -> usize {
let return_values = match terminator {
TerminatorInstruction::Return { return_values, .. } => return_values,
// TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions
_ => unreachable!("ICE: Program must have a singular return"),
};

return_values.iter().fold(0, |acc, value_id| {
let is_databus = self
.data_bus
.return_data
.map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]);

if is_databus {
// We do not return value for the data bus.
acc
} else {
acc + dfg.type_of_value(*value_id).flattened_size()
}
})
}

/// Converts an SSA terminator's return values into their ACIR representations
fn convert_ssa_return(
&mut self,
terminator: &TerminatorInstruction,
dfg: &DataFlowGraph,
) -> Result<Vec<SsaReport>, RuntimeError> {
) -> Result<(Vec<AcirVar>, Vec<SsaReport>), RuntimeError> {
let (return_values, call_stack) = match terminator {
TerminatorInstruction::Return { return_values, call_stack } => {
(return_values, call_stack.clone())
Expand All @@ -1739,6 +1760,7 @@
};

let mut has_constant_return = false;
let mut return_vars: Vec<AcirVar> = Vec::new();
for value_id in return_values {
let is_databus = self
.data_bus
Expand All @@ -1759,7 +1781,7 @@
dfg,
)?;
} else {
self.acir_context.return_var(acir_var)?;
return_vars.push(acir_var);
}
}
}
Expand All @@ -1770,7 +1792,7 @@
Vec::new()
};

Ok(warnings)
Ok((return_vars, warnings))
}

/// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does
Expand Down Expand Up @@ -3079,8 +3101,8 @@
check_call_opcode(
&func_with_nested_call_opcodes[1],
2,
vec![Witness(2), Witness(1)],
vec![Witness(3)],
vec![Witness(3), Witness(1)],
vec![Witness(4)],
);
}

Expand All @@ -3100,13 +3122,13 @@
for (expected_input, input) in expected_inputs.iter().zip(inputs) {
assert_eq!(
expected_input, input,
"Expected witness {expected_input:?} but got {input:?}"
"Expected input witness {expected_input:?} but got {input:?}"
);
}
for (expected_output, output) in expected_outputs.iter().zip(outputs) {
assert_eq!(
expected_output, output,
"Expected witness {expected_output:?} but got {output:?}"
"Expected output witness {expected_output:?} but got {output:?}"
);
}
}
Expand Down
Loading