From 9dc05bc3d47c57981e584661fcc7b5480e21d7d8 Mon Sep 17 00:00:00 2001 From: guipublic <47281315+guipublic@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:43:43 +0100 Subject: [PATCH] chore: uses sha256compression opcode in Noir and implements acvm solver for it (#4511) This PR should be merged **after** PR #4503 , because it uses the opcode implemented by the latter. In this PR, I add the implementation of the ACVM solver for sha256compression opcode, and use it in Noir implementation of sha256. This gives us 3 ways of doing sha256. You can see below the resulting circuit size for hashing 1 byte with each of them: - The full Noir implementation : 17161 ACIR Opcodes, Circuit size is 65065 - The full BB implementation: 75 ACIR Opcodes, Circuit size is 38799 - Mixed Noir+sha256compression opcode: 351 ACIR Occodes, Circuit size is 15495 The sha256compression opcode is a clear winner, and this is because it uses UltraPlonk lookup-gates. As a result, I have removed the 2 other methods in the stdlib. The stdlib sha256 is now calling the Noir implementation which is using the sha256compression opcodes. The old opcode should be removed in a future PR. --------- Co-authored-by: kevaundray --- noir/acvm-repo/acvm/src/pwg/blackbox/hash.rs | 47 +++++++- noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs | 12 ++- noir/acvm-repo/blackbox_solver/Cargo.toml | 2 +- noir/acvm-repo/blackbox_solver/src/lib.rs | 10 ++ noir/acvm-repo/brillig_vm/src/black_box.rs | 33 +++++- noir/noir_stdlib/src/sha256.nr | 106 +------------------ 6 files changed, 103 insertions(+), 107 deletions(-) diff --git a/noir/acvm-repo/acvm/src/pwg/blackbox/hash.rs b/noir/acvm-repo/acvm/src/pwg/blackbox/hash.rs index 1ada397fc59..06489822c92 100644 --- a/noir/acvm-repo/acvm/src/pwg/blackbox/hash.rs +++ b/noir/acvm-repo/acvm/src/pwg/blackbox/hash.rs @@ -3,7 +3,7 @@ use acir::{ native_types::{Witness, WitnessMap}, BlackBoxFunc, FieldElement, }; -use acvm_blackbox_solver::BlackBoxResolutionError; +use acvm_blackbox_solver::{sha256compression, BlackBoxResolutionError}; use crate::pwg::{insert_value, witness_to_value}; use crate::OpcodeResolutionError; @@ -86,3 +86,48 @@ fn write_digest_to_outputs( Ok(()) } + +pub(crate) fn solve_sha_256_permutation_opcode( + initial_witness: &mut WitnessMap, + inputs: &[FunctionInput], + hash_values: &[FunctionInput], + outputs: &[Witness], + black_box_func: BlackBoxFunc, +) -> Result<(), OpcodeResolutionError> { + let mut message = [0; 16]; + if inputs.len() != 16 { + return Err(OpcodeResolutionError::BlackBoxFunctionFailed( + black_box_func, + format!("Expected 16 inputs but encountered {}", &message.len()), + )); + } + for (i, input) in inputs.iter().enumerate() { + let value = witness_to_value(initial_witness, input.witness)?; + message[i] = value.to_u128() as u32; + } + + if hash_values.len() != 8 { + return Err(OpcodeResolutionError::BlackBoxFunctionFailed( + black_box_func, + format!("Expected 8 values but encountered {}", hash_values.len()), + )); + } + let mut state = [0; 8]; + for (i, hash) in hash_values.iter().enumerate() { + let value = witness_to_value(initial_witness, hash.witness)?; + state[i] = value.to_u128() as u32; + } + + sha256compression(&mut state, &message); + let outputs: [Witness; 8] = outputs.try_into().map_err(|_| { + OpcodeResolutionError::BlackBoxFunctionFailed( + black_box_func, + format!("Expected 8 outputs but encountered {}", outputs.len()), + ) + })?; + for (output_witness, value) in outputs.iter().zip(state.into_iter()) { + insert_value(output_witness, FieldElement::from(value as u128), initial_witness)?; + } + + Ok(()) +} diff --git a/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs b/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs index 7146dff87e0..7ae92fd84fc 100644 --- a/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs +++ b/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs @@ -20,7 +20,7 @@ mod signature; use fixed_base_scalar_mul::{embedded_curve_add, fixed_base_scalar_mul}; // Hash functions should eventually be exposed for external consumers. -use hash::solve_generic_256_hash_opcode; +use hash::{solve_generic_256_hash_opcode, solve_sha_256_permutation_opcode}; use logic::{and, xor}; use pedersen::pedersen; use range::solve_range_opcode; @@ -205,6 +205,14 @@ pub(crate) fn solve( bigint_solver.bigint_to_bytes(*input, outputs, initial_witness) } BlackBoxFuncCall::Poseidon2Permutation { .. } => todo!(), - BlackBoxFuncCall::Sha256Compression { .. } => todo!(), + BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => { + solve_sha_256_permutation_opcode( + initial_witness, + inputs, + hash_values, + outputs, + bb_func.get_black_box_func(), + ) + } } } diff --git a/noir/acvm-repo/blackbox_solver/Cargo.toml b/noir/acvm-repo/blackbox_solver/Cargo.toml index 7359cf307e4..a13f496f34f 100644 --- a/noir/acvm-repo/blackbox_solver/Cargo.toml +++ b/noir/acvm-repo/blackbox_solver/Cargo.toml @@ -18,7 +18,7 @@ thiserror.workspace = true blake2 = "0.10.6" blake3 = "1.5.0" -sha2 = "0.10.6" +sha2 = { version="0.10.6", features = ["compress",] } sha3 = "0.10.6" keccak = "0.1.4" k256 = { version = "0.11.0", features = [ diff --git a/noir/acvm-repo/blackbox_solver/src/lib.rs b/noir/acvm-repo/blackbox_solver/src/lib.rs index afba4eff17c..e033344fefa 100644 --- a/noir/acvm-repo/blackbox_solver/src/lib.rs +++ b/noir/acvm-repo/blackbox_solver/src/lib.rs @@ -43,6 +43,16 @@ pub fn keccak256(inputs: &[u8]) -> Result<[u8; 32], BlackBoxResolutionError> { .map_err(|err| BlackBoxResolutionError::Failed(BlackBoxFunc::Keccak256, err)) } +pub fn sha256compression(state: &mut [u32; 8], msg_blocks: &[u32; 16]) { + let mut blocks = [0_u8; 64]; + for (i, block) in msg_blocks.iter().enumerate() { + let bytes = block.to_be_bytes(); + blocks[i * 4..i * 4 + 4].copy_from_slice(&bytes); + } + let blocks: GenericArray = blocks.into(); + sha2::compress256(state, &[blocks]); +} + const KECCAK_LANES: usize = 25; pub fn keccakf1600( diff --git a/noir/acvm-repo/brillig_vm/src/black_box.rs b/noir/acvm-repo/brillig_vm/src/black_box.rs index 04aa2bcf9af..5b2680465ab 100644 --- a/noir/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/acvm-repo/brillig_vm/src/black_box.rs @@ -2,7 +2,7 @@ use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, Value}; use acir::{BlackBoxFunc, FieldElement}; use acvm_blackbox_solver::{ blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256, keccakf1600, - sha256, BlackBoxFunctionSolver, BlackBoxResolutionError, + sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError, }; use crate::Memory; @@ -185,7 +185,36 @@ pub(crate) fn evaluate_black_box( BlackBoxOp::BigIntFromLeBytes { .. } => todo!(), BlackBoxOp::BigIntToLeBytes { .. } => todo!(), BlackBoxOp::Poseidon2Permutation { .. } => todo!(), - BlackBoxOp::Sha256Compression { .. } => todo!(), + BlackBoxOp::Sha256Compression { input, hash_values, output } => { + let mut message = [0; 16]; + let inputs = read_heap_vector(memory, input); + if inputs.len() != 16 { + return Err(BlackBoxResolutionError::Failed( + BlackBoxFunc::Sha256Compression, + format!("Expected 16 inputs but encountered {}", &inputs.len()), + )); + } + for (i, input) in inputs.iter().enumerate() { + message[i] = input.to_u128() as u32; + } + let mut state = [0; 8]; + let values = read_heap_vector(memory, hash_values); + if values.len() != 8 { + return Err(BlackBoxResolutionError::Failed( + BlackBoxFunc::Sha256Compression, + format!("Expected 8 values but encountered {}", &values.len()), + )); + } + for (i, value) in values.iter().enumerate() { + state[i] = value.to_u128() as u32; + } + + sha256compression(&mut state, &message); + let state = state.map(|x| Value::from(x as u128)); + + memory.write_slice(memory.read_ref(output.pointer), &state); + Ok(()) + } } } diff --git a/noir/noir_stdlib/src/sha256.nr b/noir/noir_stdlib/src/sha256.nr index 39e39b8cb6e..6bcc5ea74c6 100644 --- a/noir/noir_stdlib/src/sha256.nr +++ b/noir/noir_stdlib/src/sha256.nr @@ -1,91 +1,6 @@ // Implementation of SHA-256 mapping a byte array of variable length to // 32 bytes. -// Internal functions act on 32-bit unsigned integers for simplicity. -// Auxiliary mappings; names as in FIPS PUB 180-4 -fn rotr32(a: u32, b: u32) -> u32 // 32-bit right rotation -{ - // None of the bits overlap between `(a >> b)` and `(a << (32 - b))` - // Addition is then equivalent to OR, with fewer constraints. - (a >> b) + (a << (32 - b)) -} - -fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ ((!x) & z) -} - -fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) -} - -fn bigma0(x: u32) -> u32 { - rotr32(x, 2) ^ rotr32(x, 13) ^ rotr32(x, 22) -} -fn bigma1(x: u32) -> u32 { - rotr32(x, 6) ^ rotr32(x, 11) ^ rotr32(x, 25) -} - -fn sigma0(x: u32) -> u32 { - rotr32(x, 7) ^ rotr32(x, 18) ^ (x >> 3) -} - -fn sigma1(x: u32) -> u32 { - rotr32(x, 17) ^ rotr32(x, 19) ^ (x >> 10) -} - -fn sha_w(msg: [u32; 16]) -> [u32; 64] // Expanded message blocks -{ - let mut w: [u32;64] = [0; 64]; - - for j in 0..16 { - w[j] = msg[j]; - } - - for j in 16..64 { - w[j] = crate::wrapping_add( - crate::wrapping_add(sigma1(w[j-2]), w[j-7]), - crate::wrapping_add(sigma0(w[j-15]), w[j-16]), - ); - } - - w -} -// SHA-256 compression function -fn sha_c(msg: [u32; 16], hash: [u32; 8]) -> [u32; 8] { - let K: [u32; 64] = [ - 1116352408, 1899447441, 3049323471, 3921009573, 961987163, 1508970993, 2453635748, - 2870763221, 3624381080, 310598401, 607225278, 1426881987, 1925078388, 2162078206, - 2614888103, 3248222580, 3835390401, 4022224774, 264347078, 604807628, 770255983, 1249150122, - 1555081692, 1996064986, 2554220882, 2821834349, 2952996808, 3210313671, 3336571891, - 3584528711, 113926993, 338241895, 666307205, 773529912, 1294757372, 1396182291, 1695183700, - 1986661051, 2177026350, 2456956037, 2730485921, 2820302411, 3259730800, 3345764771, - 3516065817, 3600352804, 4094571909, 275423344, 430227734, 506948616, 659060556, 883997877, - 958139571, 1322822218, 1537002063, 1747873779, 1955562222, 2024104815, 2227730452, - 2361852424, 2428436474, 2756734187, 3204031479, 3329325298 - ]; // first 32 bits of fractional parts of cube roots of first 64 primes - let mut out_h: [u32; 8] = hash; - let w = sha_w(msg); - for j in 0..64 { - let t1 = crate::wrapping_add( - crate::wrapping_add( - crate::wrapping_add(out_h[7], bigma1(out_h[4])), - ch(out_h[4], out_h[5], out_h[6]) - ), - crate::wrapping_add(K[j], w[j]) - ); - let t2 = crate::wrapping_add(bigma0(out_h[0]), maj(out_h[0], out_h[1], out_h[2])); - out_h[7] = out_h[6]; - out_h[6] = out_h[5]; - out_h[5] = out_h[4]; - out_h[4] = crate::wrapping_add(out_h[3], t1); - out_h[3] = out_h[2]; - out_h[2] = out_h[1]; - out_h[1] = out_h[0]; - out_h[0] = crate::wrapping_add(t1, t2); - } - - out_h -} // Convert 64-byte array to array of 16 u32s fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { let mut msg32: [u32; 16] = [0; 16]; @@ -102,19 +17,15 @@ fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { pub fn digest(msg: [u8; N]) -> [u8; 32] { let mut msg_block: [u8; 64] = [0; 64]; let mut h: [u32; 8] = [1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635, 1541459225]; // Intermediate hash, starting with the canonical initial value - let mut c: [u32; 8] = [0; 8]; // Compression of current message block as sequence of u32 let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes let mut i: u64 = 0; // Message byte pointer - for k in 0..msg.len() { + for k in 0..N { // Populate msg_block msg_block[i as Field] = msg[k]; i = i + 1; if i == 64 { // Enough to hash block - c = sha_c(msg_u8_to_u32(msg_block), h); - for j in 0..8 { - h[j] = crate::wrapping_add(c[j], h[j]); - } + h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); i = 0; } @@ -135,11 +46,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { } } } - c = h; - c = sha_c(msg_u8_to_u32(msg_block), c); - for j in 0..8 { - h[j] = crate::wrapping_add(h[j], c[j]); - } + h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); i = 0; } @@ -159,11 +66,8 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { } } // Hash final padded block - c = h; - c = sha_c(msg_u8_to_u32(msg_block), c); - for j in 0..8 { - h[j] = crate::wrapping_add(h[j], c[j]); - } + h = crate::hash::sha256_compression(msg_u8_to_u32(msg_block), h); + // Return final hash as byte array for j in 0..8 { for k in 0..4 {