Skip to content

Commit

Permalink
sha3: fix padding for long cSHAKE parameters
Browse files Browse the repository at this point in the history
We used to compute the incorrect value if len(initBlock) % rate == 0.

Also, add a test vector for golang/go#66232, confirmed to fail on
GOARCH=386 without CL 570876.

Fixes golang/go#69169

Change-Id: I3f2400926fca111dd0ca1327d6b5975e51b28f96
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/616576
Reviewed-by: Andrew Ekstedt <andrew.ekstedt@gmail.com>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
Reviewed-by: Michael Pratt <mpratt@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
  • Loading branch information
FiloSottile authored and gopherbot committed Oct 22, 2024
1 parent c17aa50 commit 80ea76e
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 22 deletions.
111 changes: 111 additions & 0 deletions sha3/sha3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"encoding/json"
"fmt"
"hash"
"io"
"math/rand"
"os"
"strings"
Expand Down Expand Up @@ -375,6 +376,116 @@ func TestClone(t *testing.T) {
}
}

func TestCSHAKEAccumulated(t *testing.T) {
// Generated with pycryptodome@3.20.0
//
// from Crypto.Hash import cSHAKE128
// rng = cSHAKE128.new()
// acc = cSHAKE128.new()
// for n in range(200):
// N = rng.read(n)
// for s in range(200):
// S = rng.read(s)
// c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
// c.update(rng.read(100))
// acc.update(c.read(200))
// c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
// c.update(rng.read(168))
// acc.update(c.read(200))
// c = cSHAKE128.cSHAKE_XOF(data=None, custom=S, capacity=256, function=N)
// c.update(rng.read(200))
// acc.update(c.read(200))
// print(acc.read(32).hex())
//
// and with @noble/hashes@v1.5.0
//
// import { bytesToHex } from "@noble/hashes/utils";
// import { cshake128 } from "@noble/hashes/sha3-addons";
// const rng = cshake128.create();
// const acc = cshake128.create();
// for (let n = 0; n < 200; n++) {
// const N = rng.xof(n);
// for (let s = 0; s < 200; s++) {
// const S = rng.xof(s);
// let c = cshake128.create({ NISTfn: N, personalization: S });
// c.update(rng.xof(100));
// acc.update(c.xof(200));
// c = cshake128.create({ NISTfn: N, personalization: S });
// c.update(rng.xof(168));
// acc.update(c.xof(200));
// c = cshake128.create({ NISTfn: N, personalization: S });
// c.update(rng.xof(200));
// acc.update(c.xof(200));
// }
// }
// console.log(bytesToHex(acc.xof(32)));
//
t.Run("cSHAKE128", func(t *testing.T) {
testCSHAKEAccumulated(t, NewCShake128, rate128,
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
})
t.Run("cSHAKE256", func(t *testing.T) {
testCSHAKEAccumulated(t, NewCShake256, rate256,
"0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef")
})
}

func testCSHAKEAccumulated(t *testing.T, newCShake func(N, S []byte) ShakeHash, rate int64, exp string) {
rnd := newCShake(nil, nil)
acc := newCShake(nil, nil)
for n := 0; n < 200; n++ {
N := make([]byte, n)
rnd.Read(N)
for s := 0; s < 200; s++ {
S := make([]byte, s)
rnd.Read(S)

c := newCShake(N, S)
io.CopyN(c, rnd, 100 /* < rate */)
io.CopyN(acc, c, 200)

c.Reset()
io.CopyN(c, rnd, rate)
io.CopyN(acc, c, 200)

c.Reset()
io.CopyN(c, rnd, 200 /* > rate */)
io.CopyN(acc, c, 200)
}
}
if got := hex.EncodeToString(acc.Sum(nil)[:32]); got != exp {
t.Errorf("got %s, want %s", got, exp)
}
}

func TestCSHAKELargeS(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}

// See https://go.dev/issue/66232.
const s = (1<<32)/8 + 1000 // s * 8 > 2^32
S := make([]byte, s)
rnd := NewShake128()
rnd.Read(S)
c := NewCShake128(nil, S)
io.CopyN(c, rnd, 1000)

// Generated with pycryptodome@3.20.0
//
// from Crypto.Hash import cSHAKE128
// rng = cSHAKE128.new()
// S = rng.read(536871912)
// c = cSHAKE128.new(custom=S)
// c.update(rng.read(1000))
// print(c.read(32).hex())
//
exp := "2cb9f237767e98f2614b8779cf096a52da9b3a849280bbddec820771ae529cf0"
if got := hex.EncodeToString(c.Sum(nil)); got != exp {
t.Errorf("got %s, want %s", got, exp)
}
}

// BenchmarkPermutationFunction measures the speed of the permutation function
// with no input data.
func BenchmarkPermutationFunction(b *testing.B) {
Expand Down
45 changes: 23 additions & 22 deletions sha3/shake.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/binary"
"hash"
"io"
"math/bits"
)

// ShakeHash defines the interface to hash functions that support
Expand Down Expand Up @@ -58,33 +59,33 @@ const (
rate256 = 136
)

func bytepad(input []byte, w int) []byte {
// leftEncode always returns max 9 bytes
buf := make([]byte, 0, 9+len(input)+w)
buf = append(buf, leftEncode(uint64(w))...)
buf = append(buf, input...)
padlen := w - (len(buf) % w)
return append(buf, make([]byte, padlen)...)
}

func leftEncode(value uint64) []byte {
var b [9]byte
binary.BigEndian.PutUint64(b[1:], value)
// Trim all but last leading zero bytes
i := byte(1)
for i < 8 && b[i] == 0 {
i++
func bytepad(data []byte, rate int) []byte {
out := make([]byte, 0, 9+len(data)+rate-1)
out = append(out, leftEncode(uint64(rate))...)
out = append(out, data...)
if padlen := rate - len(out)%rate; padlen < rate {
out = append(out, make([]byte, padlen)...)
}
// Prepend number of encoded bytes
b[i-1] = 9 - i
return b[i-1:]
return out
}

func leftEncode(x uint64) []byte {
// Let n be the smallest positive integer for which 2^(8n) > x.
n := (bits.Len64(x) + 7) / 8
if n == 0 {
n = 1
}
// Return n || x with n as a byte and x an n bytes in big-endian order.
b := make([]byte, 9)
binary.BigEndian.PutUint64(b[1:], x)
b = b[9-n-1:]
b[0] = byte(n)
return b
}

func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}

// leftEncode returns max 9 bytes
c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
c.initBlock = make([]byte, 0, 9+len(N)+9+len(S)) // leftEncode returns max 9 bytes
c.initBlock = append(c.initBlock, leftEncode(uint64(len(N))*8)...)
c.initBlock = append(c.initBlock, N...)
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S))*8)...)
Expand Down

0 comments on commit 80ea76e

Please sign in to comment.