From 3c778b73d9458ab708df21c850468d708676cde4 Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Tue, 27 Aug 2024 11:54:43 -0400 Subject: [PATCH] chore(perf): Simplify poseidon2 algorithm (#5811) # Description ## Problem\* Resolves Optimizations found while looking exploring other Brillig opts. ## Summary\* There are a couple optimizations here: 1. I noticed that we loop over the cache and do some resetting inside of `squeeze` of `Poseidon2`. However, we our `Hasher` always creates a fresh Poseidon2 object so it seems unnecessary to reset the cache in this way. In Brillig this leads to an extra loop that is essentially unused and blows up the code size of any programs using poseidon in an unconstrained environment. 2. We were writing into a `result` array and returning it from `perform_duplex`. This result was unused inside of `absorb` and we can directly access `self.state` inside of `squeeze`. I no longer return anything from `perform_duplex`. ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [ ] I have tested the changes locally. - [ ] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/hash/poseidon2.nr | 55 +++++++------------------------ 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/noir_stdlib/src/hash/poseidon2.nr b/noir_stdlib/src/hash/poseidon2.nr index 9626da0cf97..cf820f86370 100644 --- a/noir_stdlib/src/hash/poseidon2.nr +++ b/noir_stdlib/src/hash/poseidon2.nr @@ -26,7 +26,7 @@ impl Poseidon2 { result } - fn perform_duplex(&mut self) -> [Field; RATE] { + fn perform_duplex(&mut self) { // zero-pad the cache for i in 0..RATE { if i >= self.cache_size { @@ -38,61 +38,30 @@ impl Poseidon2 { self.state[i] += self.cache[i]; } self.state = crate::hash::poseidon2_permutation(self.state, 4); - // return `RATE` number of field elements from the sponge state. - let mut result = [0; RATE]; - for i in 0..RATE { - result[i] = self.state[i]; - } - result } fn absorb(&mut self, input: Field) { - if (!self.squeeze_mode) & (self.cache_size == RATE) { + assert(!self.squeeze_mode); + if self.cache_size == RATE { // If we're absorbing, and the cache is full, apply the sponge permutation to compress the cache - let _ = self.perform_duplex(); + self.perform_duplex(); self.cache[0] = input; self.cache_size = 1; - } else if (!self.squeeze_mode) & (self.cache_size != RATE) { + } else { // If we're absorbing, and the cache is not full, add the input into the cache self.cache[self.cache_size] = input; self.cache_size += 1; - } else if self.squeeze_mode { - // If we're in squeeze mode, switch to absorb mode and add the input into the cache. - // N.B. I don't think this code path can be reached?! - self.cache[0] = input; - self.cache_size = 1; - self.squeeze_mode = false; } } fn squeeze(&mut self) -> Field { - if self.squeeze_mode & (self.cache_size == 0) { - // If we're in squeze mode and the cache is empty, there is nothing left to squeeze out of the sponge! - // Switch to absorb mode. - self.squeeze_mode = false; - self.cache_size = 0; - } - if !self.squeeze_mode { - // If we're in absorb mode, apply sponge permutation to compress the cache, populate cache with compressed - // state and switch to squeeze mode. Note: this code block will execute if the previous `if` condition was - // matched - let new_output_elements = self.perform_duplex(); - self.squeeze_mode = true; - for i in 0..RATE { - self.cache[i] = new_output_elements[i]; - } - self.cache_size = RATE; - } - // By this point, we should have a non-empty cache. Pop one item off the top of the cache and return it. - let result = self.cache[0]; - for i in 1..RATE { - if i < self.cache_size { - self.cache[i - 1] = self.cache[i]; - } - } - self.cache_size -= 1; - self.cache[self.cache_size] = 0; - result + assert(!self.squeeze_mode); + // If we're in absorb mode, apply sponge permutation to compress the cache. + self.perform_duplex(); + self.squeeze_mode = true; + + // Pop one item off the top of the permutation and return it. + self.state[0] } fn hash_internal(input: [Field; N], in_len: u32, is_variable_length: bool) -> Field {