From 8da1a6f6dfd3b87b86ffd27c3b7987fa83cac2bc Mon Sep 17 00:00:00 2001 From: Tom French Date: Fri, 8 Dec 2023 19:26:23 +0000 Subject: [PATCH 1/2] feat: optimize the stdlib sha2 implementations --- noir_stdlib/src/sha256.nr | 14 ++++++++------ noir_stdlib/src/sha512.nr | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/noir_stdlib/src/sha256.nr b/noir_stdlib/src/sha256.nr index 2f686a64165..58f5288f23e 100644 --- a/noir_stdlib/src/sha256.nr +++ b/noir_stdlib/src/sha256.nr @@ -6,9 +6,11 @@ fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { let mut msg32: [u32; 16] = [0; 16]; for i in 0..16 { + let mut msg_field: Field = 0; for j in 0..4 { - msg32[15 - i] = (msg32[15 - i] << 8) + msg[64 - 4*(i + 1) + j] as u32; + msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; } + msg32[15 - i] = msg_field as u32; } msg32 @@ -51,16 +53,16 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { i = 0; } + let len = 8 * msg.len(); + let len_bytes = (len as Field).to_le_bytes(8); for _i in 0..64 { // In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). if i < 56 { msg_block[i as Field] = 0; i = i + 1; } else if i < 64 { - let mut len = 8 * msg.len(); for j in 0..8 { - msg_block[63 - j] = len as u8; - len >>= 8; + msg_block[63 - j] = len_bytes[j]; } i += 8; } @@ -70,9 +72,9 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { // Return final hash as byte array for j in 0..8 { + let h_bytes = (h[7 - j] as Field).to_le_bytes(4); for k in 0..4 { - out_h[31 - 4*j - k] = h[7 - j] as u8; - h[7-j] >>= 8; + out_h[31 - 4*j - k] = h_bytes[k]; } } diff --git a/noir_stdlib/src/sha512.nr b/noir_stdlib/src/sha512.nr index 4dfe78308e2..5a4e169d416 100644 --- a/noir_stdlib/src/sha512.nr +++ b/noir_stdlib/src/sha512.nr @@ -77,9 +77,11 @@ fn msg_u8_to_u64(msg: [u8; 128]) -> [u64; 16] { let mut msg64: [u64; 16] = [0; 16]; for i in 0..16 { + let mut msg_field: Field = 0; for j in 0..8 { - msg64[15 - i] = (msg64[15 - i] << 8) + msg[128 - 8*(i + 1) + j] as u64; + msg_field = msg_field * 256 + msg[128 - 8*(i + 1) + j] as Field; } + msg64[15 - i] = msg_field as u64; } msg64 @@ -130,16 +132,16 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { i = 0; } + let len = 8 * msg.len(); + let len_bytes = (len as Field).to_le_bytes(16); for _i in 0..128 { // In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112). if i < 112 { msg_block[i as Field] = 0; i += 1; } else if i < 128 { - let mut len = 8 * msg.len(); for j in 0..16 { - msg_block[127 - j] = len as u8; - len >>= 8; + msg_block[127 - j] = len_bytes[j]; } i += 16; // Done. } @@ -151,9 +153,9 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { } // Return final hash as byte array for j in 0..8 { + let h_bytes = (h[7 - j] as Field).to_le_bytes(8); for k in 0..8 { - out_h[63 - 8*j - k] = h[7 - j] as u8; - h[7-j] >>= 8; + out_h[63 - 8*j - k] = h_bytes[k]; } } From 86e38dd31c9750f030c714234b67ee075030c1d3 Mon Sep 17 00:00:00 2001 From: Tom French Date: Tue, 27 Feb 2024 17:21:15 +0000 Subject: [PATCH 2/2] chore: remove unnecessary casts --- noir_stdlib/src/sha256.nr | 8 ++++---- noir_stdlib/src/sha512.nr | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/noir_stdlib/src/sha256.nr b/noir_stdlib/src/sha256.nr index 58f5288f23e..8ca6808568d 100644 --- a/noir_stdlib/src/sha256.nr +++ b/noir_stdlib/src/sha256.nr @@ -23,7 +23,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { let mut i: u64 = 0; // Message byte pointer for k in 0..N { // Populate msg_block - msg_block[i as Field] = msg[k]; + msg_block[i] = msg[k]; i = i + 1; if i == 64 { // Enough to hash block @@ -34,7 +34,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { } // Pad the rest such that we have a [u32; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i as Field] = 1 << 7; + msg_block[i] = 1 << 7; i = i + 1; // If i >= 57, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. @@ -43,7 +43,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { if i < 64 { for _i in 57..64 { if i <= 63 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i += 1; } } @@ -58,7 +58,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { for _i in 0..64 { // In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). if i < 56 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i = i + 1; } else if i < 64 { for j in 0..8 { diff --git a/noir_stdlib/src/sha512.nr b/noir_stdlib/src/sha512.nr index 5a4e169d416..a766ae50d55 100644 --- a/noir_stdlib/src/sha512.nr +++ b/noir_stdlib/src/sha512.nr @@ -96,7 +96,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { let mut i: u64 = 0; // Message byte pointer for k in 0..msg.len() { // Populate msg_block - msg_block[i as Field] = msg[k]; + msg_block[i] = msg[k]; i = i + 1; if i == 128 { // Enough to hash block @@ -110,7 +110,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { } // Pad the rest such that we have a [u64; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i as Field] = 1 << 7; + msg_block[i] = 1 << 7; i += 1; // If i >= 113, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. @@ -119,7 +119,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { if i < 128 { for _i in 113..128 { if i <= 127 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i += 1; } } @@ -137,7 +137,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { for _i in 0..128 { // In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112). if i < 112 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i += 1; } else if i < 128 { for j in 0..16 {