diff --git a/src/crypto/cipher/benchmark_test.go b/src/crypto/cipher/benchmark_test.go index 181d08c9b14699..1a5b1b1ddd552d 100644 --- a/src/crypto/cipher/benchmark_test.go +++ b/src/crypto/cipher/benchmark_test.go @@ -65,12 +65,12 @@ func BenchmarkAESGCM(b *testing.B) { } } -func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte) { +func benchmarkAESStream(b *testing.B, mode func(cipher.Block, []byte) cipher.Stream, buf []byte, keySize int) { b.SetBytes(int64(len(buf))) - var key [16]byte + key := make([]byte, keySize) var iv [16]byte - aes, _ := aes.NewCipher(key[:]) + aes, _ := aes.NewCipher(key) stream := mode(aes, iv[:]) b.ResetTimer() @@ -87,15 +87,20 @@ const almost1K = 1024 - 5 const almost8K = 8*1024 - 5 func BenchmarkAESCTR(b *testing.B) { - b.Run("50", func(b *testing.B) { - benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50)) - }) - b.Run("1K", func(b *testing.B) { - benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K)) - }) - b.Run("8K", func(b *testing.B) { - benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K)) - }) + for _, keyBits := range []int{128, 192, 256} { + keySize := keyBits / 8 + b.Run(strconv.Itoa(keyBits), func(b *testing.B) { + b.Run("50", func(b *testing.B) { + benchmarkAESStream(b, cipher.NewCTR, make([]byte, 50), keySize) + }) + b.Run("1K", func(b *testing.B) { + benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost1K), keySize) + }) + b.Run("8K", func(b *testing.B) { + benchmarkAESStream(b, cipher.NewCTR, make([]byte, almost8K), keySize) + }) + }) + } } func BenchmarkAESCBCEncrypt1K(b *testing.B) { diff --git a/src/crypto/cipher/ctr_aes_test.go b/src/crypto/cipher/ctr_aes_test.go index 9b7d30e2164422..1d8ae78674ebe2 100644 --- a/src/crypto/cipher/ctr_aes_test.go +++ b/src/crypto/cipher/ctr_aes_test.go @@ -17,6 +17,7 @@ import ( "crypto/internal/boring" "crypto/internal/cryptotest" fipsaes "crypto/internal/fips140/aes" + "encoding/binary" "encoding/hex" "fmt" "math/rand" @@ -117,6 +118,60 @@ func makeTestingCiphers(aesBlock cipher.Block, iv []byte) (genericCtr, multibloc return cipher.NewCTR(wrap(aesBlock), iv), cipher.NewCTR(aesBlock, iv) } +// TestCTR_AES_blocks8FastPathMatchesGeneric ensures the overlow aware branch +// produces identical keystreams to the generic counter walker across +// representative IVs, including near-overflow cases. +func TestCTR_AES_blocks8FastPathMatchesGeneric(t *testing.T) { + key := make([]byte, aes.BlockSize) + block, err := aes.NewCipher(key) + if err != nil { + t.Fatal(err) + } + if _, ok := block.(*fipsaes.Block); !ok { + t.Skip("requires crypto/internal/fips140/aes") + } + + keystream := make([]byte, 8*aes.BlockSize) + + testCases := []struct { + name string + hi uint64 + lo uint64 + }{ + {"Zero", 0, 0}, + {"NearOverflowMinus7", 1, ^uint64(0) - 7}, + {"NearOverflowMinus6", 2, ^uint64(0) - 6}, + {"Overflow", 0, ^uint64(0)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var iv [aes.BlockSize]byte + binary.BigEndian.PutUint64(iv[0:8], tc.hi) + binary.BigEndian.PutUint64(iv[8:], tc.lo) + + generic, multiblock := makeTestingCiphers(block, iv[:]) + + genericOut := make([]byte, len(keystream)) + multiblockOut := make([]byte, len(keystream)) + + generic.XORKeyStream(genericOut, keystream) + multiblock.XORKeyStream(multiblockOut, keystream) + + if !bytes.Equal(multiblockOut, genericOut) { + t.Fatalf("mismatch for iv %#x:%#x\n"+ + "asm keystream: %x\n"+ + "gen keystream: %x\n"+ + "asm counters: %x\n"+ + "gen counters: %x", + tc.hi, tc.lo, multiblockOut, genericOut, + extractCounters(block, multiblockOut), + extractCounters(block, genericOut)) + } + }) + } +} + func randBytes(t *testing.T, r *rand.Rand, count int) []byte { t.Helper() buf := make([]byte, count) @@ -297,3 +352,12 @@ func TestCTR_AES_multiblock_XORKeyStreamAt(t *testing.T) { }) } } + +func extractCounters(block cipher.Block, keystream []byte) []byte { + blockSize := block.BlockSize() + res := make([]byte, len(keystream)) + for i := 0; i < len(keystream); i += blockSize { + block.Decrypt(res[i:i+blockSize], keystream[i:i+blockSize]) + } + return res +} diff --git a/src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go b/src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go index 775d4a8acc5969..e3dbdf66d70e1c 100644 --- a/src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go +++ b/src/crypto/internal/fips140/aes/_asm/ctr/ctr_amd64_asm.go @@ -40,19 +40,79 @@ func ctrBlocks(numBlocks int) { bswap := XMM() MOVOU(bswapMask(), bswap) - blocks := make([]VecVirtual, 0, numBlocks) + blocks := make([]VecVirtual, numBlocks) + + // For the 8-block case we optimize counter generation. We build the first + // counter as usual, then check whether the remaining seven increments will + // overflow. When they do not (the common case) we keep the work entirely in + // XMM registers to avoid expensive general-purpose -> XMM moves. Otherwise + // we fall back to the traditional scalar path. + if numBlocks == 8 { + for i := range blocks { + blocks[i] = XMM() + } - // Lay out counter block plaintext. - for i := 0; i < numBlocks; i++ { - x := XMM() - blocks = append(blocks, x) - - MOVQ(ivlo, x) - PINSRQ(Imm(1), ivhi, x) - PSHUFB(bswap, x) - if i < numBlocks-1 { - ADDQ(Imm(1), ivlo) - ADCQ(Imm(0), ivhi) + base := XMM() + tmp := GP64() + addVec := XMM() + + MOVQ(ivlo, blocks[0]) + PINSRQ(Imm(1), ivhi, blocks[0]) + MOVAPS(blocks[0], base) + PSHUFB(bswap, blocks[0]) + + // Check whether any of these eight counters will overflow. + MOVQ(ivlo, tmp) + ADDQ(Imm(uint64(numBlocks-1)), tmp) + slowLabel := fmt.Sprintf("ctr%d_slow", numBlocks) + doneLabel := fmt.Sprintf("ctr%d_done", numBlocks) + JC(LabelRef(slowLabel)) + + // Fast branch: create an XMM increment vector containing the value 1. + // Adding it to the base counter yields each subsequent counter. + XORQ(tmp, tmp) + INCQ(tmp) + PXOR(addVec, addVec) + PINSRQ(Imm(0), tmp, addVec) + + for i := 1; i < numBlocks; i++ { + PADDQ(addVec, base) + MOVAPS(base, blocks[i]) + } + JMP(LabelRef(doneLabel)) + + Label(slowLabel) + ADDQ(Imm(1), ivlo) + ADCQ(Imm(0), ivhi) + for i := 1; i < numBlocks; i++ { + MOVQ(ivlo, blocks[i]) + PINSRQ(Imm(1), ivhi, blocks[i]) + if i < numBlocks-1 { + ADDQ(Imm(1), ivlo) + ADCQ(Imm(0), ivhi) + } + } + + Label(doneLabel) + + // Convert little-endian counters to big-endian after the branch since + // both paths share the same shuffle sequence. + for i := 1; i < numBlocks; i++ { + PSHUFB(bswap, blocks[i]) + } + } else { + // Lay out counter block plaintext. + for i := 0; i < numBlocks; i++ { + x := XMM() + blocks[i] = x + + MOVQ(ivlo, x) + PINSRQ(Imm(1), ivhi, x) + PSHUFB(bswap, x) + if i < numBlocks-1 { + ADDQ(Imm(1), ivlo) + ADCQ(Imm(0), ivhi) + } } } diff --git a/src/crypto/internal/fips140/aes/ctr_amd64.s b/src/crypto/internal/fips140/aes/ctr_amd64.s index e6710834dd27e6..deef3e7705a5b3 100644 --- a/src/crypto/internal/fips140/aes/ctr_amd64.s +++ b/src/crypto/internal/fips140/aes/ctr_amd64.s @@ -286,41 +286,68 @@ TEXT ·ctrBlocks8Asm(SB), $0-48 MOVOU bswapMask<>+0(SB), X0 MOVQ SI, X1 PINSRQ $0x01, DI, X1 + MOVAPS X1, X8 PSHUFB X0, X1 + MOVQ SI, R8 + ADDQ $0x07, R8 + JC ctr8_slow + XORQ R8, R8 + INCQ R8 + PXOR X9, X9 + PINSRQ $0x00, R8, X9 + PADDQ X9, X8 + MOVAPS X8, X2 + PADDQ X9, X8 + MOVAPS X8, X3 + PADDQ X9, X8 + MOVAPS X8, X4 + PADDQ X9, X8 + MOVAPS X8, X5 + PADDQ X9, X8 + MOVAPS X8, X6 + PADDQ X9, X8 + MOVAPS X8, X7 + PADDQ X9, X8 + MOVAPS X8, X8 + JMP ctr8_done + +ctr8_slow: ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X2 PINSRQ $0x01, DI, X2 - PSHUFB X0, X2 ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X3 PINSRQ $0x01, DI, X3 - PSHUFB X0, X3 ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X4 PINSRQ $0x01, DI, X4 - PSHUFB X0, X4 ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X5 PINSRQ $0x01, DI, X5 - PSHUFB X0, X5 ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X6 PINSRQ $0x01, DI, X6 - PSHUFB X0, X6 ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X7 PINSRQ $0x01, DI, X7 - PSHUFB X0, X7 ADDQ $0x01, SI ADCQ $0x00, DI MOVQ SI, X8 PINSRQ $0x01, DI, X8 + +ctr8_done: + PSHUFB X0, X2 + PSHUFB X0, X3 + PSHUFB X0, X4 + PSHUFB X0, X5 + PSHUFB X0, X6 + PSHUFB X0, X7 PSHUFB X0, X8 MOVUPS (CX), X0 PXOR X0, X1