Skip to content

Commit

Permalink
fix: avoid huge unrolling in hash_args (AztecProtocol/aztec-packages#…
Browse files Browse the repository at this point in the history
…5703)

After #4736 was fixed, now this
is possible :)

---------

Co-authored-by: Alvaro Rodriguez <sirasistant@MacBook-Pro-de-Alvaro.local>
  • Loading branch information
AztecBot and Alvaro Rodriguez committed Apr 12, 2024
1 parent 2bd006a commit 9e7d58b
Show file tree
Hide file tree
Showing 136 changed files with 9,908 additions and 453 deletions.
2 changes: 1 addition & 1 deletion .aztec-sync-commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ff28080bcfb946177010960722925973ee19646b
10d9ad99200a5897417ff5669763ead4e38d87fa
58 changes: 4 additions & 54 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,14 @@ namespace Program {

struct Keccak256 {
std::vector<Program::FunctionInput> inputs;
Program::FunctionInput var_message_size;
std::vector<Program::Witness> outputs;

friend bool operator==(const Keccak256&, const Keccak256&);
std::vector<uint8_t> bincodeSerialize() const;
static Keccak256 bincodeDeserialize(std::vector<uint8_t>);
};

struct Keccak256VariableLength {
std::vector<Program::FunctionInput> inputs;
Program::FunctionInput var_message_size;
std::vector<Program::Witness> outputs;

friend bool operator==(const Keccak256VariableLength&, const Keccak256VariableLength&);
std::vector<uint8_t> bincodeSerialize() const;
static Keccak256VariableLength bincodeDeserialize(std::vector<uint8_t>);
};

struct Keccakf1600 {
std::vector<Program::FunctionInput> inputs;
std::vector<Program::Witness> outputs;
Expand Down Expand Up @@ -275,7 +266,7 @@ namespace Program {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AND, XOR, RANGE, SHA256, Blake2s, Blake3, SchnorrVerify, PedersenCommitment, PedersenHash, EcdsaSecp256k1, EcdsaSecp256r1, FixedBaseScalarMul, EmbeddedCurveAdd, Keccak256, Keccak256VariableLength, Keccakf1600, RecursiveAggregation, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression> value;
std::variant<AND, XOR, RANGE, SHA256, Blake2s, Blake3, SchnorrVerify, PedersenCommitment, PedersenHash, EcdsaSecp256k1, EcdsaSecp256r1, FixedBaseScalarMul, EmbeddedCurveAdd, Keccak256, Keccakf1600, RecursiveAggregation, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression> value;

friend bool operator==(const BlackBoxFuncCall&, const BlackBoxFuncCall&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -2582,6 +2573,7 @@ namespace Program {

inline bool operator==(const BlackBoxFuncCall::Keccak256 &lhs, const BlackBoxFuncCall::Keccak256 &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.var_message_size == rhs.var_message_size)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
return true;
}
Expand All @@ -2607,6 +2599,7 @@ template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxFuncCall::Keccak256>::serialize(const Program::BlackBoxFuncCall::Keccak256 &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.var_message_size)>::serialize(obj.var_message_size, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

Expand All @@ -2615,49 +2608,6 @@ template <typename Deserializer>
Program::BlackBoxFuncCall::Keccak256 serde::Deserializable<Program::BlackBoxFuncCall::Keccak256>::deserialize(Deserializer &deserializer) {
Program::BlackBoxFuncCall::Keccak256 obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlackBoxFuncCall::Keccak256VariableLength &lhs, const BlackBoxFuncCall::Keccak256VariableLength &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.var_message_size == rhs.var_message_size)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxFuncCall::Keccak256VariableLength::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxFuncCall::Keccak256VariableLength>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxFuncCall::Keccak256VariableLength BlackBoxFuncCall::Keccak256VariableLength::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxFuncCall::Keccak256VariableLength>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxFuncCall::Keccak256VariableLength>::serialize(const Program::BlackBoxFuncCall::Keccak256VariableLength &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.var_message_size)>::serialize(obj.var_message_size, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxFuncCall::Keccak256VariableLength serde::Deserializable<Program::BlackBoxFuncCall::Keccak256VariableLength>::deserialize(Deserializer &deserializer) {
Program::BlackBoxFuncCall::Keccak256VariableLength obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.var_message_size = serde::Deserializable<decltype(obj.var_message_size)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
Expand Down
11 changes: 2 additions & 9 deletions acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ pub enum BlackBoxFuncCall {
outputs: (Witness, Witness),
},
Keccak256 {
inputs: Vec<FunctionInput>,
outputs: Vec<Witness>,
},
Keccak256VariableLength {
inputs: Vec<FunctionInput>,
/// This is the number of bytes to take
/// from the input. Note: if `var_message_size`
Expand Down Expand Up @@ -183,7 +179,6 @@ impl BlackBoxFuncCall {
BlackBoxFuncCall::FixedBaseScalarMul { .. } => BlackBoxFunc::FixedBaseScalarMul,
BlackBoxFuncCall::EmbeddedCurveAdd { .. } => BlackBoxFunc::EmbeddedCurveAdd,
BlackBoxFuncCall::Keccak256 { .. } => BlackBoxFunc::Keccak256,
BlackBoxFuncCall::Keccak256VariableLength { .. } => BlackBoxFunc::Keccak256,
BlackBoxFuncCall::Keccakf1600 { .. } => BlackBoxFunc::Keccakf1600,
BlackBoxFuncCall::RecursiveAggregation { .. } => BlackBoxFunc::RecursiveAggregation,
BlackBoxFuncCall::BigIntAdd { .. } => BlackBoxFunc::BigIntAdd,
Expand All @@ -206,7 +201,6 @@ impl BlackBoxFuncCall {
BlackBoxFuncCall::SHA256 { inputs, .. }
| BlackBoxFuncCall::Blake2s { inputs, .. }
| BlackBoxFuncCall::Blake3 { inputs, .. }
| BlackBoxFuncCall::Keccak256 { inputs, .. }
| BlackBoxFuncCall::Keccakf1600 { inputs, .. }
| BlackBoxFuncCall::PedersenCommitment { inputs, .. }
| BlackBoxFuncCall::PedersenHash { inputs, .. }
Expand Down Expand Up @@ -280,7 +274,7 @@ impl BlackBoxFuncCall {
inputs.extend(hashed_message.iter().copied());
inputs
}
BlackBoxFuncCall::Keccak256VariableLength { inputs, var_message_size, .. } => {
BlackBoxFuncCall::Keccak256 { inputs, var_message_size, .. } => {
let mut inputs = inputs.clone();
inputs.push(*var_message_size);
inputs
Expand All @@ -306,9 +300,8 @@ impl BlackBoxFuncCall {
BlackBoxFuncCall::SHA256 { outputs, .. }
| BlackBoxFuncCall::Blake2s { outputs, .. }
| BlackBoxFuncCall::Blake3 { outputs, .. }
| BlackBoxFuncCall::Keccak256 { outputs, .. }
| BlackBoxFuncCall::Keccakf1600 { outputs, .. }
| BlackBoxFuncCall::Keccak256VariableLength { outputs, .. }
| BlackBoxFuncCall::Keccak256 { outputs, .. }
| BlackBoxFuncCall::Poseidon2Permutation { outputs, .. }
| BlackBoxFuncCall::Sha256Compression { outputs, .. } => outputs.to_vec(),
BlackBoxFuncCall::AND { output, .. }
Expand Down
10 changes: 1 addition & 9 deletions acvm-repo/acvm/src/pwg/blackbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,7 @@ pub(crate) fn solve(
blake3,
bb_func.get_black_box_func(),
),
BlackBoxFuncCall::Keccak256 { inputs, outputs } => solve_generic_256_hash_opcode(
initial_witness,
inputs,
None,
outputs,
keccak256,
bb_func.get_black_box_func(),
),
BlackBoxFuncCall::Keccak256VariableLength { inputs, var_message_size, outputs } => {
BlackBoxFuncCall::Keccak256 { inputs, var_message_size, outputs } => {
solve_generic_256_hash_opcode(
initial_witness,
inputs,
Expand Down
6 changes: 3 additions & 3 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
for output in brillig.outputs.iter() {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, memory[current_ret_data_idx].value, witness_map)?;
insert_value(witness, memory[current_ret_data_idx].to_field(), witness_map)?;
current_ret_data_idx += 1;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr.iter() {
let value = memory[current_ret_data_idx];
insert_value(witness, value.value, witness_map)?;
let value = &memory[current_ret_data_idx];
insert_value(witness, value.to_field(), witness_map)?;
current_ret_data_idx += 1;
}
}
Expand Down
5 changes: 2 additions & 3 deletions acvm-repo/acvm_js/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ function run_if_available {
require_command jq
require_command cargo
require_command wasm-bindgen
require_command wasm-opt

self_path=$(dirname "$(readlink -f "$0")")
pname=$(cargo read-manifest | jq -r '.name')
Expand All @@ -49,5 +48,5 @@ BROWSER_WASM=${BROWSER_DIR}/${pname}_bg.wasm
run_or_fail cargo build --lib --release --target $TARGET --package ${pname}
run_or_fail wasm-bindgen $WASM_BINARY --out-dir $NODE_DIR --typescript --target nodejs
run_or_fail wasm-bindgen $WASM_BINARY --out-dir $BROWSER_DIR --typescript --target web
run_or_fail wasm-opt $NODE_WASM -o $NODE_WASM -O
run_or_fail wasm-opt $BROWSER_WASM -o $BROWSER_WASM -O
run_if_available wasm-opt $NODE_WASM -o $NODE_WASM -O
run_if_available wasm-opt $BROWSER_WASM -o $BROWSER_WASM -O
56 changes: 55 additions & 1 deletion acvm-repo/acvm_js/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use wasm_bindgen::prelude::wasm_bindgen;

use crate::{
foreign_call::{resolve_brillig, ForeignCallHandler},
JsExecutionError, JsWitnessMap, JsWitnessStack,
public_witness::extract_indices,
JsExecutionError, JsSolvedAndReturnWitness, JsWitnessMap, JsWitnessStack,
};

#[wasm_bindgen]
Expand Down Expand Up @@ -58,6 +59,44 @@ pub async fn execute_circuit(
Ok(witness_map.into())
}

/// Executes an ACIR circuit to generate the solved witness from the initial witness.
/// This method also extracts the public return values from the solved witness into its own return witness.
///
/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver.
/// @param {Uint8Array} circuit - A serialized representation of an ACIR circuit
/// @param {WitnessMap} initial_witness - The initial witness map defining all of the inputs to `circuit`..
/// @param {ForeignCallHandler} foreign_call_handler - A callback to process any foreign calls from the circuit.
/// @returns {SolvedAndReturnWitness} The solved witness calculated by executing the circuit on the provided inputs, as well as the return witness indices as specified by the circuit.
#[wasm_bindgen(js_name = executeCircuitWithReturnWitness, skip_jsdoc)]
pub async fn execute_circuit_with_return_witness(
solver: &WasmBlackBoxFunctionSolver,
program: Vec<u8>,
initial_witness: JsWitnessMap,
foreign_call_handler: ForeignCallHandler,
) -> Result<JsSolvedAndReturnWitness, Error> {
console_error_panic_hook::set_once();

let program: Program = Program::deserialize_program(&program)
.map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?;

let mut witness_stack = execute_program_with_native_program_and_return(
solver,
&program,
initial_witness,
&foreign_call_handler,
)
.await?;
let solved_witness =
witness_stack.pop().expect("Should have at least one witness on the stack").witness;

let main_circuit = &program.functions[0];
let return_witness =
extract_indices(&solved_witness, main_circuit.return_values.0.iter().copied().collect())
.map_err(|err| JsExecutionError::new(err, None))?;

Ok((solved_witness, return_witness).into())
}

/// Executes an ACIR circuit to generate the solved witness from the initial witness.
///
/// @param {&WasmBlackBoxFunctionSolver} solver - A black box solver.
Expand Down Expand Up @@ -127,6 +166,21 @@ async fn execute_program_with_native_type_return(
let program: Program = Program::deserialize_program(&program)
.map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None))?;

execute_program_with_native_program_and_return(
solver,
&program,
initial_witness,
foreign_call_executor,
)
.await
}

async fn execute_program_with_native_program_and_return(
solver: &WasmBlackBoxFunctionSolver,
program: &Program,
initial_witness: JsWitnessMap,
foreign_call_executor: &ForeignCallHandler,
) -> Result<WitnessStack, Error> {
let executor = ProgramExecutor::new(&program.functions, &solver.0, foreign_call_executor);
let witness_stack = executor.execute(initial_witness.into()).await?;

Expand Down
38 changes: 37 additions & 1 deletion acvm-repo/acvm_js/src/js_witness_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@ use acvm::{
acir::native_types::{Witness, WitnessMap},
FieldElement,
};
use js_sys::{JsString, Map};
use js_sys::{JsString, Map, Object};
use wasm_bindgen::prelude::{wasm_bindgen, JsValue};

#[wasm_bindgen(typescript_custom_section)]
const WITNESS_MAP: &'static str = r#"
// Map from witness index to hex string value of witness.
export type WitnessMap = Map<number, string>;
/**
* An execution result containing two witnesses.
* 1. The full solved witness of the execution.
* 2. The return witness which contains the given public return values within the full witness.
*/
export type SolvedAndReturnWitness = {
solvedWitness: WitnessMap;
returnWitness: WitnessMap;
}
"#;

// WitnessMap
Expand All @@ -21,6 +31,12 @@ extern "C" {
#[wasm_bindgen(constructor, js_class = "Map")]
pub fn new() -> JsWitnessMap;

#[wasm_bindgen(extends = Object, js_name = "SolvedAndReturnWitness", typescript_type = "SolvedAndReturnWitness")]
#[derive(Clone, Debug, PartialEq, Eq)]
pub type JsSolvedAndReturnWitness;

#[wasm_bindgen(constructor, js_class = "Object")]
pub fn new() -> JsSolvedAndReturnWitness;
}

impl Default for JsWitnessMap {
Expand All @@ -29,6 +45,12 @@ impl Default for JsWitnessMap {
}
}

impl Default for JsSolvedAndReturnWitness {
fn default() -> Self {
Self::new()
}
}

impl From<WitnessMap> for JsWitnessMap {
fn from(witness_map: WitnessMap) -> Self {
let js_map = JsWitnessMap::new();
Expand All @@ -54,6 +76,20 @@ impl From<JsWitnessMap> for WitnessMap {
}
}

impl From<(WitnessMap, WitnessMap)> for JsSolvedAndReturnWitness {
fn from(witness_maps: (WitnessMap, WitnessMap)) -> Self {
let js_solved_witness = JsWitnessMap::from(witness_maps.0);
let js_return_witness = JsWitnessMap::from(witness_maps.1);

let entry_map = Map::new();
entry_map.set(&JsValue::from_str("solvedWitness"), &js_solved_witness);
entry_map.set(&JsValue::from_str("returnWitness"), &js_return_witness);

let solved_and_return_witness = Object::from_entries(&entry_map).unwrap();
JsSolvedAndReturnWitness { obj: solved_and_return_witness }
}
}

pub(crate) fn js_value_to_field_element(js_value: JsValue) -> Result<FieldElement, JsString> {
let hex_str = js_value.as_string().ok_or("failed to parse field element from non-string")?;

Expand Down
3 changes: 2 additions & 1 deletion acvm-repo/acvm_js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ pub use compression::{
};
pub use execute::{
create_black_box_solver, execute_circuit, execute_circuit_with_black_box_solver,
execute_program, execute_program_with_black_box_solver,
execute_circuit_with_return_witness, execute_program, execute_program_with_black_box_solver,
};
pub use js_execution_error::JsExecutionError;
pub use js_witness_map::JsSolvedAndReturnWitness;
pub use js_witness_map::JsWitnessMap;
pub use js_witness_stack::JsWitnessStack;
pub use logging::init_log_level;
Expand Down
9 changes: 6 additions & 3 deletions acvm-repo/acvm_js/src/public_witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use wasm_bindgen::prelude::wasm_bindgen;

use crate::JsWitnessMap;

fn extract_indices(witness_map: &WitnessMap, indices: Vec<Witness>) -> Result<WitnessMap, String> {
pub(crate) fn extract_indices(
witness_map: &WitnessMap,
indices: Vec<Witness>,
) -> Result<WitnessMap, String> {
let mut extracted_witness_map = WitnessMap::new();
for witness in indices {
let witness_value = witness_map.get(&witness).ok_or(format!(
Expand Down Expand Up @@ -44,7 +47,7 @@ pub fn get_return_witness(
let witness_map = WitnessMap::from(witness_map);

let return_witness =
extract_indices(&witness_map, circuit.return_values.0.clone().into_iter().collect())?;
extract_indices(&witness_map, circuit.return_values.0.iter().copied().collect())?;

Ok(JsWitnessMap::from(return_witness))
}
Expand All @@ -71,7 +74,7 @@ pub fn get_public_parameters_witness(
let witness_map = WitnessMap::from(solved_witness);

let public_params_witness =
extract_indices(&witness_map, circuit.public_parameters.0.clone().into_iter().collect())?;
extract_indices(&witness_map, circuit.public_parameters.0.iter().copied().collect())?;

Ok(JsWitnessMap::from(public_params_witness))
}
Expand Down
Loading

0 comments on commit 9e7d58b

Please sign in to comment.