Skip to content

Commit

Permalink
std.crypto.chacha: support larger vectors on AVX2 and AVX512 targets (#…
Browse files Browse the repository at this point in the history
…15809)

* std.crypto.chacha: support larger vectors on AVX2 and AVX512 targets

Ryzen 7 7700, ChaCha20/8 stream, long outputs:

Generic: 3268 MiB/s
AVX2   : 6023 MiB/s
AVX512 : 8086 MiB/s

Bump the rand.chacha buffer a tiny bit to take advantage of this.
More than 8 blocks doesn't seem to make any measurable difference.

ChaChaPoly also gets a small performance boost from this, albeit
Poly1305 remains the bottleneck.

Generic:  707 MiB/s
AVX2   :  981 MiB/s
AVX512 : 1202 MiB/s

aarch64 appears to generally benefit from 4-way vectorization.

Verified on Apple Silicon, but also on a Cortex A72.
  • Loading branch information
jedisct1 authored May 22, 2023
1 parent eef9275 commit 5af89b3
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 52 deletions.
194 changes: 143 additions & 51 deletions lib/std/crypto/chacha20.zig
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,98 @@ pub const XChaCha12Poly1305 = XChaChaPoly1305(12);
pub const XChaCha8Poly1305 = XChaChaPoly1305(8);

// Vectorized implementation of the core function
fn ChaChaVecImpl(comptime rounds_nb: usize) type {
fn ChaChaVecImpl(comptime rounds_nb: usize, comptime degree: comptime_int) type {
return struct {
const Lane = @Vector(4, u32);
const Lane = @Vector(4 * degree, u32);
const BlockVec = [4]Lane;

fn initContext(key: [8]u32, d: [4]u32) BlockVec {
const c = "expand 32-byte k";
const constant_le = comptime Lane{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
return BlockVec{
constant_le,
Lane{ key[0], key[1], key[2], key[3] },
Lane{ key[4], key[5], key[6], key[7] },
Lane{ d[0], d[1], d[2], d[3] },
};
switch (degree) {
1 => {
const constant_le = Lane{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
return BlockVec{
constant_le,
Lane{ key[0], key[1], key[2], key[3] },
Lane{ key[4], key[5], key[6], key[7] },
Lane{ d[0], d[1], d[2], d[3] },
};
},
2 => {
const constant_le = Lane{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
return BlockVec{
constant_le,
Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] },
Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] },
Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3] },
};
},
4 => {
const constant_le = Lane{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
return BlockVec{
constant_le,
Lane{ key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3], key[0], key[1], key[2], key[3] },
Lane{ key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7], key[4], key[5], key[6], key[7] },
Lane{ d[0], d[1], d[2], d[3], d[0] +% 1, d[1], d[2], d[3], d[0] +% 2, d[1], d[2], d[3], d[0] +% 3, d[1], d[2], d[3] },
};
},
else => @panic("invalid degree"),
}
}

inline fn chacha20Core(x: *BlockVec, input: BlockVec) void {
x.* = input;

const m0 = switch (degree) {
1 => [_]i32{ 3, 0, 1, 2 },
2 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 },
4 => [_]i32{ 3, 0, 1, 2 } ++ [_]i32{ 7, 4, 5, 6 } ++ [_]i32{ 11, 8, 9, 10 } ++ [_]i32{ 15, 12, 13, 14 },
else => @panic("invalid degree"),
};
const m1 = switch (degree) {
1 => [_]i32{ 2, 3, 0, 1 },
2 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 },
4 => [_]i32{ 2, 3, 0, 1 } ++ [_]i32{ 6, 7, 4, 5 } ++ [_]i32{ 10, 11, 8, 9 } ++ [_]i32{ 14, 15, 12, 13 },
else => @panic("invalid degree"),
};
const m2 = switch (degree) {
1 => [_]i32{ 1, 2, 3, 0 },
2 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 },
4 => [_]i32{ 1, 2, 3, 0 } ++ [_]i32{ 5, 6, 7, 4 } ++ [_]i32{ 9, 10, 11, 8 } ++ [_]i32{ 13, 14, 15, 12 },
else => @panic("invalid degree"),
};

var r: usize = 0;
while (r < rounds_nb) : (r += 2) {
x[0] +%= x[1];
Expand All @@ -112,13 +180,13 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {

x[0] +%= x[1];
x[3] ^= x[0];
x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 3, 0, 1, 2 });
x[0] = @shuffle(u32, x[0], undefined, m0);
x[3] = math.rotl(Lane, x[3], 8);

x[2] +%= x[3];
x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 });
x[3] = @shuffle(u32, x[3], undefined, m1);
x[1] ^= x[2];
x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 1, 2, 3, 0 });
x[2] = @shuffle(u32, x[2], undefined, m2);
x[1] = math.rotl(Lane, x[1], 7);

x[0] +%= x[1];
Expand All @@ -131,24 +199,26 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {

x[0] +%= x[1];
x[3] ^= x[0];
x[0] = @shuffle(u32, x[0], undefined, [_]i32{ 1, 2, 3, 0 });
x[0] = @shuffle(u32, x[0], undefined, m2);
x[3] = math.rotl(Lane, x[3], 8);

x[2] +%= x[3];
x[3] = @shuffle(u32, x[3], undefined, [_]i32{ 2, 3, 0, 1 });
x[3] = @shuffle(u32, x[3], undefined, m1);
x[1] ^= x[2];
x[2] = @shuffle(u32, x[2], undefined, [_]i32{ 3, 0, 1, 2 });
x[2] = @shuffle(u32, x[2], undefined, m0);
x[1] = math.rotl(Lane, x[1], 7);
}
}

inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
var i: usize = 0;
while (i < 4) : (i += 1) {
mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]);
mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]);
mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]);
mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]);
inline fn hashToBytes(comptime dm: usize, out: *[64 * dm]u8, x: BlockVec) void {
for (0..dm) |d| {
var i: usize = 0;
while (i < 4) : (i += 1) {
mem.writeIntLittle(u32, out[64 * d + 16 * i + 0 ..][0..4], x[i][0 + 4 * d]);
mem.writeIntLittle(u32, out[64 * d + 16 * i + 4 ..][0..4], x[i][1 + 4 * d]);
mem.writeIntLittle(u32, out[64 * d + 16 * i + 8 ..][0..4], x[i][2 + 4 * d]);
mem.writeIntLittle(u32, out[64 * d + 16 * i + 12 ..][0..4], x[i][3 + 4 * d]);
}
}
}

Expand All @@ -162,29 +232,33 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {
fn chacha20Xor(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
var ctx = initContext(key, counter);
var x: BlockVec = undefined;
var buf: [64]u8 = undefined;
var buf: [64 * degree]u8 = undefined;
var i: usize = 0;
while (i + 64 <= in.len) : (i += 64) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(buf[0..], x);

var xout = out[i..];
const xin = in[i..];
var j: usize = 0;
while (j < 64) : (j += 1) {
xout[j] = xin[j];
}
j = 0;
while (j < 64) : (j += 1) {
xout[j] ^= buf[j];
inline for ([_]comptime_int{ 4, 2, 1 }) |d| {
while (degree >= d and i + 64 * d <= in.len) : (i += 64 * d) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(d, buf[0 .. 64 * d], x);

var xout = out[i..];
const xin = in[i..];
var j: usize = 0;
while (j < 64 * d) : (j += 1) {
xout[j] = xin[j];
}
j = 0;
while (j < 64 * d) : (j += 1) {
xout[j] ^= buf[j];
}
inline for (0..d) |d_| {
ctx[3][4 * d_] += @intCast(u32, d);
}
}
ctx[3][0] += 1;
}
if (i < in.len) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(buf[0..], x);
hashToBytes(1, buf[0..64], x);

var xout = out[i..];
const xin = in[i..];
Expand All @@ -199,18 +273,22 @@ fn ChaChaVecImpl(comptime rounds_nb: usize) type {
var ctx = initContext(key, counter);
var x: BlockVec = undefined;
var i: usize = 0;
while (i + 64 <= out.len) : (i += 64) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(out[i..][0..64], x);
ctx[3][0] += 1;
inline for ([_]comptime_int{ 4, 2, 1 }) |d| {
while (degree >= d and i + 64 * d <= out.len) : (i += 64 * d) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(d, out[i..][0 .. 64 * d], x);
inline for (0..d) |d_| {
ctx[3][4 * d_] += @intCast(u32, d);
}
}
}
if (i < out.len) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);

var buf: [64]u8 = undefined;
hashToBytes(buf[0..], x);
hashToBytes(1, buf[0..], x);
@memcpy(out[i..], buf[0 .. out.len - i]);
}
}
Expand Down Expand Up @@ -399,7 +477,21 @@ fn ChaChaNonVecImpl(comptime rounds_nb: usize) type {
}

fn ChaChaImpl(comptime rounds_nb: usize) type {
return if (builtin.cpu.arch == .x86_64) ChaChaVecImpl(rounds_nb) else ChaChaNonVecImpl(rounds_nb);
switch (builtin.cpu.arch) {
.x86_64 => {
const has_avx2 = std.Target.x86.featureSetHas(builtin.cpu.features, .avx2);
const has_avx512f = std.Target.x86.featureSetHas(builtin.cpu.features, .avx512f);
if (has_avx512f) return ChaChaVecImpl(rounds_nb, 4);
if (has_avx2) return ChaChaVecImpl(rounds_nb, 2);
return ChaChaVecImpl(rounds_nb, 1);
},
.aarch64 => {
const has_neon = std.Target.aarch64.featureSetHas(builtin.cpu.features, .neon);
if (has_neon) return ChaChaVecImpl(rounds_nb, 4);
return ChaChaNonVecImpl(rounds_nb);
},
else => return ChaChaNonVecImpl(rounds_nb),
}
}

fn keyToWords(key: [32]u8) [8]u32 {
Expand Down
2 changes: 1 addition & 1 deletion lib/std/rand/ChaCha.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const Self = @This();

const Cipher = std.crypto.stream.chacha.ChaCha8IETF;

const State = [2 * Cipher.block_length]u8;
const State = [8 * Cipher.block_length]u8;

state: State,
offset: usize,
Expand Down

0 comments on commit 5af89b3

Please sign in to comment.