From 5af89b3dccf7ee375f68e9cd3ee4980fef89e38f Mon Sep 17 00:00:00 2001 From: Frank Denis <124872+jedisct1@users.noreply.github.com> Date: Mon, 22 May 2023 20:33:35 +0200 Subject: [PATCH] std.crypto.chacha: support larger vectors on AVX2 and AVX512 targets (#15809) * std.crypto.chacha: support larger vectors on AVX2 and AVX512 targets Ryzen 7 7700, ChaCha20/8 stream, long outputs: Generic: 3268 MiB/s AVX2 : 6023 MiB/s AVX512 : 8086 MiB/s Bump the rand.chacha buffer a tiny bit to take advantage of this. More than 8 blocks doesn't seem to make any measurable difference. ChaChaPoly also gets a small performance boost from this, albeit Poly1305 remains the bottleneck. Generic: 707 MiB/s AVX2 : 981 MiB/s AVX512 : 1202 MiB/s aarch64 appears to generally benefit from 4-way vectorization. Verified on Apple Silicon, but also on a Cortex A72. --- lib/std/crypto/chacha20.zig | 194 ++++++++++++++++++++++++++---------- lib/std/rand/ChaCha.zig | 2 +- 2 files changed, 144 insertions(+), 52 deletions(-) diff --git a/lib/std/crypto/chacha20.zig b/lib/std/crypto/chacha20.zig index bffc70f50075..5915a16ece57 100644 --- a/lib/std/crypto/chacha20.zig +++ b/lib/std/crypto/chacha20.zig @@ -76,30 +76,98 @@ pub const XChaCha12Poly1305 = XChaChaPoly1305(12); pub const XChaCha8Poly1305 = XChaChaPoly1305(8); // Vectorized implementation of the core function -fn ChaChaVecImpl(comptime rounds_nb: usize) type { +fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type { return struct { - const Lane = @Vector(4, u32); + const Lane = @Vector(4 * degree, u32); const BlockVec = [4]Lane; fn initContext(key: [8]u32, d: [4]u32) BlockVec { const c = "expand 32-byte k"; - const constant_le = comptime Lane{ - mem.readIntLittle(u32, c[0..4]), - mem.readIntLittle(u32, c[4..8]), - mem.readIntLittle(u32, c[8..12]), - mem.readIntLittle(u32, c[12..16]), - }; - return BlockVec{ - constant_le, - Lane{ key[0], key[1], key[2], key[3] }, - Lane{ key[4], key[5], key[6], key[7] }, - Lane{ d[0], d[1], d[2], d[3] }, - }; + switch (degree) { + 1 => { + const constant_le = Lane{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le, + Lane{ key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7] }, + Lane{ d[0], d[1], d[2], d[3] }, + }; + }, + 2 => { + const constant_le = Lane{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le, + Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] }, + Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3] }, + }; + }, + 4 => { + const constant_le = Lane{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le, + Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] }, + Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3], d[0] +% 2, d[1], d[2], d[3], d[0] +% 3, d[1], d[2], d[3] }, + }; + }, + else => @panic("invalid degree"), + } } inline fn chacha20Core(x: *BlockVec, input: BlockVec) void { x.* = input; + const m0 = switch (degree) { + 1 => [_]i32{ 3, 0, 1, 2 }, + 2 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 }, + 4 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 } ++ [_]i32{ 11, 8, 9, 10 } ++ [_]i32{ 15, 12, 13, 14 }, + else => @panic("invalid degree"), + }; + const m1 = switch (degree) { + 1 => [_]i32{ 2, 3, 0, 1 }, + 2 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 }, + 4 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 } ++ [_]i32{ 10, 11, 8, 9 } ++ [_]i32{ 14, 15, 12, 13 }, + else => @panic("invalid degree"), + }; + const m2 = switch (degree) { + 1 => [_]i32{ 1, 2, 3, 0 }, + 2 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 }, + 4 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 } ++ [_]i32{ 9, 10, 11, 8 } ++ [_]i32{ 13, 14, 15, 12 }, + else => @panic("invalid degree"), + }; + var r: usize = 0; while (r < rounds_nb) : (r += 2) { x[0] +%= x[1]; @@ -112,13 +180,13 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type { x[0] +%= x[1]; x[3] ^= x[0]; - x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 3, 0, 1, 2 }); + x[0] = @shuffle(u32, x[0], undefined, m0); x[3] = math.rotl(Lane, x[3], 8); x[2] +%= x[3]; - x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 }); + x[3] = @shuffle(u32, x[3], undefined, m1); x[1] ^= x[2]; - x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 1, 2, 3, 0 }); + x[2] = @shuffle(u32, x[2], undefined, m2); x[1] = math.rotl(Lane, x[1], 7); x[0] +%= x[1]; @@ -131,24 +199,26 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type { x[0] +%= x[1]; x[3] ^= x[0]; - x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 1, 2, 3, 0 }); + x[0] = @shuffle(u32, x[0], undefined, m2); x[3] = math.rotl(Lane, x[3], 8); x[2] +%= x[3]; - x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 }); + x[3] = @shuffle(u32, x[3], undefined, m1); x[1] ^= x[2]; - x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 3, 0, 1, 2 }); + x[2] = @shuffle(u32, x[2], undefined, m0); x[1] = math.rotl(Lane, x[1], 7); } } - inline fn hashToBytes(out: *[64]u8, x: BlockVec) void { - var i: usize = 0; - while (i < 4) : (i += 1) { - mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); - mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); - mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); - mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); + inline fn hashToBytes(comptime dm: usize, out: *[64 * dm]u8, x: BlockVec) void { + for (0..dm) |d| { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[64 * d + 16 * i + 0 ..][0..4], x[i][0 + 4 * d]); + mem.writeIntLittle(u32, out[64 * d + 16 * i + 4 ..][0..4], x[i][1 + 4 * d]); + mem.writeIntLittle(u32, out[64 * d + 16 * i + 8 ..][0..4], x[i][2 + 4 * d]); + mem.writeIntLittle(u32, out[64 * d + 16 * i + 12 ..][0..4], x[i][3 + 4 * d]); + } } } @@ -162,29 +232,33 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type { fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { var ctx = initContext(key, counter); var x: BlockVec = undefined; - var buf: [64]u8 = undefined; + var buf: [64 * degree]u8 = undefined; var i: usize = 0; - while (i + 64 <= in.len) : (i += 64) { - chacha20Core(x[0..], ctx); - contextFeedback(&x, ctx); - hashToBytes(buf[0..], x); - - var xout = out[i..]; - const xin = in[i..]; - var j: usize = 0; - while (j < 64) : (j += 1) { - xout[j] = xin[j]; - } - j = 0; - while (j < 64) : (j += 1) { - xout[j] ^= buf[j]; + inline for ([_]comptime_int{ 4, 2, 1 }) |d| { + while (degree >= d and i + 64 * d <= in.len) : (i += 64 * d) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(d, buf[0 .. 64 * d], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64 * d) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64 * d) : (j += 1) { + xout[j] ^= buf[j]; + } + inline for (0..d) |d_| { + ctx[3][4 * d_] += @intCast(u32, d); + } } - ctx[3][0] += 1; } if (i < in.len) { chacha20Core(x[0..], ctx); contextFeedback(&x, ctx); - hashToBytes(buf[0..], x); + hashToBytes(1, buf[0..64], x); var xout = out[i..]; const xin = in[i..]; @@ -199,18 +273,22 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type { var ctx = initContext(key, counter); var x: BlockVec = undefined; var i: usize = 0; - while (i + 64 <= out.len) : (i += 64) { - chacha20Core(x[0..], ctx); - contextFeedback(&x, ctx); - hashToBytes(out[i..][0..64], x); - ctx[3][0] += 1; + inline for ([_]comptime_int{ 4, 2, 1 }) |d| { + while (degree >= d and i + 64 * d <= out.len) : (i += 64 * d) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(d, out[i..][0 .. 64 * d], x); + inline for (0..d) |d_| { + ctx[3][4 * d_] += @intCast(u32, d); + } + } } if (i < out.len) { chacha20Core(x[0..], ctx); contextFeedback(&x, ctx); var buf: [64]u8 = undefined; - hashToBytes(buf[0..], x); + hashToBytes(1, buf[0..], x); @memcpy(out[i..], buf[0 .. out.len - i]); } } @@ -399,7 +477,21 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type { } fn ChaChaImpl(comptime rounds_nb: usize) type { - return if (builtin.cpu.arch == .x86_64) ChaChaVecImpl(rounds_nb) else ChaChaNonVecImpl(rounds_nb); + switch (builtin.cpu.arch) { + .x86_64 => { + const has_avx2 = std.Target.x86.featureSetHas(builtin.cpu.features, .avx2); + const has_avx512f = std.Target.x86.featureSetHas(builtin.cpu.features, .avx512f); + if (has_avx512f) return ChaChaVecImpl(rounds_nb, 4); + if (has_avx2) return ChaChaVecImpl(rounds_nb, 2); + return ChaChaVecImpl(rounds_nb, 1); + }, + .aarch64 => { + const has_neon = std.Target.aarch64.featureSetHas(builtin.cpu.features, .neon); + if (has_neon) return ChaChaVecImpl(rounds_nb, 4); + return ChaChaNonVecImpl(rounds_nb); + }, + else => return ChaChaNonVecImpl(rounds_nb), + } } fn keyToWords(key: [32]u8) [8]u32 { diff --git a/lib/std/rand/ChaCha.zig b/lib/std/rand/ChaCha.zig index 3878fb25c862..75f62c9a4723 100644 --- a/lib/std/rand/ChaCha.zig +++ b/lib/std/rand/ChaCha.zig @@ -10,7 +10,7 @@ const Self = @This(); const Cipher = std.crypto.stream.chacha.ChaCha8IETF; -const State = [2 * Cipher.block_length]u8; +const State = [8 * Cipher.block_length]u8; state: State, offset: usize,