From 57ae6eb7f93fca4040fd38bab5be496f21c322d8 Mon Sep 17 00:00:00 2001 From: TomAFrench Date: Wed, 21 Jun 2023 15:14:35 +0000 Subject: [PATCH] feat(acvm)!: Support stepwise execution of ACIR --- acvm/src/pwg/block.rs | 36 +++----- acvm/src/pwg/mod.rs | 203 ++++++++++++++++++++---------------------- acvm/tests/solver.rs | 45 ++++------ 3 files changed, 123 insertions(+), 161 deletions(-) diff --git a/acvm/src/pwg/block.rs b/acvm/src/pwg/block.rs index 05ef57ef4..463b8e589 100644 --- a/acvm/src/pwg/block.rs +++ b/acvm/src/pwg/block.rs @@ -16,20 +16,9 @@ use super::{OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError}; /// block_value is the value of the Block at the solved_operations step /// solved_operations is the number of solved elements in the block #[derive(Default)] -pub(super) struct BlockSolver { - block_value: HashMap, - solved_operations: usize, -} +pub(super) struct BlockSolver; impl BlockSolver { - fn insert_value(&mut self, index: u32, value: FieldElement) { - self.block_value.insert(index, value); - } - - fn get_value(&self, index: u32) -> Option { - self.block_value.get(&index).copied() - } - // Helper function which tries to solve a Block opcode // As long as operations are resolved, we update/read from the block_value // We stop when an operation cannot be resolved @@ -44,7 +33,9 @@ impl BlockSolver { )) }; - for block_op in trace.iter().skip(self.solved_operations) { + let mut block_value: HashMap = HashMap::new(); + + for block_op in trace.iter() { let op_expr = ArithmeticSolver::evaluate(&block_op.operation, initial_witness); let operation = op_expr.to_const().ok_or_else(|| { missing_assignment(ArithmeticSolver::any_witness_from_expression(&op_expr)) @@ -57,23 +48,24 @@ impl BlockSolver { let value = ArithmeticSolver::evaluate(&block_op.value, initial_witness); let value_witness = ArithmeticSolver::any_witness_from_expression(&value); if value.is_const() { - self.insert_value(index, value.q_c); + block_value.insert(index, value.q_c); } else if operation.is_zero() && value.is_linear() { match ArithmeticSolver::solve_fan_in_term(&value, initial_witness) { GateStatus::GateUnsolvable => return Err(missing_assignment(value_witness)), GateStatus::GateSolvable(sum, (coef, w)) => { - let map_value = - self.get_value(index).ok_or_else(|| missing_assignment(Some(w)))?; + let map_value = block_value + .get(&index) + .copied() + .ok_or_else(|| missing_assignment(Some(w)))?; insert_value(&w, (map_value - sum - value.q_c) / coef, initial_witness)?; } GateStatus::GateSatisfied(sum) => { - self.insert_value(index, sum + value.q_c); + block_value.insert(index, sum + value.q_c); } } } else { return Err(missing_assignment(value_witness)); } - self.solved_operations += 1; } Ok(()) } @@ -86,16 +78,10 @@ impl BlockSolver { initial_witness: &mut WitnessMap, trace: &[MemOp], ) -> Result { - let initial_solved_operations = self.solved_operations; - match self.solve_helper(initial_witness, trace) { Ok(()) => Ok(OpcodeResolution::Solved), Err(OpcodeResolutionError::OpcodeNotSolvable(err)) => { - if self.solved_operations > initial_solved_operations { - Ok(OpcodeResolution::InProgress) - } else { - Ok(OpcodeResolution::Stalled(err)) - } + Ok(OpcodeResolution::Stalled(err)) } Err(err) => Err(err), } diff --git a/acvm/src/pwg/mod.rs b/acvm/src/pwg/mod.rs index 617028b1a..af9a71a94 100644 --- a/acvm/src/pwg/mod.rs +++ b/acvm/src/pwg/mod.rs @@ -1,10 +1,8 @@ // Re-usable methods that backends can use to implement their PWG -use std::collections::HashMap; - use acir::{ brillig_vm::ForeignCallResult, - circuit::{brillig::Brillig, opcodes::BlockId, Opcode}, + circuit::{brillig::Brillig, Opcode}, native_types::{Expression, Witness, WitnessMap}, BlackBoxFunc, FieldElement, }; @@ -46,7 +44,7 @@ pub enum ACVMStatus { /// to the ACVM using [`ACVM::resolve_pending_foreign_call`]. /// /// Once this is done, the ACVM can be restarted to solve the remaining opcodes. - RequiresForeignCall, + RequiresForeignCall(UnresolvedBrilligCall), } #[derive(Debug, PartialEq)] @@ -95,17 +93,13 @@ pub struct ACVM { status: ACVMStatus, backend: B, - /// Stores the solver for each [block][`Opcode::Block`] opcode. This persists their internal state to prevent recomputation. - block_solvers: HashMap, + /// A list of opcodes which are to be executed by the ACVM. - /// - /// Note that this doesn't include any opcodes which are waiting on a pending foreign call. opcodes: Vec, + /// Index of the next opcode to be executed. + instruction_pointer: usize, witness_map: WitnessMap, - - /// A list of foreign calls which must be resolved before the ACVM can resume execution. - pending_foreign_calls: Vec, } impl ACVM { @@ -113,27 +107,12 @@ impl ACVM { ACVM { status: ACVMStatus::InProgress, backend, - block_solvers: HashMap::default(), opcodes, + instruction_pointer: 0, witness_map: initial_witness, - pending_foreign_calls: Vec::new(), } } - /// Returns a reference to the current state of the ACVM's [`WitnessMap`]. - /// - /// Once execution has completed, the witness map can be extracted using [`ACVM::finalize`] - pub fn witness_map(&self) -> &WitnessMap { - &self.witness_map - } - - /// Returns a slice containing the opcodes which remain to be solved. - /// - /// Note: this doesn't include any opcodes which are waiting on a pending foreign call. - pub fn unresolved_opcodes(&self) -> &[Opcode] { - &self.opcodes - } - /// Updates the current status of the VM. /// Returns the given status. fn status(&mut self, status: ACVMStatus) -> ACVMStatus { @@ -147,9 +126,41 @@ impl ACVM { self.status(ACVMStatus::Failure(error)) } + /// Sets the status of the VM to `RequiresForeignCall`. + /// Indicating that the VM is now waiting for a foreign call to be resolved. + fn wait_for_foreign_call( + &mut self, + opcode: Opcode, + foreign_call_wait_info: ForeignCallWaitInfo, + ) -> ACVMStatus { + let brillig = match opcode { + Opcode::Brillig(brillig) => brillig, + _ => unreachable!("Brillig resolution for non brillig opcode"), + }; + let foreign_call = UnresolvedBrilligCall { brillig, foreign_call_wait_info }; + self.status(ACVMStatus::RequiresForeignCall(foreign_call)) + } + + /// Returns a reference to the current state of the ACVM's [`WitnessMap`]. + /// + /// Once execution has completed, the witness map can be extracted using [`ACVM::finalize`] + pub fn witness_map(&self) -> &WitnessMap { + &self.witness_map + } + + /// Returns a slice containing the opcodes of the circuit being executed. + pub fn opcodes(&self) -> &[Opcode] { + &self.opcodes + } + + /// Returns the index of the current opcode to be executed. + pub fn instruction_pointer(&self) -> usize { + self.instruction_pointer + } + /// Finalize the ACVM execution, returning the resulting [`WitnessMap`]. pub fn finalize(self) -> WitnessMap { - if !matches!(self.status, ACVMStatus::Solved) { + if self.status != ACVMStatus::Solved { panic!("ACVM is not ready to be finalized"); } self.witness_map @@ -157,17 +168,28 @@ impl ACVM { /// Return a reference to the arguments for the next pending foreign call, if one exists. pub fn get_pending_foreign_call(&self) -> Option<&ForeignCallWaitInfo> { - self.pending_foreign_calls.first().map(|foreign_call| &foreign_call.foreign_call_wait_info) + if let ACVMStatus::RequiresForeignCall(foreign_call) = &self.status { + Some(&foreign_call.foreign_call_wait_info) + } else { + None + } } /// Resolves a pending foreign call using a result calculated outside of the ACVM. pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { - // Remove the first foreign call and inject the result to create a new opcode. - let foreign_call = self.pending_foreign_calls.remove(0); + let foreign_call = if let ACVMStatus::RequiresForeignCall(foreign_call) = &self.status { + // TODO: We can avoid this clone + foreign_call.clone() + } else { + panic!("no foreign call") + }; let resolved_brillig = foreign_call.resolve(foreign_call_result); - // Mark this opcode to be executed next. - self.opcodes.insert(0, Opcode::Brillig(resolved_brillig)); + // Overwrite the brillig opcode with a new one with the foreign call response. + self.opcodes[self.instruction_pointer] = Opcode::Brillig(resolved_brillig); + + // Now that the foreign call has been resolved then we can resume execution. + self.status(ACVMStatus::InProgress); } /// Executes the ACVM's circuit until execution halts. @@ -177,86 +199,55 @@ impl ACVM { /// 2. The circuit has been found to be unsatisfiable. /// 2. A Brillig [foreign call][`UnresolvedBrilligCall`] has been encountered and must be resolved. pub fn solve(&mut self) -> ACVMStatus { - // TODO: Prevent execution with outstanding foreign calls? - let mut unresolved_opcodes: Vec = Vec::new(); - while !self.opcodes.is_empty() { - unresolved_opcodes.clear(); - let mut stalled = true; - let mut opcode_not_solvable = None; - for opcode in &self.opcodes { - let resolution = match opcode { - Opcode::Arithmetic(expr) => { - ArithmeticSolver::solve(&mut self.witness_map, expr) - } - Opcode::BlackBoxFuncCall(bb_func) => { - blackbox::solve(&self.backend, &mut self.witness_map, bb_func) - } - Opcode::Directive(directive) => { - solve_directives(&mut self.witness_map, directive) - } - Opcode::Block(block) | Opcode::ROM(block) | Opcode::RAM(block) => { - let solver = self.block_solvers.entry(block.id).or_default(); - solver.solve(&mut self.witness_map, &block.trace) - } - Opcode::Brillig(brillig) => { - BrilligSolver::solve(&mut self.witness_map, brillig) - } - }; + while self.status == ACVMStatus::InProgress { + self.solve_opcode(); + } + self.status.clone() + } + + pub fn solve_opcode(&mut self) -> ACVMStatus { + let opcode = &self.opcodes[self.instruction_pointer]; + + let resolution = match opcode { + Opcode::Arithmetic(expr) => ArithmeticSolver::solve(&mut self.witness_map, expr), + Opcode::BlackBoxFuncCall(bb_func) => { + blackbox::solve(&self.backend, &mut self.witness_map, bb_func) + } + Opcode::Directive(directive) => solve_directives(&mut self.witness_map, directive), + Opcode::Block(block) | Opcode::ROM(block) | Opcode::RAM(block) => { + BlockSolver.solve(&mut self.witness_map, &block.trace) + } + Opcode::Brillig(brillig) => { + let resolution = BrilligSolver::solve(&mut self.witness_map, brillig); + match resolution { - Ok(OpcodeResolution::Solved) => { - stalled = false; - } - Ok(OpcodeResolution::InProgress) => { - stalled = false; - unresolved_opcodes.push(opcode.clone()); - } - Ok(OpcodeResolution::InProgressBrillig(oracle_wait_info)) => { - stalled = false; - // InProgressBrillig Oracles must be externally re-solved - let brillig = match opcode { - Opcode::Brillig(brillig) => brillig.clone(), - _ => unreachable!("Brillig resolution for non brillig opcode"), - }; - self.pending_foreign_calls.push(UnresolvedBrilligCall { - brillig, - foreign_call_wait_info: oracle_wait_info, - }) - } - Ok(OpcodeResolution::Stalled(not_solvable)) => { - if opcode_not_solvable.is_none() { - // we keep track of the first unsolvable opcode - opcode_not_solvable = Some(not_solvable); - } - // We push those opcodes not solvable to the back as - // it could be because the opcodes are out of order, i.e. this assignment - // relies on a later opcodes' results - unresolved_opcodes.push(opcode.clone()); + Ok(OpcodeResolution::InProgressBrillig(foreign_call_wait_info)) => { + return self.wait_for_foreign_call(opcode.clone(), foreign_call_wait_info) } - Err(OpcodeResolutionError::OpcodeNotSolvable(_)) => { - unreachable!("ICE - Result should have been converted to GateResolution") - } - Err(error) => return self.fail(error), + res => res, } } - - // Before potentially ending execution, we must save the list of opcodes which remain to be solved. - std::mem::swap(&mut self.opcodes, &mut unresolved_opcodes); - - // We have oracles that must be externally resolved - if self.get_pending_foreign_call().is_some() { - return self.status(ACVMStatus::RequiresForeignCall); + }; + match resolution { + Ok(OpcodeResolution::Solved) => { + self.instruction_pointer += 1; + if self.instruction_pointer == self.opcodes.len() { + self.status(ACVMStatus::Solved) + } else { + self.status(ACVMStatus::InProgress) + } } - - // We are stalled because of an opcode being bad - if stalled && !self.opcodes.is_empty() { - let error = OpcodeResolutionError::OpcodeNotSolvable( - opcode_not_solvable - .expect("infallible: cannot be stalled and None at the same time"), - ); - return self.fail(error); + Ok(OpcodeResolution::InProgress) => { + unreachable!("Opcodes should be immediately solvable"); + } + Ok(OpcodeResolution::InProgressBrillig(_)) => { + unreachable!("Handled above") } + Ok(OpcodeResolution::Stalled(not_solvable)) => self.status(ACVMStatus::Failure( + OpcodeResolutionError::OpcodeNotSolvable(not_solvable), + )), + Err(error) => self.fail(error), } - self.status(ACVMStatus::Solved) } } diff --git a/acvm/tests/solver.rs b/acvm/tests/solver.rs index 50a544b6e..cab984c2a 100644 --- a/acvm/tests/solver.rs +++ b/acvm/tests/solver.rs @@ -129,12 +129,11 @@ fn inversion_brillig_oracle_equivalence() { // use the partial witness generation solver with our acir program let solver_status = acvm.solve(); - assert_eq!( - solver_status, - ACVMStatus::RequiresForeignCall, + assert!( + matches!(solver_status, ACVMStatus::RequiresForeignCall(_)), "should require foreign call response" ); - assert!(acvm.unresolved_opcodes().is_empty(), "brillig should have been removed"); + assert_eq!(acvm.instruction_pointer(), 0, "brillig should have been removed"); let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); @@ -259,12 +258,11 @@ fn double_inversion_brillig_oracle() { // use the partial witness generation solver with our acir program let solver_status = acvm.solve(); - assert_eq!( - solver_status, - ACVMStatus::RequiresForeignCall, + assert!( + matches!(solver_status, ACVMStatus::RequiresForeignCall(_)), "should require foreign call response" ); - assert!(acvm.unresolved_opcodes().is_empty(), "brillig should have been removed"); + assert_eq!(acvm.instruction_pointer(), 0, "should stall on brillig"); let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); @@ -277,12 +275,11 @@ fn double_inversion_brillig_oracle() { // After filling data request, continue solving let solver_status = acvm.solve(); - assert_eq!( - solver_status, - ACVMStatus::RequiresForeignCall, + assert!( + matches!(solver_status, ACVMStatus::RequiresForeignCall(_)), "should require foreign call response" ); - assert!(acvm.unresolved_opcodes().is_empty(), "should be fully solved"); + assert_eq!(acvm.instruction_pointer(), 0, "should stall on brillig"); let foreign_call_wait_info = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); @@ -382,17 +379,11 @@ fn oracle_dependent_execution() { // use the partial witness generation solver with our acir program let solver_status = acvm.solve(); - assert_eq!( - solver_status, - ACVMStatus::RequiresForeignCall, + assert!( + matches!(solver_status, ACVMStatus::RequiresForeignCall(_)), "should require foreign call response" ); - assert_eq!(acvm.unresolved_opcodes().len(), 1, "brillig should have been removed"); - assert_eq!( - acvm.unresolved_opcodes()[0], - Opcode::Arithmetic(inverse_equality_check.clone()), - "Equality check of inverses should still be waiting to be resolved" - ); + assert_eq!(acvm.instruction_pointer(), 1, "should stall on brillig"); let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); @@ -404,17 +395,11 @@ fn oracle_dependent_execution() { // After filling data request, continue solving let solver_status = acvm.solve(); - assert_eq!( - solver_status, - ACVMStatus::RequiresForeignCall, + assert!( + matches!(solver_status, ACVMStatus::RequiresForeignCall(_)), "should require foreign call response" ); - assert_eq!(acvm.unresolved_opcodes().len(), 1, "brillig should have been removed"); - assert_eq!( - acvm.unresolved_opcodes()[0], - Opcode::Arithmetic(inverse_equality_check), - "Equality check of inverses should still be waiting to be resolved" - ); + assert_eq!(acvm.instruction_pointer(), 1, "should stall on brillig"); let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request");