Skip to content

Commit 750a45f

Browse files
FiloSottilegopherbot
authored andcommitted
sha3: add MarshalBinary, AppendBinary, and UnmarshalBinary
Fixes golang/go#24617 Change-Id: I1d9d529950aa8a5953435e8d3412cda44b075d55 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/616635 Reviewed-by: Roland Shoemaker <roland@golang.org> Auto-Submit: Filippo Valsorda <filippo@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Daniel McCarney <daniel@binaryparadox.net> Reviewed-by: Michael Pratt <mpratt@google.com>
1 parent 36b1725 commit 750a45f

File tree

5 files changed

+171
-20
lines changed

5 files changed

+171
-20
lines changed

sha3/doc.go

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
// Package sha3 implements the SHA-3 fixed-output-length hash functions and
66
// the SHAKE variable-output-length hash functions defined by FIPS-202.
77
//
8+
// All types in this package also implement [encoding.BinaryMarshaler],
9+
// [encoding.BinaryAppender] and [encoding.BinaryUnmarshaler] to marshal and
10+
// unmarshal the internal state of the hash.
11+
//
812
// Both types of hash function use the "sponge" construction and the Keccak
913
// permutation. For a detailed specification see http://keccak.noekeon.org/
1014
//

sha3/hashes.go

+25-6
Original file line numberDiff line numberDiff line change
@@ -48,33 +48,52 @@ func init() {
4848
crypto.RegisterHash(crypto.SHA3_512, New512)
4949
}
5050

51+
const (
52+
dsbyteSHA3 = 0b00000110
53+
dsbyteKeccak = 0b00000001
54+
dsbyteShake = 0b00011111
55+
dsbyteCShake = 0b00000100
56+
57+
// rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in
58+
// bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits.
59+
rateK256 = (1600 - 256) / 8
60+
rateK448 = (1600 - 448) / 8
61+
rateK512 = (1600 - 512) / 8
62+
rateK768 = (1600 - 768) / 8
63+
rateK1024 = (1600 - 1024) / 8
64+
)
65+
5166
func new224Generic() *state {
52-
return &state{rate: 144, outputLen: 28, dsbyte: 0x06}
67+
return &state{rate: rateK448, outputLen: 28, dsbyte: dsbyteSHA3}
5368
}
5469

5570
func new256Generic() *state {
56-
return &state{rate: 136, outputLen: 32, dsbyte: 0x06}
71+
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteSHA3}
5772
}
5873

5974
func new384Generic() *state {
60-
return &state{rate: 104, outputLen: 48, dsbyte: 0x06}
75+
return &state{rate: rateK768, outputLen: 48, dsbyte: dsbyteSHA3}
6176
}
6277

6378
func new512Generic() *state {
64-
return &state{rate: 72, outputLen: 64, dsbyte: 0x06}
79+
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteSHA3}
6580
}
6681

6782
// NewLegacyKeccak256 creates a new Keccak-256 hash.
6883
//
6984
// Only use this function if you require compatibility with an existing cryptosystem
7085
// that uses non-standard padding. All other users should use New256 instead.
71-
func NewLegacyKeccak256() hash.Hash { return &state{rate: 136, outputLen: 32, dsbyte: 0x01} }
86+
func NewLegacyKeccak256() hash.Hash {
87+
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak}
88+
}
7289

7390
// NewLegacyKeccak512 creates a new Keccak-512 hash.
7491
//
7592
// Only use this function if you require compatibility with an existing cryptosystem
7693
// that uses non-standard padding. All other users should use New512 instead.
77-
func NewLegacyKeccak512() hash.Hash { return &state{rate: 72, outputLen: 64, dsbyte: 0x01} }
94+
func NewLegacyKeccak512() hash.Hash {
95+
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak}
96+
}
7897

7998
// Sum224 returns the SHA3-224 digest of the data.
8099
func Sum224(data []byte) (digest [28]byte) {

sha3/sha3.go

+72
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package sha3
77
import (
88
"crypto/subtle"
99
"encoding/binary"
10+
"errors"
1011
"unsafe"
1112

1213
"golang.org/x/sys/cpu"
@@ -170,3 +171,74 @@ func (d *state) Sum(in []byte) []byte {
170171
dup.Read(hash)
171172
return append(in, hash...)
172173
}
174+
175+
const (
176+
magicSHA3 = "sha\x08"
177+
magicShake = "sha\x09"
178+
magicCShake = "sha\x0a"
179+
magicKeccak = "sha\x0b"
180+
// magic || rate || main state || n || sponge direction
181+
marshaledSize = len(magicSHA3) + 1 + 200 + 1 + 1
182+
)
183+
184+
func (d *state) MarshalBinary() ([]byte, error) {
185+
return d.AppendBinary(make([]byte, 0, marshaledSize))
186+
}
187+
188+
func (d *state) AppendBinary(b []byte) ([]byte, error) {
189+
switch d.dsbyte {
190+
case dsbyteSHA3:
191+
b = append(b, magicSHA3...)
192+
case dsbyteShake:
193+
b = append(b, magicShake...)
194+
case dsbyteCShake:
195+
b = append(b, magicCShake...)
196+
case dsbyteKeccak:
197+
b = append(b, magicKeccak...)
198+
default:
199+
panic("unknown dsbyte")
200+
}
201+
// rate is at most 168, and n is at most rate.
202+
b = append(b, byte(d.rate))
203+
b = append(b, d.a[:]...)
204+
b = append(b, byte(d.n), byte(d.state))
205+
return b, nil
206+
}
207+
208+
func (d *state) UnmarshalBinary(b []byte) error {
209+
if len(b) != marshaledSize {
210+
return errors.New("sha3: invalid hash state")
211+
}
212+
213+
magic := string(b[:len(magicSHA3)])
214+
b = b[len(magicSHA3):]
215+
switch {
216+
case magic == magicSHA3 && d.dsbyte == dsbyteSHA3:
217+
case magic == magicShake && d.dsbyte == dsbyteShake:
218+
case magic == magicCShake && d.dsbyte == dsbyteCShake:
219+
case magic == magicKeccak && d.dsbyte == dsbyteKeccak:
220+
default:
221+
return errors.New("sha3: invalid hash state identifier")
222+
}
223+
224+
rate := int(b[0])
225+
b = b[1:]
226+
if rate != d.rate {
227+
return errors.New("sha3: invalid hash state function")
228+
}
229+
230+
copy(d.a[:], b)
231+
b = b[len(d.a):]
232+
233+
n, state := int(b[0]), spongeDirection(b[1])
234+
if n > d.rate {
235+
return errors.New("sha3: invalid hash state")
236+
}
237+
d.n = n
238+
if state != spongeAbsorbing && state != spongeSqueezing {
239+
return errors.New("sha3: invalid hash state")
240+
}
241+
d.state = state
242+
243+
return nil
244+
}

sha3/sha3_test.go

+40-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ package sha3
1313
import (
1414
"bytes"
1515
"compress/flate"
16+
"encoding"
1617
"encoding/hex"
1718
"encoding/json"
1819
"fmt"
@@ -421,11 +422,11 @@ func TestCSHAKEAccumulated(t *testing.T) {
421422
// console.log(bytesToHex(acc.xof(32)));
422423
//
423424
t.Run("cSHAKE128", func(t *testing.T) {
424-
testCSHAKEAccumulated(t, NewCShake128, rate128,
425+
testCSHAKEAccumulated(t, NewCShake128, rateK256,
425426
"bb14f8657c6ec5403d0b0e2ef3d3393497e9d3b1a9a9e8e6c81dbaa5fd809252")
426427
})
427428
t.Run("cSHAKE256", func(t *testing.T) {
428-
testCSHAKEAccumulated(t, NewCShake256, rate256,
429+
testCSHAKEAccumulated(t, NewCShake256, rateK512,
429430
"0baaf9250c6e25f0c14ea5c7f9bfde54c8a922c8276437db28f3895bdf6eeeef")
430431
})
431432
}
@@ -486,6 +487,43 @@ func TestCSHAKELargeS(t *testing.T) {
486487
}
487488
}
488489

490+
func TestMarshalUnmarshal(t *testing.T) {
491+
t.Run("SHA3-224", func(t *testing.T) { testMarshalUnmarshal(t, New224()) })
492+
t.Run("SHA3-256", func(t *testing.T) { testMarshalUnmarshal(t, New256()) })
493+
t.Run("SHA3-384", func(t *testing.T) { testMarshalUnmarshal(t, New384()) })
494+
t.Run("SHA3-512", func(t *testing.T) { testMarshalUnmarshal(t, New512()) })
495+
t.Run("SHAKE128", func(t *testing.T) { testMarshalUnmarshal(t, NewShake128()) })
496+
t.Run("SHAKE256", func(t *testing.T) { testMarshalUnmarshal(t, NewShake256()) })
497+
t.Run("cSHAKE128", func(t *testing.T) { testMarshalUnmarshal(t, NewCShake128([]byte("N"), []byte("S"))) })
498+
t.Run("cSHAKE256", func(t *testing.T) { testMarshalUnmarshal(t, NewCShake256([]byte("N"), []byte("S"))) })
499+
t.Run("Keccak-256", func(t *testing.T) { testMarshalUnmarshal(t, NewLegacyKeccak256()) })
500+
t.Run("Keccak-512", func(t *testing.T) { testMarshalUnmarshal(t, NewLegacyKeccak512()) })
501+
}
502+
503+
// TODO(filippo): move this to crypto/internal/cryptotest.
504+
func testMarshalUnmarshal(t *testing.T, h hash.Hash) {
505+
buf := make([]byte, 200)
506+
rand.Read(buf)
507+
n := rand.Intn(200)
508+
h.Write(buf)
509+
want := h.Sum(nil)
510+
h.Reset()
511+
h.Write(buf[:n])
512+
b, err := h.(encoding.BinaryMarshaler).MarshalBinary()
513+
if err != nil {
514+
t.Errorf("MarshalBinary: %v", err)
515+
}
516+
h.Write(bytes.Repeat([]byte{0}, 200))
517+
if err := h.(encoding.BinaryUnmarshaler).UnmarshalBinary(b); err != nil {
518+
t.Errorf("UnmarshalBinary: %v", err)
519+
}
520+
h.Write(buf[n:])
521+
got := h.Sum(nil)
522+
if !bytes.Equal(got, want) {
523+
t.Errorf("got %x, want %x", got, want)
524+
}
525+
}
526+
489527
// BenchmarkPermutationFunction measures the speed of the permutation function
490528
// with no input data.
491529
func BenchmarkPermutationFunction(b *testing.B) {

sha3/shake.go

+30-12
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ package sha3
1616
// [2] https://doi.org/10.6028/NIST.SP.800-185
1717

1818
import (
19+
"bytes"
1920
"encoding/binary"
21+
"errors"
2022
"hash"
2123
"io"
2224
"math/bits"
@@ -51,14 +53,6 @@ type cshakeState struct {
5153
initBlock []byte
5254
}
5355

54-
// Consts for configuring initial SHA-3 state
55-
const (
56-
dsbyteShake = 0x1f
57-
dsbyteCShake = 0x04
58-
rate128 = 168
59-
rate256 = 136
60-
)
61-
6256
func bytepad(data []byte, rate int) []byte {
6357
out := make([]byte, 0, 9+len(data)+rate-1)
6458
out = append(out, leftEncode(uint64(rate))...)
@@ -112,6 +106,30 @@ func (c *state) Clone() ShakeHash {
112106
return c.clone()
113107
}
114108

109+
func (c *cshakeState) MarshalBinary() ([]byte, error) {
110+
return c.AppendBinary(make([]byte, 0, marshaledSize+len(c.initBlock)))
111+
}
112+
113+
func (c *cshakeState) AppendBinary(b []byte) ([]byte, error) {
114+
b, err := c.state.AppendBinary(b)
115+
if err != nil {
116+
return nil, err
117+
}
118+
b = append(b, c.initBlock...)
119+
return b, nil
120+
}
121+
122+
func (c *cshakeState) UnmarshalBinary(b []byte) error {
123+
if len(b) <= marshaledSize {
124+
return errors.New("sha3: invalid hash state")
125+
}
126+
if err := c.state.UnmarshalBinary(b[:marshaledSize]); err != nil {
127+
return err
128+
}
129+
c.initBlock = bytes.Clone(b[marshaledSize:])
130+
return nil
131+
}
132+
115133
// NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
116134
// Its generic security strength is 128 bits against all attacks if at
117135
// least 32 bytes of its output are used.
@@ -127,11 +145,11 @@ func NewShake256() ShakeHash {
127145
}
128146

129147
func newShake128Generic() *state {
130-
return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
148+
return &state{rate: rateK256, outputLen: 32, dsbyte: dsbyteShake}
131149
}
132150

133151
func newShake256Generic() *state {
134-
return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
152+
return &state{rate: rateK512, outputLen: 64, dsbyte: dsbyteShake}
135153
}
136154

137155
// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
@@ -144,7 +162,7 @@ func NewCShake128(N, S []byte) ShakeHash {
144162
if len(N) == 0 && len(S) == 0 {
145163
return NewShake128()
146164
}
147-
return newCShake(N, S, rate128, 32, dsbyteCShake)
165+
return newCShake(N, S, rateK256, 32, dsbyteCShake)
148166
}
149167

150168
// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
@@ -157,7 +175,7 @@ func NewCShake256(N, S []byte) ShakeHash {
157175
if len(N) == 0 && len(S) == 0 {
158176
return NewShake256()
159177
}
160-
return newCShake(N, S, rate256, 64, dsbyteCShake)
178+
return newCShake(N, S, rateK512, 64, dsbyteCShake)
161179
}
162180

163181
// ShakeSum128 writes an arbitrary-length digest of data into hash.

0 commit comments

Comments
 (0)