Skip to content

Commit

Permalink
feat: place return value witnesses directly after function arguments (#…
Browse files Browse the repository at this point in the history
…5142)

# Description

## Problem\*

Resolves #5104 
## Summary\*

This PR preallocates some witnesses to hold the return values at the
beginning of ACIR gen and then adds assertions to fill these witnesses
with the return values. This ensures that the return values will be
placed in the witness map directly after any function inputs (reasons
for this being desirable are laid out in #5104)

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
TomAFrench and jfecher authored May 31, 2024
1 parent d6122eb commit 1252b5f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 68 deletions.
13 changes: 3 additions & 10 deletions compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl AcirContext {
}

/// Converts an [`AcirVar`] to a [`Witness`]
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
pub(crate) fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let expression = self.var_to_expression(var)?;
let witness = if let Some(constant) = expression.to_const() {
// Check if a witness has been assigned this value already, if so reuse it.
Expand Down Expand Up @@ -1027,15 +1027,6 @@ impl AcirContext {
Ok(remainder)
}

/// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the
/// `GeneratedAcir`'s return witnesses.
pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> {
let return_var = self.get_or_create_witness_var(acir_var)?;
let witness = self.var_to_witness(return_var)?;
self.acir_ir.push_return_witness(witness);
Ok(())
}

/// Constrains the `AcirVar` variable to be of type `NumericType`.
pub(crate) fn range_constrain_var(
&mut self,
Expand Down Expand Up @@ -1538,9 +1529,11 @@ impl AcirContext {
pub(crate) fn finish(
mut self,
inputs: Vec<Witness>,
return_values: Vec<Witness>,
warnings: Vec<SsaReport>,
) -> GeneratedAcir {
self.acir_ir.input_witnesses = inputs;
self.acir_ir.return_witnesses = return_values;
self.acir_ir.warnings = warnings;
self.acir_ir
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ pub(crate) struct GeneratedAcir {
opcodes: Vec<AcirOpcode<FieldElement>>,

/// All witness indices that comprise the final return value of the program
///
/// Note: This may contain repeated indices, which is necessary for later mapping into the
/// abi's return type.
pub(crate) return_witnesses: Vec<Witness>,

/// All witness indices which are inputs to the main function
Expand Down Expand Up @@ -164,11 +161,6 @@ impl GeneratedAcir {

fresh_witness
}

/// Adds a witness index to the program's return witnesses.
pub(crate) fn push_return_witness(&mut self, witness: Witness) {
self.return_witnesses.push(witness);
}
}

impl GeneratedAcir {
Expand Down
122 changes: 72 additions & 50 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ use acvm::acir::circuit::brillig::BrilligBytecode;
use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation};
use acvm::acir::native_types::Witness;
use acvm::acir::BlackBoxFunc;
use acvm::{
acir::AcirField,
acir::{circuit::opcodes::BlockId, native_types::Expression},
FieldElement,
};
use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement};
use fxhash::FxHashMap as HashMap;
use im::Vector;
use iter_extended::{try_vecmap, vecmap};
Expand Down Expand Up @@ -330,38 +326,10 @@ impl Ssa {
bytecode: brillig.byte_code,
});

let runtime_types = self.functions.values().map(|function| function.runtime());
for (acir, runtime_type) in acirs.iter_mut().zip(runtime_types) {
if matches!(runtime_type, RuntimeType::Acir(_)) {
generate_distinct_return_witnesses(acir);
}
}

Ok((acirs, brillig, self.error_selector_to_type))
}
}

fn generate_distinct_return_witnesses(acir: &mut GeneratedAcir) {
// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
// layout for serializing those types as if they were being passed as inputs.
//
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
// working with rather than following the standard ABI encoding rules.
//
// TODO: We're being conservative here by generating a new witness for every expression.
// This means that we're likely to get a number of constraints which are just renumbering witnesses.
// This can be tackled by:
// - Tracking the last assigned public input witness and only renumbering a witness if it is below this value.
// - Modifying existing constraints to rearrange their outputs so they are suitable
// - See: https://github.com/noir-lang/noir/pull/4467
let distinct_return_witness = vecmap(acir.return_witnesses.clone(), |return_witness| {
acir.create_witness_for_expression(&Expression::from(return_witness))
});

acir.return_witnesses = distinct_return_witness;
}

impl<'a> Context<'a> {
fn new(shared_context: &'a mut SharedContext) -> Context<'a> {
let mut acir_context = AcirContext::default();
Expand Down Expand Up @@ -422,15 +390,45 @@ impl<'a> Context<'a> {
let dfg = &main_func.dfg;
let entry_block = &dfg[main_func.entry_block()];
let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?;
let num_return_witnesses =
self.get_num_return_witnesses(entry_block.unwrap_terminator(), dfg);

// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
// layout for serializing those types as if they were being passed as inputs.
//
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
// working with rather than following the standard ABI encoding rules.
//
// We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of
// the function's return values can then be determined through knowledge of its ABI alone.
let return_witness_vars =
vecmap(0..num_return_witnesses, |_| self.acir_context.add_variable());

let return_witnesses = vecmap(&return_witness_vars, |return_var| {
let expr = self.acir_context.var_to_expression(*return_var).unwrap();
expr.to_witness().expect("return vars should be witnesses")
});

self.data_bus = dfg.data_bus.to_owned();
let mut warnings = Vec::new();
for instruction_id in entry_block.instructions() {
warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?);
}

warnings.extend(self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?);
Ok(self.acir_context.finish(input_witness, warnings))
let (return_vars, return_warnings) =
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;

// TODO: This is a naive method of assigning the return values to their witnesses as
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
//
// We should search through the program and relabel these witnesses so we can remove this constraint.
for (witness_var, return_var) in return_witness_vars.iter().zip(return_vars) {
self.acir_context.assert_eq_var(*witness_var, return_var, None)?;
}

warnings.extend(return_warnings);
Ok(self.acir_context.finish(input_witness, return_witnesses, warnings))
}

fn convert_brillig_main(
Expand Down Expand Up @@ -468,17 +466,13 @@ impl<'a> Context<'a> {
)?;
self.shared_context.insert_generated_brillig(main_func.id(), arguments, 0, code);

let output_vars: Vec<_> = output_values
let return_witnesses: Vec<Witness> = output_values
.iter()
.flat_map(|value| value.clone().flatten())
.map(|value| value.0)
.collect();
.map(|(value, _)| self.acir_context.var_to_witness(value))
.collect::<Result<_, _>>()?;

for acir_var in output_vars {
self.acir_context.return_var(acir_var)?;
}

let generated_acir = self.acir_context.finish(witness_inputs, Vec::new());
let generated_acir = self.acir_context.finish(witness_inputs, return_witnesses, Vec::new());

assert_eq!(
generated_acir.opcodes().len(),
Expand Down Expand Up @@ -1724,12 +1718,39 @@ impl<'a> Context<'a> {
self.define_result(dfg, instruction, AcirValue::Var(result, typ));
}

/// Converts an SSA terminator's return values into their ACIR representations
fn get_num_return_witnesses(
&mut self,
terminator: &TerminatorInstruction,
dfg: &DataFlowGraph,
) -> usize {
let return_values = match terminator {
TerminatorInstruction::Return { return_values, .. } => return_values,
// TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions
_ => unreachable!("ICE: Program must have a singular return"),
};

return_values.iter().fold(0, |acc, value_id| {
let is_databus = self
.data_bus
.return_data
.map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]);

if is_databus {
// We do not return value for the data bus.
acc
} else {
acc + dfg.type_of_value(*value_id).flattened_size()
}
})
}

/// Converts an SSA terminator's return values into their ACIR representations
fn convert_ssa_return(
&mut self,
terminator: &TerminatorInstruction,
dfg: &DataFlowGraph,
) -> Result<Vec<SsaReport>, RuntimeError> {
) -> Result<(Vec<AcirVar>, Vec<SsaReport>), RuntimeError> {
let (return_values, call_stack) = match terminator {
TerminatorInstruction::Return { return_values, call_stack } => {
(return_values, call_stack.clone())
Expand All @@ -1739,6 +1760,7 @@ impl<'a> Context<'a> {
};

let mut has_constant_return = false;
let mut return_vars: Vec<AcirVar> = Vec::new();
for value_id in return_values {
let is_databus = self
.data_bus
Expand All @@ -1759,7 +1781,7 @@ impl<'a> Context<'a> {
dfg,
)?;
} else {
self.acir_context.return_var(acir_var)?;
return_vars.push(acir_var);
}
}
}
Expand All @@ -1770,7 +1792,7 @@ impl<'a> Context<'a> {
Vec::new()
};

Ok(warnings)
Ok((return_vars, warnings))
}

/// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does
Expand Down Expand Up @@ -3079,8 +3101,8 @@ mod test {
check_call_opcode(
&func_with_nested_call_opcodes[1],
2,
vec![Witness(2), Witness(1)],
vec![Witness(3)],
vec![Witness(3), Witness(1)],
vec![Witness(4)],
);
}

Expand All @@ -3100,13 +3122,13 @@ mod test {
for (expected_input, input) in expected_inputs.iter().zip(inputs) {
assert_eq!(
expected_input, input,
"Expected witness {expected_input:?} but got {input:?}"
"Expected input witness {expected_input:?} but got {input:?}"
);
}
for (expected_output, output) in expected_outputs.iter().zip(outputs) {
assert_eq!(
expected_output, output,
"Expected witness {expected_output:?} but got {output:?}"
"Expected output witness {expected_output:?} but got {output:?}"
);
}
}
Expand Down

0 comments on commit 1252b5f

Please sign in to comment.