diff --git a/build.zig b/build.zig index 06d3603..39f014c 100644 --- a/build.zig +++ b/build.zig @@ -48,7 +48,7 @@ pub fn build(b: *std.Build) void { // Creates a step for unit testing. This only builds the test executable // but does not run it. const lib_unit_tests = b.addTest(.{ - .root_source_file = b.path("src/root.zig"), + .root_source_file = b.path("src/main.zig"), .target = target, .optimize = optimize, }); diff --git a/src/io.zig b/src/io.zig index 88af440..b4d1841 100644 --- a/src/io.zig +++ b/src/io.zig @@ -1,7 +1,7 @@ const std = @import("std"); -const Ctx = @import("std/http/Client.zig").Ctx; -const Cbk = @import("std/http/Client.zig").Cbk; +pub const Ctx = @import("std/http/Client.zig").Ctx; +pub const Cbk = @import("std/http/Client.zig").Cbk; pub const Blocking = struct { pub fn connect( diff --git a/src/main.zig b/src/main.zig index 485e5db..0d05e54 100644 --- a/src/main.zig +++ b/src/main.zig @@ -31,7 +31,7 @@ pub fn main() !void { var ctx = try Client.Ctx.init(&loop, &req); defer ctx.deinit(); - var server_header_buffer: [2048]u8 = undefined; + var server_header_buffer: [1024 * 1024]u8 = undefined; try client.async_open( .GET, diff --git a/src/std/crypto/tls/Client.zig b/src/std/crypto/tls/Client.zig deleted file mode 100644 index cad730f..0000000 --- a/src/std/crypto/tls/Client.zig +++ /dev/null @@ -1,1968 +0,0 @@ -const std = @import("std"); -const tls = std.crypto.tls; -pub const Client = @This(); -const net = std.net; -const mem = std.mem; -const crypto = std.crypto; -const assert = std.debug.assert; -const Certificate = std.crypto.Certificate; - -const max_ciphertext_len = tls.max_ciphertext_len; -const hkdfExpandLabel = tls.hkdfExpandLabel; -const int2 = tls.int2; -const int3 = tls.int3; -const array = tls.array; -const enum_array = tls.enum_array; - -const Cbk = @import("../../http/Client.zig").Cbk; -const Loop = @import("../../http/Client.zig").Loop; -const Ctx = @import("../../http/Client.zig").Ctx; - -read_seq: u64, -write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. -partial_ciphertext_end: u15, -/// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. -received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. -allow_truncation_attacks: bool = false, -application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided, indicating end-of-stream. - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } -}; - -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, - }; -} - -/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which -/// must conform to `StreamInterface`. -/// -/// `host` is only borrowed during this function call. -pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { - const host_len: u16 = @intCast(host.len); - - var random_buffer: [128]u8 = undefined; - crypto.random.bytes(&random_buffer); - const hello_rand = random_buffer[0..32].*; - const legacy_session_id = random_buffer[32..64].*; - const x25519_kp_seed = random_buffer[64..96].*; - const secp256r1_kp_seed = random_buffer[96..128].*; - - const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const kyber768_kp = crypto.kem.kyber_d00.Kyber768.KeyPair.create(null) catch {}; - - const extensions_payload = - tls.extension(.supported_versions, [_]u8{ - 0x02, // byte length of supported versions - 0x03, 0x04, // TLS 1.3 - }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ - .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - .ed25519, - })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ - .x25519_kyber768d00, - .secp256r1, - .x25519, - })) ++ tls.extension( - .key_share, - array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, x25519_kp.public_key) ++ - int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++ - int2(@intFromEnum(tls.NamedGroup.x25519_kyber768d00)) ++ - array(1, x25519_kp.public_key ++ kyber768_kp.public_key.toBytes())), - ) ++ - int2(@intFromEnum(tls.ExtensionType.server_name)) ++ - int2(host_len + 5) ++ // byte length of this extension payload - int2(host_len + 3) ++ // server_name_list byte count - [1]u8{0x00} ++ // name_type - int2(host_len); - - const extensions_header = - int2(@intCast(extensions_payload.len + host_len)) ++ - extensions_payload; - - const legacy_compression_methods = 0x0100; - - const client_hello = - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - hello_rand ++ - [1]u8{32} ++ legacy_session_id ++ - cipher_suites ++ - int2(legacy_compression_methods) ++ - extensions_header; - - const out_handshake = - [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int3(@intCast(client_hello.len + host_len)) ++ - client_hello; - - const plaintext_header = [_]u8{ - @intFromEnum(tls.ContentType.handshake), - 0x03, 0x01, // legacy_record_version - } ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake; - - { - var iovecs = [_]std.posix.iovec_const{ - .{ - .base = &plaintext_header, - .len = plaintext_header.len, - }, - .{ - .base = host.ptr, - .len = host.len, - }, - }; - try stream.writevAll(&iovecs); - } - - const client_hello_bytes1 = plaintext_header[5..]; - - var handshake_cipher: tls.HandshakeCipher = undefined; - var handshake_buffer: [8000]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; - { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_record_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - const server_hello_fragment = d.buf[d.idx..][0..record_len]; - var ptd = try d.sub(record_len); - switch (ct) { - .alert => { - try ptd.ensure(2); - const level = ptd.decode(tls.AlertLevel); - const desc = ptd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - try ptd.ensure(4); - const handshake_type = ptd.decode(tls.HandshakeType); - if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; - const length = ptd.decode(u24); - var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); - const legacy_version = hsd.decode(u16); - const random = hsd.array(32); - if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { - // This is a HelloRetryRequest message. This client implementation - // does not expect to get one. - return error.TlsUnexpectedMessage; - } - const legacy_session_id_echo_len = hsd.decode(u8); - if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = hsd.array(32); - if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) - return error.TlsIllegalParameter; - const cipher_suite_tag = hsd.decode(tls.CipherSuite); - hsd.skip(1); // legacy_compression_method - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - var supported_version: u16 = 0; - var shared_key: []const u8 = undefined; - var have_shared_key = false; - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version != 0) return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (have_shared_key) return error.TlsIllegalParameter; - have_shared_key = true; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - switch (named_group) { - .x25519_kyber768d00 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.kyber_d00.Kyber768.ciphertext_length; - if (key_size != hksl) - return error.TlsIllegalParameter; - const server_ks = extd.array(hksl); - - shared_key = &((crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_ks[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (kyber768_kp.secret_key.decaps( - server_ks[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (key_size != ksl) return error.TlsIllegalParameter; - const server_pub_key = extd.array(ksl); - - shared_key = &(crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_pub_key.*, - ) catch return error.TlsDecryptFailure); - }, - .secp256r1 => { - const server_pub_key = extd.slice(key_size); - - const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; - const pk = PublicKey.fromSec1(server_pub_key) catch { - return error.TlsDecryptFailure; - }; - const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { - return error.TlsDecryptFailure; - }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => {}, - } - } - if (!have_shared_key) return error.TlsIllegalParameter; - - const tls_version = if (supported_version == 0) legacy_version else supported_version; - if (tls_version != @intFromEnum(tls.ProtocolVersion.tls_1_3)) - return error.TlsIllegalParameter; - - switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - => |tag| { - const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ - .handshake_secret = undefined, - .master_secret = undefined, - .client_handshake_key = undefined, - .server_handshake_key = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, - .client_handshake_iv = undefined, - .server_handshake_iv = undefined, - .transcript_hash = P.Hash.init(.{}), - }); - const p = &@field(handshake_cipher, @tagName(tag)); - p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 - p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(server_hello_fragment); - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => return error.TlsUnexpectedMessage, - } - } - - // This is used for two purposes: - // * Detect whether a certificate is the first one presented, in which case - // we need to verify the host name. - // * Flip back and forth between the two cleartext buffers in order to keep - // the previous certificate in memory so that it can be verified by the - // next one. - var cert_index: usize = 0; - var read_seq: u64 = 0; - var prev_cert: Certificate.Parsed = undefined; - // Set to true once a trust chain has been established from the first - // certificate to a root CA. - const HandshakeState = enum { - /// In this state we expect only an encrypted_extensions message. - encrypted_extensions, - /// In this state we expect certificate messages. - certificate, - /// In this state we expect certificate or certificate_verify messages. - /// certificate messages are ignored since the trust chain is already - /// established. - trust_chain_established, - /// In this state, we expect only the finished message. - finished, - }; - var handshake_state: HandshakeState = .encrypted_extensions; - var cleartext_bufs: [2][8000]u8 = undefined; - var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [600]u8 = undefined; - var main_cert_pub_key_len: u16 = undefined; - const now_sec = std.time.timestamp(); - - while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..5]; - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - var record_decoder = try d.sub(record_len); - switch (ct) { - .change_cipher_spec => { - try record_decoder.ensure(1); - if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; - }, - .application_data => { - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - - const cleartext = switch (handshake_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ciphertext_len = record_len - P.AEAD.tag_length; - try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); - const ciphertext = record_decoder.slice(ciphertext_len); - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; - const auth_tag = record_decoder.array(P.AEAD.tag_length).*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_handshake_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(read_seq))); - break :nonce @as(V, p.server_handshake_iv) ^ operand; - }; - read_seq += 1; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch - return error.TlsBadRecordMac; - break :c cleartext; - }, - }; - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - if (inner_ct != .handshake) return error.TlsUnexpectedMessage; - - var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); - while (true) { - try ctd.ensure(4); - const handshake_type = ctd.decode(tls.HandshakeType); - const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); - const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; - const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - try hsd.ensure(2); - const total_ext_size = hsd.decode(u16); - var all_extd = try hsd.sub(total_ext_size); - while (!all_extd.eof()) { - try all_extd.ensure(4); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - const extd = try all_extd.sub(ext_size); - _ = extd; - switch (et) { - .server_name => {}, - else => {}, - } - } - }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - try hsd.ensure(1 + 4); - const cert_req_ctx_len = hsd.decode(u8); - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = hsd.decode(u24); - var certs_decoder = try hsd.sub(certs_size); - while (!certs_decoder.eof()) { - try certs_decoder.ensure(3); - const cert_size = certs_decoder.decode(u24); - const certd = try certs_decoder.sub(cert_size); - - const subject_cert: Certificate = .{ - .buffer = certd.buf, - .index = @intCast(certd.idx), - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - try subject.verifyHostName(host); - - // Keep track of the public key for the - // certificate_verify message later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); - main_cert_pub_key_len = @intCast(pub_key.len); - } else { - try prev_cert.verify(subject, now_sec); - } - - if (ca_bundle.verify(subject, now_sec)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; - - try certs_decoder.ensure(2); - const total_ext_size = certs_decoder.decode(u16); - const all_extd = try certs_decoder.sub(total_ext_size); - _ = all_extd; - } - }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, - } - - try hsd.ensure(4); - const scheme = hsd.decode(tls.SignatureScheme); - const sig_len = hsd.decode(u16); - try hsd.ensure(sig_len); - const encoded_sig = hsd.slice(sig_len); - const max_digest_len = 64; - var verify_buffer: [64 + 34 + max_digest_len]u8 = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); - - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = SchemeHash(comptime_scheme); - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = SchemeEddsa(comptime_scheme); - if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); - if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - else => { - return error.TlsBadSignatureScheme; - }, - } - }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @intFromEnum(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { - const P = @TypeOf(p.*); - const finished_digest = p.transcript_hash.peek(); - p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; - const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @intFromEnum(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; - - const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - - var finished_msg = [_]u8{ - @intFromEnum(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); - - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; - const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); - - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - var both_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &both_msgs, - .len = both_msgs.len, - }}; - try stream.writevAll(&both_msgs_vec); - - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ - .client_secret = client_secret, - .server_secret = server_secret, - .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), - .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), - .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), - .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); - }, - }; - const leftover = d.rest(); - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - if (ctd.eof()) break; - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); - } -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } -} - -// TODO: should it be better to use iovecs (as with the non-async version)? -pub fn async_writeAll( - c: *Client, - stream: anytype, - bytes: []const u8, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - const prepared = prepareCiphertextRecordAsBuf(c, &ciphertext_buf, bytes, .application_data); - _ = prepareCiphertextRecordAsBuf( - c, - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ); - - return try stream.async_writeAll(ciphertext_buf[0..prepared.ciphertext_end], ctx, cbk); -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.posix.iovec_const = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } - - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].len) { - const encrypted_amt = iovecs_buf[i].len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; - } - iovecs_buf[i].base += amt; - iovecs_buf[i].len -= amt; - } -} - -fn prepareCiphertextRecord( - c: *Client, - iovecs: []std.posix.iovec_const, - ciphertext_buf: []u8, - bytes: []const u8, - inner_content_type: tls.ContentType, -) struct { - iovec_end: usize, - ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, -} { - // Due to the trailing inner content type byte in the ciphertext, we need - // an additional buffer for storing the cleartext into before encrypting. - var cleartext_buf: [max_ciphertext_len]u8 = undefined; - var ciphertext_end: usize = 0; - var iovec_end: usize = 0; - var bytes_i: usize = 0; - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; - while (true) { - const encrypted_content_len: u16 = @intCast(@min( - @min(bytes.len - bytes_i, tls.max_cipertext_inner_record_len), - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), - )); - if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, - .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, - }; - - @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); - bytes_i += encrypted_content_len; - const ciphertext_len = encrypted_content_len + 1; - const cleartext = cleartext_buf[0..ciphertext_len]; - - const record_start = ciphertext_end; - const ad = ciphertext_buf[ciphertext_end..][0..5]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); - ciphertext_end += ad.len; - const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; - ciphertext_end += ciphertext_len; - const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += auth_tag.len; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.client_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); - break :nonce @as(V, p.client_iv) ^ operand; - }; - c.write_seq += 1; // TODO send key_update on overflow - P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; - } - }, - } -} - -fn prepareCiphertextRecordAsBuf( - c: *Client, - ciphertext_buf: []u8, - bytes: []const u8, - inner_content_type: tls.ContentType, -) struct { - ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, -} { - // Due to the trailing inner content type byte in the ciphertext, we need - // an additional buffer for storing the cleartext into before encrypting. - var cleartext_buf: [max_ciphertext_len]u8 = undefined; - var ciphertext_end: usize = 0; - var bytes_i: usize = 0; - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const V = @Vector(P.AEAD.nonce_length, u8); - const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; - while (true) { - const encrypted_content_len: u16 = @intCast(@min( - @min(bytes.len - bytes_i, max_ciphertext_len - 1), - ciphertext_buf.len - close_notify_alert_reserved - - overhead_len - ciphertext_end, - )); - if (encrypted_content_len == 0) return .{ - .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, - }; - - @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); - bytes_i += encrypted_content_len; - const ciphertext_len = encrypted_content_len + 1; - const cleartext = cleartext_buf[0..ciphertext_len]; - - const ad = ciphertext_buf[ciphertext_end..][0..5]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); - ciphertext_end += ad.len; - const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; - ciphertext_end += ciphertext_len; - const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += auth_tag.len; - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); - c.write_seq += 1; // TODO send key_update on overflow - const nonce = @as(V, p.client_iv) ^ operand; - P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); - } - }, - } -} - -pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} - -pub fn async_readv( - c: *Client, - stream: anytype, - iovecs: []std.posix.iovec, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - return async_readvAtLeast(c, stream, iovecs, 1, ctx, cbk); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} - -fn onReadvAtLeast(ctx: *Ctx, res: anyerror!void) !void { - res catch |err| return ctx.pop(err); - - const c = ctx.conn().tls_client; - - var amt = ctx.len(); - ctx._off_i += amt; - if (c.eof() or ctx._off_i >= ctx._tls_len) { - ctx.setLen(ctx._off_i); - return ctx.pop({}); - } - while (amt >= ctx._iovecs[ctx._vec_i].len) { - amt -= ctx._iovecs[ctx._vec_i].len; - ctx._vec_i += 1; - } - ctx._iovecs[ctx._vec_i].base += amt; - ctx._iovecs[ctx._vec_i].len -= amt; - - c.async_readvAdvanced( - ctx.stream(), - ctx._iovecs[ctx._vec_i..], - ctx, - onReadvAtLeast, - ) catch |err| return ctx.pop(err); -} - -pub fn async_readvAtLeast( - c: *Client, - stream: anytype, - iovecs: []std.posix.iovec, - len: usize, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - if (c.eof()) return; - - ctx._tls_len = len; - ctx._off_i = 0; - ctx._vec_i = 0; - - try ctx.push(cbk); - return c.async_readvAdvanced(stream, iovecs[ctx._vec_i..], ctx, onReadvAtLeast); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].len) { - amt -= iovecs[vec_i].len; - vec_i += 1; - } - iovecs[vec_i].base += amt; - iovecs[vec_i].len -= amt; - } -} - -pub fn async_readvAdvanced( - c: *Client, - stream: anytype, - iovecs: []const std.posix.iovec, - ctx: *Ctx, - comptime cbk: Cbk, -) !void { - ctx._vp = try ctx.alloc().create(VecPut); - errdefer ctx.alloc().destroy(ctx._vp); - - ctx._vp.* = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(ctx._vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(ctx._vp.total == amt); - ctx.setLen(amt); - return ctx.pop({}); - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(ctx._vp.total == amt); - ctx.setLen(amt); - return ctx.pop({}); - } - } - - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. - ctx._cleartext_stack_buffer = try ctx.alloc().alloc(u8, max_ciphertext_len); - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - ctx._in_stack_buffer = try ctx.alloc().alloc(u8, max_ciphertext_len * 4); - // How many bytes left in the user's buffer. - const free_size = ctx._vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| ctx._in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer space`. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - ctx._first_iov = try ctx.alloc().alloc(u8, first_iov.len); - @memcpy(ctx._first_iov, first_iov); - - var ask_iovecs_buf: [2]std.posix.iovec = .{ - .{ - .base = first_iov.ptr, - .len = first_iov.len, - }, - .{ - .base = ctx._in_stack_buffer.ptr, - .len = ctx._in_stack_buffer.len, - }, - }; - - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, ctx._cleartext_stack_buffer.len); - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - - try ctx.push(cbk); - return stream.async_readv(ask_iovecs, ctx, setDecode); -} - -fn setDecode(ctx: *Ctx, res: anyerror!void) anyerror!void { - res catch |err| { - // reset ctx - ctx.alloc().destroy(ctx._vp); - ctx.alloc().free(ctx._cleartext_stack_buffer); - ctx.alloc().free(ctx._in_stack_buffer); - ctx.alloc().free(ctx._first_iov); - - return ctx.pop(err); - }; - - const len = decode( - ctx.conn().tls_client, - ctx.len(), - ctx._in_stack_buffer, - ctx._first_iov, - ctx._vp, - ctx._cleartext_stack_buffer, - ) catch |err| return ctx.pop(err); - - // reset ctx - ctx.alloc().destroy(ctx._vp); - ctx.alloc().free(ctx._cleartext_stack_buffer); - ctx.alloc().free(ctx._in_stack_buffer); - ctx.alloc().free(ctx._first_iov); - - ctx.setLen(len); - return ctx.pop({}); -} - -fn decode( - c: *Client, - actual_read_len: usize, - in_stack_buffer: []u8, - first_iov: []u8, - vp: *VecPut, - cleartext_stack_buffer: []u8, -) !usize { - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; - } - } - - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - switch (ct) { - .alert => { - if (in + 2 > frag.len) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(frag[in]); - const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .application_data => { - const cleartext = switch (c.application_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const V = @Vector(P.AEAD.nonce_length, u8); - const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); - const nonce: [P.AEAD.nonce_length]u8 = @as(V, p.server_iv) ^ operand; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch - return error.TlsBadRecordMac; - break :c cleartext; - }, - }; - - c.read_seq = try std.math.add(u64, c.read_seq, 1); - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - switch (inner_ct) { - .alert => { - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == cleartext_stack_buffer.ptr) { - // Stack buffer was used, so we must copy to the output buffer. - const msg = cleartext[0 .. cleartext.len - 1]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], - msg, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); - } else { - const amt = vp.put(msg); - if (amt < msg.len) { - const rest = msg[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); - } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len - 1); - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - in = end; - } -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; - } - } - - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. - var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer space`. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.posix.iovec = .{ - .{ - .base = first_iov.ptr, - .len = first_iov.len, - }, - .{ - .base = &in_stack_buffer, - .len = in_stack_buffer.len, - }, - }; - - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len); - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; - } - } - - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - switch (ct) { - .alert => { - if (in + 2 > frag.len) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(frag[in]); - const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .application_data => { - const cleartext = switch (c.application_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); - break :nonce @as(V, p.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch - return error.TlsBadRecordMac; - break :c cleartext; - }, - }; - - c.read_seq = try std.math.add(u64, c.read_seq, 1); - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - switch (inner_ct) { - .alert => { - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - const msg = cleartext[0 .. cleartext.len - 1]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], - msg, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); - } else { - const amt = vp.put(msg); - if (amt < msg.len) { - const rest = msg[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); - } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len - 1); - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - in = end; - } -} - -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return out; -} - -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return out; -} - -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); - } -} - -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; - } -} - -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); - -inline fn big(x: anytype) @TypeOf(x) { - return switch (native_endian) { - .big => x, - .little => @byteSwap(x), - }; -} - -fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, - .ecdsa_secp384r1_sha384 => crypto.sign.ecdsa.EcdsaP384Sha384, - else => @compileError("bad scheme"), - }; -} - -fn SchemeHash(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, - else => @compileError("bad scheme"), - }; -} - -fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { - return switch (scheme) { - .ed25519 => crypto.sign.Ed25519, - else => @compileError("bad scheme"), - }; -} - -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -pub const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } - - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.base[vp.off..v.len]; - } - - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].len) { - vp.off = 0; - vp.idx += 1; - } - } - - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len; - return total; - } -}; - -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.len) { - iovec.len = bytes_left; - return iovecs[0 .. vec_i + 1]; - } - bytes_left -= iovec.len; - } - return iovecs; -} - -/// The priority order here is chosen based on what crypto algorithms Zig has -/// available in the standard library as well as what is faster. Following are -/// a few data points on the relative performance of these algorithms. -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -/// aegis-128l: 15382 MiB/s -/// aegis-256: 9553 MiB/s -/// aes128-gcm: 3721 MiB/s -/// aes256-gcm: 3010 MiB/s -/// chacha20Poly1305: 597 MiB/s -/// -/// Measurement taken with 0.11.0-dev.810+c2f5848fe -/// on x86_64-linux Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz: -/// zig run .lib/std/crypto/benchmark.zig -OReleaseFast -mcpu=baseline -/// aegis-128l: 629 MiB/s -/// chacha20Poly1305: 529 MiB/s -/// aegis-256: 461 MiB/s -/// aes128-gcm: 138 MiB/s -/// aes256-gcm: 120 MiB/s -const cipher_suites = if (crypto.core.aes.has_hardware_support) - enum_array(tls.CipherSuite, &.{ - .AEGIS_128L_SHA256, - .AEGIS_256_SHA512, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - }) -else - enum_array(tls.CipherSuite, &.{ - .CHACHA20_POLY1305_SHA256, - .AEGIS_128L_SHA256, - .AEGIS_256_SHA512, - .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - }); - -test { - _ = StreamInterface; -} diff --git a/src/std/http.zig b/src/std/http.zig index af966d8..f027d44 100644 --- a/src/std/http.zig +++ b/src/std/http.zig @@ -1,9 +1,9 @@ pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); pub const protocol = @import("http/protocol.zig"); -pub const HeadParser = @import("http/HeadParser.zig"); -pub const ChunkParser = @import("http/ChunkParser.zig"); -pub const HeaderIterator = @import("http/HeaderIterator.zig"); +pub const HeadParser = std.http.HeadParser; +pub const ChunkParser = std.http.ChunkParser; +pub const HeaderIterator = std.http.HeaderIterator; pub const Version = enum { @"HTTP/1.0", @@ -308,16 +308,11 @@ pub const Header = struct { }; const builtin = @import("builtin"); -const std = @import("std.zig"); +const std = @import("std"); test { _ = Client; _ = Method; _ = Server; _ = Status; - _ = HeadParser; - _ = ChunkParser; - if (builtin.os.tag != .wasi) { - _ = @import("http/test.zig"); - } } diff --git a/src/std/http/Client.zig b/src/std/http/Client.zig index e443f3a..aca572c 100644 --- a/src/std/http/Client.zig +++ b/src/std/http/Client.zig @@ -18,12 +18,14 @@ const use_vectors = builtin.zig_backend != .stage2_x86_64; const Client = @This(); const proto = @import("protocol.zig"); -const async_net = @import("../net.zig"); -const async_tls = @import("../crypto/tls/Client.zig"); +const tls23 = @import("../../tls.zig/main.zig"); +const VecPut = @import("../../tls.zig/connection.zig").VecPut; const GenericStack = @import("../../stack.zig").Stack; const async_io = @import("../../io.zig"); const Loop = async_io.Blocking; +const cipher = @import("../../tls.zig/cipher.zig"); + pub const disable_tls = std.options.http_disable_tls; /// Used for all client allocations. Must be thread-safe. @@ -196,9 +198,9 @@ pub const ConnectionPool = struct { /// An interface to either a plain or TLS connection. pub const Connection = struct { - stream: async_net.Stream, + stream: net.Stream, /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *async_tls.Client else void, + tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void, /// The protocol that this connection is using. protocol: Protocol, @@ -244,12 +246,12 @@ pub const Connection = struct { } pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { + return conn.tls_client.readv(buffers) catch |err| { // https://github.com/ziglang/zig/issues/2473 if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, + error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure, error.ConnectionTimedOut => return error.ConnectionTimedOut, error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, else => return error.UnexpectedReadFailure, @@ -289,6 +291,7 @@ pub const Connection = struct { if (conn.read_end != conn.read_start) return; ctx._iovecs = try ctx.alloc().alloc(std.posix.iovec, 1); + errdefer ctx.alloc().free(ctx._iovecs); const iovecs = [1]std.posix.iovec{ .{ .base = &conn.read_buf, .len = conn.read_buf.len }, }; @@ -369,7 +372,7 @@ pub const Connection = struct { } pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { + return conn.tls_client.writeAll(buffer) catch |err| switch (err) { error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, else => return error.UnexpectedWriteFailure, }; @@ -476,7 +479,7 @@ pub const Connection = struct { if (disable_tls) unreachable; // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + conn.tls_client.close() catch {}; allocator.destroy(conn.tls_client); } @@ -1644,13 +1647,14 @@ fn setConnection(ctx: *Ctx, res: anyerror!void) !void { if (ctx.data.conn.protocol == .tls) { if (disable_tls) unreachable; - ctx.data.conn.tls_client = try ctx.alloc().create(async_tls.Client); + ctx.data.conn.tls_client = try ctx.alloc().create(tls23.Connection(net.Stream)); errdefer ctx.alloc().destroy(ctx.data.conn.tls_client); - ctx.data.conn.tls_client.* = async_tls.Client.init(ctx.data.conn.stream, ctx.req.client.ca_bundle, ctx.data.conn.host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - ctx.data.conn.tls_client.allow_truncation_attacks = true; + // TODO tls23.client does an handshake to pick a cipher. + ctx.data.conn.tls_client.* = tls23.client(ctx.data.conn.stream, .{ + .host = ctx.data.conn.host, + .root_ca = .{ .bundle = ctx.req.client.ca_bundle }, + }) catch return error.TlsInitializationFailed; } // add connection node in pool @@ -1720,13 +1724,14 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec if (protocol == .tls) { if (disable_tls) unreachable; - conn.data.tls_client = try client.allocator.create(async_tls.Client); + conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream)); errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.tls_client.* = async_tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.data.tls_client.allow_truncation_attacks = true; + // TODO tls23.client does an handshake to pick a cipher. + conn.data.tls_client.* = tls23.client(stream, .{ + .host = host, + .root_ca = .{ .bundle = client.ca_bundle }, + }) catch return error.TlsInitializationFailed; } client.connection_pool.addUsed(conn); @@ -1755,7 +1760,7 @@ pub fn async_connectTcp( if (disable_tls and protocol == .tls) return error.TlsInitializationFailed; - return async_net.async_tcpConnectToHost( + return net.async_tcpConnectToHost( client.allocator, host, port, @@ -2317,17 +2322,26 @@ pub const Ctx = struct { _buffer: ?[]const u8 = null, _len: ?usize = null, - // TLS readvAtLeast - _off_i: usize = 0, - _vec_i: usize = 0, - _tls_len: usize = 0, _iovecs: []std.posix.iovec = undefined, - // TLS readvAdvanced - _vp: *async_tls.VecPut = undefined, - _cleartext_stack_buffer: []u8 = undefined, - _in_stack_buffer: []u8 = undefined, - _first_iov: []u8 = undefined, + // TLS readvAtLeast + // _off_i: usize = 0, + // _vec_i: usize = 0, + // _tls_len: usize = 0, + + // TLS readv + _vp: VecPut = undefined, + // _tls_read_buf contains the next decrypted buffer + _tls_read_buf: ?[]u8 = undefined, + _tls_read_content_type: tls23.proto.ContentType = undefined, + + // _tls_read_record contains the crypted record + _tls_read_record: ?tls23.record.Record = null, + + // TLS writeAll + _tls_write_bytes: []const u8 = undefined, + _tls_write_index: usize = 0, + _tls_write_buf: [cipher.max_ciphertext_record_len]u8 = undefined, pub fn init(loop: *Loop, req: *Request) !Ctx { const connection = try req.client.allocator.create(Connection); @@ -2404,7 +2418,7 @@ pub const Ctx = struct { return self.req.connection.?; } - pub fn stream(self: Ctx) async_net.Stream { + pub fn stream(self: Ctx) net.Stream { return self.conn().stream; } }; @@ -2424,7 +2438,7 @@ fn onRequestWait(ctx: *Ctx, res: anyerror!void) !void { }; std.log.debug("REQUEST WAITED", .{}); std.log.debug("Status code: {any}", .{ctx.req.response.status}); - const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 2000); + const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 1024 * 1024); defer ctx.alloc().free(body); std.log.debug("Body: \n{s}", .{body}); } diff --git a/src/std/http/Server.zig b/src/std/http/Server.zig index 25d3e44..38d3f13 100644 --- a/src/std/http/Server.zig +++ b/src/std/http/Server.zig @@ -1137,7 +1137,7 @@ fn rebase(s: *Server, index: usize) void { s.read_buffer_len = index + leftover.len; } -const std = @import("../std.zig"); +const std = @import("std"); const http = std.http; const mem = std.mem; const net = std.net; diff --git a/src/tls.zig/PrivateKey.zig b/src/tls.zig/PrivateKey.zig new file mode 100644 index 0000000..0e2b944 --- /dev/null +++ b/src/tls.zig/PrivateKey.zig @@ -0,0 +1,260 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const Certificate = std.crypto.Certificate; +const der = Certificate.der; +const rsa = @import("rsa/rsa.zig"); +const base64 = std.base64.standard.decoderWithIgnore(" \t\r\n"); +const proto = @import("protocol.zig"); + +const max_ecdsa_key_len = 66; + +signature_scheme: proto.SignatureScheme, + +key: union { + rsa: rsa.KeyPair, + ecdsa: [max_ecdsa_key_len]u8, +}, + +const PrivateKey = @This(); + +pub fn fromFile(gpa: Allocator, file: std.fs.File) !PrivateKey { + const buf = try file.readToEndAlloc(gpa, 1024 * 1024); + defer gpa.free(buf); + return try parsePem(buf); +} + +pub fn parsePem(buf: []const u8) !PrivateKey { + const key_start, const key_end, const marker_version = try findKey(buf); + const encoded = std.mem.trim(u8, buf[key_start..key_end], " \t\r\n"); + + // required bytes: + // 2412, 1821, 1236 for rsa 4096, 3072, 2048 bits size keys + var decoded: [4096]u8 = undefined; + const n = try base64.decode(&decoded, encoded); + + if (marker_version == 2) { + return try parseEcDer(decoded[0..n]); + } + return try parseDer(decoded[0..n]); +} + +fn findKey(buf: []const u8) !struct { usize, usize, usize } { + const markers = [_]struct { + begin: []const u8, + end: []const u8, + }{ + .{ .begin = "-----BEGIN PRIVATE KEY-----", .end = "-----END PRIVATE KEY-----" }, + .{ .begin = "-----BEGIN EC PRIVATE KEY-----", .end = "-----END EC PRIVATE KEY-----" }, + }; + + for (markers, 1..) |marker, ver| { + const begin_marker_start = std.mem.indexOfPos(u8, buf, 0, marker.begin) orelse continue; + const key_start = begin_marker_start + marker.begin.len; + const key_end = std.mem.indexOfPos(u8, buf, key_start, marker.end) orelse continue; + + return .{ key_start, key_end, ver }; + } + return error.MissingEndMarker; +} + +// ref: https://asn1js.eu/#MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBKFkVJCtU9FR6egz3yNxKBwXd86cFzMYqyGb8hRc1zVvLdw-So_2FBtITp6jzYmFShZANiAAQ-CH3a1R0V6dFlTK8Rs4M4egrpPtdta0osysO0Zl8mkBiDsTlvJNqeAp7L2ItHgFW8k_CfhgQT6iLDacNMhKC4XOV07r_ePD-mmkvqvRmzfOowHUoVRhCKrOTmF_J9Syc +pub fn parseDer(buf: []const u8) !PrivateKey { + const info = try der.Element.parse(buf, 0); + const version = try der.Element.parse(buf, info.slice.start); + + const algo_seq = try der.Element.parse(buf, version.slice.end); + const algo_cat = try der.Element.parse(buf, algo_seq.slice.start); + + const key_str = try der.Element.parse(buf, algo_seq.slice.end); + const key_seq = try der.Element.parse(buf, key_str.slice.start); + const key_int = try der.Element.parse(buf, key_seq.slice.start); + + const category = try Certificate.parseAlgorithmCategory(buf, algo_cat); + switch (category) { + .rsaEncryption => { + const modulus = try der.Element.parse(buf, key_int.slice.end); + const public_exponent = try der.Element.parse(buf, modulus.slice.end); + const private_exponent = try der.Element.parse(buf, public_exponent.slice.end); + + const public_key = try rsa.PublicKey.fromBytes(content(buf, modulus), content(buf, public_exponent)); + const secret_key = try rsa.SecretKey.fromBytes(public_key.modulus, content(buf, private_exponent)); + const key_pair = rsa.KeyPair{ .public = public_key, .secret = secret_key }; + + return .{ + .signature_scheme = switch (key_pair.public.modulus.bits()) { + 4096 => .rsa_pss_rsae_sha512, + 3072 => .rsa_pss_rsae_sha384, + else => .rsa_pss_rsae_sha256, + }, + .key = .{ .rsa = key_pair }, + }; + }, + .X9_62_id_ecPublicKey => { + const key = try der.Element.parse(buf, key_int.slice.end); + const algo_param = try der.Element.parse(buf, algo_cat.slice.end); + const named_curve = try Certificate.parseNamedCurve(buf, algo_param); + return .{ + .signature_scheme = signatureScheme(named_curve), + .key = .{ .ecdsa = ecdsaKey(buf, key) }, + }; + }, + else => unreachable, + } +} + +// References: +// https://asn1js.eu/#MHcCAQEEINJSRKv8kSKEzLHptfAlg-LGh4_pHHlq0XLf30Q9pcztoAoGCCqGSM49AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq_-4V1K6nPpeoih3bT2npeplF9eyXj7rm8eW9Ua6VLhq71mqtMC-YLm-IkORBVq1cuA +// https://www.rfc-editor.org/rfc/rfc5915 +pub fn parseEcDer(bytes: []const u8) !PrivateKey { + const pki_msg = try der.Element.parse(bytes, 0); + const version = try der.Element.parse(bytes, pki_msg.slice.start); + const key = try der.Element.parse(bytes, version.slice.end); + const parameters = try der.Element.parse(bytes, key.slice.end); + const curve = try der.Element.parse(bytes, parameters.slice.start); + const named_curve = try Certificate.parseNamedCurve(bytes, curve); + return .{ + .signature_scheme = signatureScheme(named_curve), + .key = .{ .ecdsa = ecdsaKey(bytes, key) }, + }; +} + +fn signatureScheme(named_curve: Certificate.NamedCurve) proto.SignatureScheme { + return switch (named_curve) { + .X9_62_prime256v1 => .ecdsa_secp256r1_sha256, + .secp384r1 => .ecdsa_secp384r1_sha384, + .secp521r1 => .ecdsa_secp521r1_sha512, + }; +} + +fn ecdsaKey(bytes: []const u8, e: der.Element) [max_ecdsa_key_len]u8 { + const data = content(bytes, e); + var ecdsa_key: [max_ecdsa_key_len]u8 = undefined; + @memcpy(ecdsa_key[0..data.len], data); + return ecdsa_key; +} + +fn content(bytes: []const u8, e: der.Element) []const u8 { + return bytes[e.slice.start..e.slice.end]; +} + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "parse ec pem" { + const data = @embedFile("testdata/ec_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ 10 35 3d ca 1b 15 1d 06 aa 71 b8 ef f3 19 22 + \\ 43 78 f3 20 98 1e b1 2f 2b 64 7e 71 d0 30 2a + \\ 90 aa e5 eb 99 c3 90 65 3d c1 26 19 be 3f 08 + \\ 20 9b 01 + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); +} + +test "parse ec prime256v1" { + const data = @embedFile("testdata/ec_prime256v1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ d2 52 44 ab fc 91 22 84 cc b1 e9 b5 f0 25 83 + \\ e2 c6 87 8f e9 1c 79 6a d1 72 df df 44 3d a5 + \\ cc ed + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp256r1_sha256, pk.signature_scheme); +} + +test "parse ec secp384r1" { + const data = @embedFile("testdata/ec_secp384r1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ ee 6d 8a 5e 0d d3 b0 c6 4b 32 40 80 e2 3a de + \\ 8b 1e dd e2 92 db 36 1c db 91 ea ba a1 06 0d + \\ 42 2d d9 a9 dc 05 43 29 f1 78 7c f9 08 af c5 + \\ 03 1f 6d + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp384r1_sha384, pk.signature_scheme); +} + +test "parse ec secp521r1" { + const data = @embedFile("testdata/ec_secp521r1_private_key.pem"); + var pk = try parsePem(data); + const priv_key = &testu.hexToBytes( + \\ 01 f0 2f 5a c7 24 18 ea 68 23 8c 2e a1 b4 b8 + \\ dc f2 11 b2 96 b0 ec 87 80 42 bf de ba f4 96 + \\ 83 8f 9b db c6 60 a7 4c d9 60 3a e4 ba 0b df + \\ ae 24 d3 1b c2 6e 82 a0 88 c1 ed 17 20 0d 3a + \\ f1 c5 7e e8 0b 27 + ); + try testing.expectEqualSlices(u8, priv_key, pk.key.ecdsa[0..priv_key.len]); + try testing.expectEqual(.ecdsa_secp521r1_sha512, pk.signature_scheme); +} + +test "parse rsa pem" { + const data = @embedFile("testdata/rsa_private_key.pem"); + const pk = try parsePem(data); + + // expected results from: + // $ openssl pkey -in testdata/rsa_private_key.pem -text -noout + const modulus = &testu.hexToBytes( + \\ 00 de f7 23 e6 75 cc 6f dd d5 6e 0f 8c 09 f8 + \\ 62 e3 60 1b c0 7d 8c d5 04 50 2c 36 e2 3b f7 + \\ 33 9f a1 14 af be cf 1a 0f 4c f5 cb 39 70 0e + \\ 3b 97 d6 21 f7 48 91 79 ca 7c 68 fc ea 62 a1 + \\ 5a 72 4f 78 57 0e cc f2 a3 50 05 f1 4c ca 51 + \\ 73 10 9a 18 8e 71 f5 b4 c7 3e be 4c ef 37 d4 + \\ 84 4b 82 1c ec 08 a3 cc 07 3d 5c 0b e5 85 3f + \\ fe b6 44 77 8f 3c 6a 2f 33 c3 5d f6 f2 29 46 + \\ 04 25 7e 05 d9 f8 3b 2d a4 40 66 9f 0d 6d 1a + \\ fa bc 0a c5 8b 86 43 30 ef 14 20 41 9d b5 cc + \\ 3e 63 b5 48 04 27 c9 5c d3 62 28 5f f5 b6 e4 + \\ 77 49 99 ac 84 4a a6 67 a5 9a 1a 37 c7 60 4c + \\ ba c1 70 cf 57 64 4a 21 ea 05 53 10 ec 94 71 + \\ 4a 43 04 83 00 aa 5a 28 bc f2 8c 58 14 92 d2 + \\ 83 17 f4 7b 29 0f e7 87 a2 47 b2 53 19 12 23 + \\ fb 4b ce 5a f8 a1 84 f9 b1 f3 bf e3 fa 10 f8 + \\ ad af 87 ce 03 0e a0 2c 13 71 57 c4 55 44 48 + \\ 44 cb + ); + const public_exponent = &testu.hexToBytes("01 00 01"); + const private_exponent = &testu.hexToBytes( + \\ 50 3b 80 98 aa a5 11 50 33 40 32 aa 02 e0 75 + \\ bd 3a 55 62 34 0b 9c 8f bb c5 dd 4e 15 a4 03 + \\ d8 9a 5f 56 4a 84 3d ed 69 95 3d 37 03 02 ac + \\ 21 1c 36 06 c4 ff 4c 63 37 d7 93 c3 48 10 a5 + \\ fa 62 6c 7c 6f 60 02 a4 0f e4 c3 8b 0d 76 b7 + \\ c0 2e a3 4d 86 e6 92 d1 eb db 10 d6 38 31 ea + \\ 15 3d d1 e8 81 c7 67 60 e7 8c 9a df 51 ce d0 + \\ 7a 88 32 b9 c1 54 b8 7d 98 fc d4 23 1a 05 0e + \\ f2 ea e1 72 29 28 2a 68 b7 90 18 80 1c 21 d6 + \\ 36 a8 6b 4a 9c dd 14 b8 9f 85 ee 95 0b f4 c6 + \\ 17 02 aa 4d ea 4d f9 39 d7 dd 9d b4 1d d2 f8 + \\ 92 46 0f 18 41 80 f4 ea 27 55 29 f8 37 59 bf + \\ 43 ec a3 eb 19 ba bc 13 06 95 3d 25 4b c9 72 + \\ cf 41 0a 6f aa cb 79 d4 7b fa b1 09 7c e2 2f + \\ 85 51 44 8b c6 97 8e 46 f9 6b ac 08 87 92 ce + \\ af 0b bf 8c bd 27 51 8f 09 e4 d3 f9 04 ac fa + \\ f2 04 70 3e d9 a6 28 17 c2 2d 74 e9 25 40 02 + \\ 49 + ); + + try testing.expectEqual(.rsa_pss_rsae_sha256, pk.signature_scheme); + const kp = pk.key.rsa; + { + var bytes: [modulus.len]u8 = undefined; + try kp.public.modulus.toBytes(&bytes, .big); + try testing.expectEqualSlices(u8, modulus, &bytes); + } + { + var bytes: [private_exponent.len]u8 = undefined; + try kp.public.public_exponent.toBytes(&bytes, .big); + try testing.expectEqualSlices(u8, public_exponent, bytes[bytes.len - public_exponent.len .. bytes.len]); + } + { + var btytes: [private_exponent.len]u8 = undefined; + try kp.secret.private_exponent.toBytes(&btytes, .big); + try testing.expectEqualSlices(u8, private_exponent, &btytes); + } +} diff --git a/src/tls.zig/cbc/main.zig b/src/tls.zig/cbc/main.zig new file mode 100644 index 0000000..2503844 --- /dev/null +++ b/src/tls.zig/cbc/main.zig @@ -0,0 +1,148 @@ +// This file is originally copied from: https://github.com/jedisct1/zig-cbc. +// +// It is modified then to have TLS padding insead of PKCS#7 padding. +// Reference: +// https://datatracker.ietf.org/doc/html/rfc5246/#section-6.2.3.2 +// https://crypto.stackexchange.com/questions/98917/on-the-correctness-of-the-padding-example-of-rfc-5246 +// +// If required padding i n bytes +// PKCS#7 padding is (n...n) +// TLS padding is (n-1...n-1) - n times of n-1 value +// +const std = @import("std"); +const aes = std.crypto.core.aes; +const mem = std.mem; +const debug = std.debug; + +/// CBC mode with TLS 1.2 padding +/// +/// Important: the counter mode doesn't provide authenticated encryption: the ciphertext can be trivially modified without this being detected. +/// If you need authenticated encryption, use anything from `std.crypto.aead` instead. +/// If you really need to use CBC mode, make sure to use a MAC to authenticate the ciphertext. +pub fn CBC(comptime BlockCipher: anytype) type { + const EncryptCtx = aes.AesEncryptCtx(BlockCipher); + const DecryptCtx = aes.AesDecryptCtx(BlockCipher); + + return struct { + const Self = @This(); + + enc_ctx: EncryptCtx, + dec_ctx: DecryptCtx, + + /// Initialize the CBC context with the given key. + pub fn init(key: [BlockCipher.key_bits / 8]u8) Self { + const enc_ctx = BlockCipher.initEnc(key); + const dec_ctx = DecryptCtx.initFromEnc(enc_ctx); + + return Self{ .enc_ctx = enc_ctx, .dec_ctx = dec_ctx }; + } + + /// Return the length of the ciphertext given the length of the plaintext. + pub fn paddedLength(length: usize) usize { + return (std.math.divCeil(usize, length + 1, EncryptCtx.block_length) catch unreachable) * EncryptCtx.block_length; + } + + /// Encrypt the given plaintext for the given IV. + /// The destination buffer must be large enough to hold the padded plaintext. + /// Use the `paddedLength()` function to compute the ciphertext size. + /// IV must be secret and unpredictable. + pub fn encrypt(self: Self, dst: []u8, src: []const u8, iv: [EncryptCtx.block_length]u8) void { + // Note: encryption *could* be parallelized, see https://research.kudelskisecurity.com/2022/11/17/some-aes-cbc-encryption-myth-busting/ + const block_length = EncryptCtx.block_length; + const padded_length = paddedLength(src.len); + debug.assert(dst.len == padded_length); // destination buffer must hold the padded plaintext + var cv = iv; + var i: usize = 0; + while (i + block_length <= src.len) : (i += block_length) { + const in = src[i..][0..block_length]; + for (cv[0..], in) |*x, y| x.* ^= y; + self.enc_ctx.encrypt(&cv, &cv); + @memcpy(dst[i..][0..block_length], &cv); + } + // Last block + var in = [_]u8{0} ** block_length; + const padding_length: u8 = @intCast(padded_length - src.len - 1); + @memset(&in, padding_length); + @memcpy(in[0 .. src.len - i], src[i..]); + for (cv[0..], in) |*x, y| x.* ^= y; + self.enc_ctx.encrypt(&cv, &cv); + @memcpy(dst[i..], cv[0 .. dst.len - i]); + } + + /// Decrypt the given ciphertext for the given IV. + /// The destination buffer must be large enough to hold the plaintext. + /// IV must be secret, unpredictable and match the one used for encryption. + pub fn decrypt(self: Self, dst: []u8, src: []const u8, iv: [DecryptCtx.block_length]u8) !void { + const block_length = DecryptCtx.block_length; + if (src.len != dst.len) { + return error.EncodingError; + } + debug.assert(src.len % block_length == 0); + var i: usize = 0; + var cv = iv; + var out: [block_length]u8 = undefined; + // Decryption could be parallelized + while (i + block_length <= dst.len) : (i += block_length) { + const in = src[i..][0..block_length]; + self.dec_ctx.decrypt(&out, in); + for (&out, cv) |*x, y| x.* ^= y; + cv = in.*; + @memcpy(dst[i..][0..block_length], &out); + } + // Last block - We intentionally don't check the padding to mitigate timing attacks + if (i < dst.len) { + const in = src[i..][0..block_length]; + @memset(&out, 0); + self.dec_ctx.decrypt(&out, in); + for (&out, cv) |*x, y| x.* ^= y; + @memcpy(dst[i..], out[0 .. dst.len - i]); + } + } + }; +} + +test "CBC mode" { + const M = CBC(aes.Aes128); + const key = [_]u8{ 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c }; + const iv = [_]u8{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f }; + const src_ = "This is a test of AES-CBC that goes on longer than a couple blocks. It is a somewhat long test case to type out!"; + const expected = "\xA0\x8C\x09\x7D\xFF\x42\xB6\x65\x4D\x4B\xC6\x90\x90\x39\xDE\x3D\xC7\xCA\xEB\xF6\x9A\x4F\x09\x97\xC9\x32\xAB\x75\x88\xB7\x57\x17"; + var res: [32]u8 = undefined; + + try comptime std.testing.expect(src_.len / M.paddedLength(1) >= 3); // Ensure that we have at least 3 blocks + + const z = M.init(key); + + // Test encryption and decryption with distinct buffers + var h = std.crypto.hash.sha2.Sha256.init(.{}); + inline for (0..src_.len) |len| { + const src = src_[0..len]; + var dst = [_]u8{0} ** M.paddedLength(src.len); + + z.encrypt(&dst, src, iv); + h.update(&dst); + + var decrypted = [_]u8{0} ** dst.len; + try z.decrypt(&decrypted, &dst, iv); + + const padding = decrypted[decrypted.len - 1] + 1; + try std.testing.expectEqualSlices(u8, src, decrypted[0 .. decrypted.len - padding]); + } + h.final(&res); + try std.testing.expectEqualSlices(u8, expected, &res); + + // Test encryption and decryption with the same buffer + h = std.crypto.hash.sha2.Sha256.init(.{}); + inline for (0..src_.len) |len| { + var buf = [_]u8{0} ** M.paddedLength(len); + @memcpy(buf[0..len], src_[0..len]); + z.encrypt(&buf, buf[0..len], iv); + h.update(&buf); + + try z.decrypt(&buf, &buf, iv); + + try std.testing.expectEqualSlices(u8, src_[0..len], buf[0..len]); + } + h.final(&res); + try std.testing.expectEqualSlices(u8, expected, &res); +} diff --git a/src/tls.zig/cipher.zig b/src/tls.zig/cipher.zig new file mode 100644 index 0000000..dbf4a07 --- /dev/null +++ b/src/tls.zig/cipher.zig @@ -0,0 +1,1004 @@ +const std = @import("std"); +const crypto = std.crypto; +const hkdfExpandLabel = crypto.tls.hkdfExpandLabel; + +const Sha1 = crypto.hash.Sha1; +const Sha256 = crypto.hash.sha2.Sha256; +const Sha384 = crypto.hash.sha2.Sha384; + +const record = @import("record.zig"); +const Record = record.Record; +const Transcript = @import("transcript.zig").Transcript; +const proto = @import("protocol.zig"); + +// tls 1.2 cbc cipher types +const CbcAes128Sha1 = CbcType(crypto.core.aes.Aes128, Sha1); +const CbcAes128Sha256 = CbcType(crypto.core.aes.Aes128, Sha256); +const CbcAes256Sha256 = CbcType(crypto.core.aes.Aes256, Sha256); +const CbcAes256Sha384 = CbcType(crypto.core.aes.Aes256, Sha384); +// tls 1.2 gcm cipher types +const Aead12Aes128Gcm = Aead12Type(crypto.aead.aes_gcm.Aes128Gcm); +const Aead12Aes256Gcm = Aead12Type(crypto.aead.aes_gcm.Aes256Gcm); +// tls 1.2 chacha cipher type +const Aead12ChaCha = Aead12ChaChaType(crypto.aead.chacha_poly.ChaCha20Poly1305); +// tls 1.3 cipher types +const Aes128GcmSha256 = Aead13Type(crypto.aead.aes_gcm.Aes128Gcm, Sha256); +const Aes256GcmSha384 = Aead13Type(crypto.aead.aes_gcm.Aes256Gcm, Sha384); +const ChaChaSha256 = Aead13Type(crypto.aead.chacha_poly.ChaCha20Poly1305, Sha256); +const Aegis128Sha256 = Aead13Type(crypto.aead.aegis.Aegis128L, Sha256); + +pub const encrypt_overhead_tls_12: comptime_int = @max( + CbcAes128Sha1.encrypt_overhead, + CbcAes128Sha256.encrypt_overhead, + CbcAes256Sha256.encrypt_overhead, + CbcAes256Sha384.encrypt_overhead, + Aead12Aes128Gcm.encrypt_overhead, + Aead12Aes256Gcm.encrypt_overhead, + Aead12ChaCha.encrypt_overhead, +); +pub const encrypt_overhead_tls_13: comptime_int = @max( + Aes128GcmSha256.encrypt_overhead, + Aes256GcmSha384.encrypt_overhead, + ChaChaSha256.encrypt_overhead, + Aegis128Sha256.encrypt_overhead, +); + +// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.1 +pub const max_cleartext_len = 1 << 14; +// ref (length): https://www.rfc-editor.org/rfc/rfc8446#section-5.2 +// The sum of the lengths of the content and the padding, plus one for the inner +// content type, plus any expansion added by the AEAD algorithm. +pub const max_ciphertext_len = max_cleartext_len + 256; +pub const max_ciphertext_record_len = record.header_len + max_ciphertext_len; + +/// Returns type for cipher suite tag. +fn CipherType(comptime tag: CipherSuite) type { + return switch (tag) { + // tls 1.2 cbc + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA, + => CbcAes128Sha1, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA256, + => CbcAes128Sha256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + => CbcAes256Sha384, + + // tls 1.2 gcm + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + => Aead12Aes128Gcm, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + => Aead12Aes256Gcm, + + // tls 1.2 chacha + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => Aead12ChaCha, + + // tls 1.3 + .AES_128_GCM_SHA256 => Aes128GcmSha256, + .AES_256_GCM_SHA384 => Aes256GcmSha384, + .CHACHA20_POLY1305_SHA256 => ChaChaSha256, + .AEGIS_128L_SHA256 => Aegis128Sha256, + + else => unreachable, + }; +} + +/// Provides initialization and common encrypt/decrypt methods for all supported +/// ciphers. Tls 1.2 has only application cipher, tls 1.3 has separate cipher +/// for handshake and application. +pub const Cipher = union(CipherSuite) { + // tls 1.2 cbc + ECDHE_ECDSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA), + ECDHE_RSA_WITH_AES_128_CBC_SHA: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA), + RSA_WITH_AES_128_CBC_SHA: CipherType(.RSA_WITH_AES_128_CBC_SHA), + + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_CBC_SHA256), + ECDHE_RSA_WITH_AES_128_CBC_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_CBC_SHA256), + RSA_WITH_AES_128_CBC_SHA256: CipherType(.RSA_WITH_AES_128_CBC_SHA256), + + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_ECDSA_WITH_AES_256_CBC_SHA384), + ECDHE_RSA_WITH_AES_256_CBC_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_CBC_SHA384), + // tls 1.2 gcm + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), + ECDHE_RSA_WITH_AES_128_GCM_SHA256: CipherType(.ECDHE_RSA_WITH_AES_128_GCM_SHA256), + ECDHE_RSA_WITH_AES_256_GCM_SHA384: CipherType(.ECDHE_RSA_WITH_AES_256_GCM_SHA384), + // tls 1.2 chacha + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: CipherType(.ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + // tls 1.3 + AES_128_GCM_SHA256: CipherType(.AES_128_GCM_SHA256), + AES_256_GCM_SHA384: CipherType(.AES_256_GCM_SHA384), + CHACHA20_POLY1305_SHA256: CipherType(.CHACHA20_POLY1305_SHA256), + AEGIS_128L_SHA256: CipherType(.AEGIS_128L_SHA256), + + // tls 1.2 application cipher + pub fn initTls12(tag: CipherSuite, key_material: []const u8, side: proto.Side) !Cipher { + switch (tag) { + inline .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => |comptime_tag| { + return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(key_material, side)); + }, + else => return error.TlsIllegalParameter, + } + } + + // tls 1.3 handshake or application cipher + pub fn initTls13(tag: CipherSuite, secret: Transcript.Secret, side: proto.Side) !Cipher { + return switch (tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |comptime_tag| { + return @unionInit(Cipher, @tagName(comptime_tag), CipherType(comptime_tag).init(secret, side)); + }, + else => return error.TlsIllegalParameter, + }; + } + + pub fn encrypt( + c: *Cipher, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + return switch (c.*) { + inline else => |*f| try f.encrypt(buf, content_type, cleartext), + }; + } + + pub fn decrypt( + c: *Cipher, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + return switch (c.*) { + inline else => |*f| { + const content_type, const cleartext = try f.decrypt(buf, rec); + if (cleartext.len > max_cleartext_len) return error.TlsRecordOverflow; + return .{ content_type, cleartext }; + }, + }; + } + + pub fn encryptSeq(c: Cipher) u64 { + return switch (c) { + inline else => |f| f.encrypt_seq, + }; + } + + pub fn keyUpdateEncrypt(c: *Cipher) !void { + return switch (c.*) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |*f| f.keyUpdateEncrypt(), + // can't happen on tls 1.2 + else => return error.TlsUnexpectedMessage, + }; + } + pub fn keyUpdateDecrypt(c: *Cipher) !void { + return switch (c.*) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_128L_SHA256, + => |*f| f.keyUpdateDecrypt(), + // can't happen on tls 1.2 + else => return error.TlsUnexpectedMessage, + }; + } +}; + +fn Aead12Type(comptime AeadType: type) type { + return struct { + const explicit_iv_len = 8; + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const iv_len = AeadType.nonce_length - explicit_iv_len; + const encrypt_overhead = record.header_len + explicit_iv_len + auth_tag_len; + + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [iv_len]u8, + decrypt_iv: [iv_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + rnd: std.Random = crypto.random, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_key = key_material[0..key_len].*; + const server_key = key_material[key_len..][0..key_len].*; + const client_iv = key_material[2 * key_len ..][0..iv_len].*; + const server_iv = key_material[2 * key_len + iv_len ..][0..iv_len].*; + return .{ + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + .encrypt_iv = if (side == .client) client_iv else server_iv, + .decrypt_iv = if (side == .client) server_iv else client_iv, + }; + } + + /// Returns encrypted tls record in format: + /// ----------------- buf ---------------------- + /// header | explicit_iv | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// explicit_iv: 8 bytes + /// ciphertext: same length as cleartext + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const record_len = record.header_len + explicit_iv_len + cleartext.len + auth_tag_len; + if (buf.len < record_len) return error.BufferOverflow; + + const header = buf[0..record.header_len]; + const explicit_iv = buf[record.header_len..][0..explicit_iv_len]; + const ciphertext = buf[record.header_len + explicit_iv_len ..][0..cleartext.len]; + const auth_tag = buf[record.header_len + explicit_iv_len + cleartext.len ..][0..auth_tag_len]; + + header.* = record.header(content_type, explicit_iv_len + cleartext.len + auth_tag_len); + self.rnd.bytes(explicit_iv); + const iv = self.encrypt_iv ++ explicit_iv.*; + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ----------- payload --------------- + /// header | explicit_iv | ciphertext | auth_tag + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = explicit_iv_len + auth_tag_len; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const cleartext_len = rec.payload.len - overhead; + if (buf.len < cleartext_len) return error.BufferOverflow; + + const explicit_iv = rec.payload[0..explicit_iv_len]; + const ciphertext = rec.payload[explicit_iv_len..][0..cleartext_len]; + const auth_tag = rec.payload[explicit_iv_len + cleartext_len ..][0..auth_tag_len]; + + const iv = self.decrypt_iv ++ explicit_iv.*; + const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + const cleartext = buf[0..cleartext_len]; + AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; + self.decrypt_seq +%= 1; + return .{ rec.content_type, cleartext }; + } + }; +} + +fn Aead12ChaChaType(comptime AeadType: type) type { + return struct { + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const encrypt_overhead = record.header_len + auth_tag_len; + + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [nonce_len]u8, + decrypt_iv: [nonce_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_key = key_material[0..key_len].*; + const server_key = key_material[key_len..][0..key_len].*; + const client_iv = key_material[2 * key_len ..][0..nonce_len].*; + const server_iv = key_material[2 * key_len + nonce_len ..][0..nonce_len].*; + return .{ + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + .encrypt_iv = if (side == .client) client_iv else server_iv, + .decrypt_iv = if (side == .client) server_iv else client_iv, + }; + } + + /// Returns encrypted tls record in format: + /// ------------ buf ------------- + /// header | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// ciphertext: same length as cleartext + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const record_len = record.header_len + cleartext.len + auth_tag_len; + if (buf.len < record_len) return error.BufferOverflow; + + const ciphertext = buf[record.header_len..][0..cleartext.len]; + const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; + + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); + AeadType.encrypt(ciphertext, auth_tag, cleartext, &ad, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + + buf[0..record.header_len].* = record.header(content_type, ciphertext.len + auth_tag.len); + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ----- payload ------- + /// header | ciphertext | auth_tag + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = auth_tag_len; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const cleartext_len = rec.payload.len - overhead; + if (buf.len < cleartext_len) return error.BufferOverflow; + + const ciphertext = rec.payload[0..cleartext_len]; + const auth_tag = rec.payload[cleartext_len..][0..auth_tag_len]; + const cleartext = buf[0..cleartext_len]; + + const ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); + AeadType.decrypt(cleartext, ciphertext, auth_tag.*, &ad, iv, self.decrypt_key) catch return error.TlsDecryptError; + self.decrypt_seq +%= 1; + return .{ rec.content_type, cleartext }; + } + }; +} + +fn Aead13Type(comptime AeadType: type, comptime Hash: type) type { + return struct { + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + + const key_len = AeadType.key_length; + const auth_tag_len = AeadType.tag_length; + const nonce_len = AeadType.nonce_length; + const digest_len = Hash.digest_length; + const encrypt_overhead = record.header_len + 1 + auth_tag_len; + + encrypt_secret: [digest_len]u8, + decrypt_secret: [digest_len]u8, + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_iv: [nonce_len]u8, + decrypt_iv: [nonce_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + + const Self = @This(); + + pub fn init(secret: Transcript.Secret, side: proto.Side) Self { + var self = Self{ + .encrypt_secret = if (side == .client) secret.client[0..digest_len].* else secret.server[0..digest_len].*, + .decrypt_secret = if (side == .server) secret.client[0..digest_len].* else secret.server[0..digest_len].*, + .encrypt_key = undefined, + .decrypt_key = undefined, + .encrypt_iv = undefined, + .decrypt_iv = undefined, + }; + self.keyGenerate(); + return self; + } + + fn keyGenerate(self: *Self) void { + self.encrypt_key = hkdfExpandLabel(Hkdf, self.encrypt_secret, "key", "", key_len); + self.decrypt_key = hkdfExpandLabel(Hkdf, self.decrypt_secret, "key", "", key_len); + self.encrypt_iv = hkdfExpandLabel(Hkdf, self.encrypt_secret, "iv", "", nonce_len); + self.decrypt_iv = hkdfExpandLabel(Hkdf, self.decrypt_secret, "iv", "", nonce_len); + } + + pub fn keyUpdateEncrypt(self: *Self) void { + self.encrypt_secret = hkdfExpandLabel(Hkdf, self.encrypt_secret, "traffic upd", "", digest_len); + self.encrypt_seq = 0; + self.keyGenerate(); + } + + pub fn keyUpdateDecrypt(self: *Self) void { + self.decrypt_secret = hkdfExpandLabel(Hkdf, self.decrypt_secret, "traffic upd", "", digest_len); + self.decrypt_seq = 0; + self.keyGenerate(); + } + + /// Returns encrypted tls record in format: + /// ------------ buf ------------- + /// header | ciphertext | auth_tag + /// + /// tls record header: 5 bytes + /// ciphertext: cleartext len + 1 byte content type + /// auth_tag: 16 bytes + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const payload_len = cleartext.len + 1 + auth_tag_len; + const record_len = record.header_len + payload_len; + if (buf.len < record_len) return error.BufferOverflow; + + const header = buf[0..record.header_len]; + header.* = record.header(.application_data, payload_len); + + // Skip @memcpy if cleartext is already part of the buf at right position + if (&cleartext[0] != &buf[record.header_len]) { + @memcpy(buf[record.header_len..][0..cleartext.len], cleartext); + } + buf[record.header_len + cleartext.len] = @intFromEnum(content_type); + const ciphertext = buf[record.header_len..][0 .. cleartext.len + 1]; + const auth_tag = buf[record.header_len + ciphertext.len ..][0..auth_tag_len]; + + const iv = ivWithSeq(nonce_len, self.encrypt_iv, self.encrypt_seq); + AeadType.encrypt(ciphertext, auth_tag, ciphertext, header, iv, self.encrypt_key); + self.encrypt_seq +%= 1; + return buf[0..record_len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + /// Accepts tls record header and payload: + /// header | ------- payload --------- + /// header | ciphertext | auth_tag + /// header | cleartext + ct | auth_tag + /// Ciphertext after decryption contains cleartext and content type (1 byte). + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + const overhead = auth_tag_len + 1; + if (rec.payload.len < overhead) return error.TlsDecryptError; + const ciphertext_len = rec.payload.len - auth_tag_len; + if (buf.len < ciphertext_len) return error.BufferOverflow; + + const ciphertext = rec.payload[0..ciphertext_len]; + const auth_tag = rec.payload[ciphertext_len..][0..auth_tag_len]; + + const iv = ivWithSeq(nonce_len, self.decrypt_iv, self.decrypt_seq); + AeadType.decrypt(buf[0..ciphertext_len], ciphertext, auth_tag.*, rec.header, iv, self.decrypt_key) catch return error.TlsBadRecordMac; + + // Remove zero bytes padding + var content_type_idx: usize = ciphertext_len - 1; + while (buf[content_type_idx] == 0 and content_type_idx > 0) : (content_type_idx -= 1) {} + + const cleartext = buf[0..content_type_idx]; + const content_type: proto.ContentType = @enumFromInt(buf[content_type_idx]); + self.decrypt_seq +%= 1; + return .{ content_type, cleartext }; + } + }; +} + +fn CbcType(comptime BlockCipher: type, comptime HashType: type) type { + const CBC = @import("cbc/main.zig").CBC(BlockCipher); + return struct { + const mac_len = HashType.digest_length; // 20, 32, 48 bytes for sha1, sha256, sha384 + const key_len = BlockCipher.key_bits / 8; // 16, 32 for Aes128, Aes256 + const iv_len = 16; + const encrypt_overhead = record.header_len + iv_len + mac_len + max_padding; + + pub const Hmac = crypto.auth.hmac.Hmac(HashType); + const paddedLength = CBC.paddedLength; + const max_padding = 16; + + encrypt_secret: [mac_len]u8, + decrypt_secret: [mac_len]u8, + encrypt_key: [key_len]u8, + decrypt_key: [key_len]u8, + encrypt_seq: u64 = 0, + decrypt_seq: u64 = 0, + rnd: std.Random = crypto.random, + + const Self = @This(); + + fn init(key_material: []const u8, side: proto.Side) Self { + const client_secret = key_material[0..mac_len].*; + const server_secret = key_material[mac_len..][0..mac_len].*; + const client_key = key_material[2 * mac_len ..][0..key_len].*; + const server_key = key_material[2 * mac_len + key_len ..][0..key_len].*; + return .{ + .encrypt_secret = if (side == .client) client_secret else server_secret, + .decrypt_secret = if (side == .client) server_secret else client_secret, + .encrypt_key = if (side == .client) client_key else server_key, + .decrypt_key = if (side == .client) server_key else client_key, + }; + } + + /// Returns encrypted tls record in format: + /// ----------------- buf ----------------- + /// header | iv | ------ ciphertext ------- + /// header | iv | cleartext | mac | padding + /// + /// tls record header: 5 bytes + /// iv: 16 bytes + /// ciphertext: cleartext length + mac + padding + /// mac: 20, 32 or 48 (sha1, sha256, sha384) + /// padding: 1-16 bytes + /// + /// Max encrypt buf overhead = iv + mac + padding (1-16) + /// aes_128_cbc_sha => 16 + 20 + 16 = 52 + /// aes_128_cbc_sha256 => 16 + 32 + 16 = 64 + /// aes_256_cbc_sha384 => 16 + 48 + 16 = 80 + pub fn encrypt( + self: *Self, + buf: []u8, + content_type: proto.ContentType, + cleartext: []const u8, + ) ![]const u8 { + const max_record_len = record.header_len + iv_len + cleartext.len + mac_len + max_padding; + if (buf.len < max_record_len) return error.BufferOverflow; + const cleartext_idx = record.header_len + iv_len; // position of cleartext in buf + @memcpy(buf[cleartext_idx..][0..cleartext.len], cleartext); + + { // calculate mac from (ad + cleartext) + // ... | ad | cleartext | mac | ... + // | -- mac msg -- | mac | + const ad = additionalData(self.encrypt_seq, content_type, cleartext.len); + const mac_msg = buf[cleartext_idx - ad.len ..][0 .. ad.len + cleartext.len]; + @memcpy(mac_msg[0..ad.len], &ad); + const mac = buf[cleartext_idx + cleartext.len ..][0..mac_len]; + Hmac.create(mac, mac_msg, &self.encrypt_secret); + self.encrypt_seq +%= 1; + } + + // ... | cleartext | mac | ... + // ... | -- plaintext --- ... + // ... | cleartext | mac | padding + // ... | ------ ciphertext ------- + const unpadded_len = cleartext.len + mac_len; + const padded_len = paddedLength(unpadded_len); + const plaintext = buf[cleartext_idx..][0..unpadded_len]; + const ciphertext = buf[cleartext_idx..][0..padded_len]; + + // Add header and iv at the buf start + // header | iv | ... + buf[0..record.header_len].* = record.header(content_type, iv_len + ciphertext.len); + const iv = buf[record.header_len..][0..iv_len]; + self.rnd.bytes(iv); + + // encrypt plaintext into ciphertext + CBC.init(self.encrypt_key).encrypt(ciphertext, plaintext, iv[0..iv_len].*); + + // header | iv | ------ ciphertext ------- + return buf[0 .. record.header_len + iv_len + ciphertext.len]; + } + + /// Decrypts payload into cleartext. Returns tls record content type and + /// cleartext. + pub fn decrypt( + self: *Self, + buf: []u8, + rec: Record, + ) !struct { proto.ContentType, []u8 } { + if (rec.payload.len < iv_len + mac_len + 1) return error.TlsDecryptError; + + // --------- payload ------------ + // iv | ------ ciphertext ------- + // iv | cleartext | mac | padding + const iv = rec.payload[0..iv_len]; + const ciphertext = rec.payload[iv_len..]; + + if (buf.len < ciphertext.len + additional_data_len) return error.BufferOverflow; + // ---------- buf --------------- + // ad | ------ plaintext -------- + // ad | cleartext | mac | padding + const plaintext = buf[additional_data_len..][0..ciphertext.len]; + // decrypt ciphertext -> plaintext + CBC.init(self.decrypt_key).decrypt(plaintext, ciphertext, iv[0..iv_len].*) catch return error.TlsDecryptError; + + // get padding len from last padding byte + const padding_len = plaintext[plaintext.len - 1] + 1; + if (plaintext.len < mac_len + padding_len) return error.TlsDecryptError; + // split plaintext into cleartext and mac + const cleartext_len = plaintext.len - mac_len - padding_len; + const cleartext = plaintext[0..cleartext_len]; + const mac = plaintext[cleartext_len..][0..mac_len]; + + // write ad to the buf + var ad = additionalData(self.decrypt_seq, rec.content_type, cleartext_len); + @memcpy(buf[0..ad.len], &ad); + const mac_msg = buf[0 .. ad.len + cleartext_len]; + self.decrypt_seq +%= 1; + + // calculate expected mac and compare with received mac + var expected_mac: [mac_len]u8 = undefined; + Hmac.create(&expected_mac, mac_msg, &self.decrypt_secret); + if (!std.mem.eql(u8, &expected_mac, mac)) + return error.TlsBadRecordMac; + + return .{ rec.content_type, cleartext }; + } + }; +} + +// xor lower 8 iv bytes with seq +fn ivWithSeq(comptime nonce_len: usize, iv: [nonce_len]u8, seq: u64) [nonce_len]u8 { + var res = iv; + const buf = res[nonce_len - 8 ..]; + const operand = std.mem.readInt(u64, buf, .big); + std.mem.writeInt(u64, buf, operand ^ seq, .big); + return res; +} + +pub const additional_data_len = record.header_len + @sizeOf(u64); + +fn additionalData(seq: u64, content_type: proto.ContentType, payload_len: usize) [additional_data_len]u8 { + const header = record.header(content_type, payload_len); + var seq_buf: [8]u8 = undefined; + std.mem.writeInt(u64, &seq_buf, seq, .big); + return seq_buf ++ header; +} + +// Cipher suites lists. In the order of preference. +// For the preference using grades priority and rules from Go project. +// https://ciphersuite.info/page/faq/ +// https://github.com/golang/go/blob/73186ba00251b3ed8baaab36e4f5278c7681155b/src/crypto/tls/cipher_suites.go#L226 +pub const cipher_suites = struct { + const tls12_secure = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + // recommended + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + // secure + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + } else [_]CipherSuite{ + // recommended + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + + // secure + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + }; + const tls12_week = [_]CipherSuite{ + // week + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + + .RSA_WITH_AES_128_CBC_SHA256, + .RSA_WITH_AES_128_CBC_SHA, + }; + pub const tls13_ = if (crypto.core.aes.has_hardware_support) [_]CipherSuite{ + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + // Excluded because didn't find server which supports it to test + // .AEGIS_128L_SHA256 + } else [_]CipherSuite{ + .CHACHA20_POLY1305_SHA256, + .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + }; + + pub const tls13 = &tls13_; + pub const tls12 = &(tls12_secure ++ tls12_week); + pub const secure = &(tls13_ ++ tls12_secure); + pub const all = &(tls13_ ++ tls12_secure ++ tls12_week); + + pub fn includes(list: []const CipherSuite, cs: CipherSuite) bool { + for (list) |s| { + if (cs == s) return true; + } + return false; + } +}; + +// Week, secure, recommended grades are from https://ciphersuite.info/page/faq/ +pub const CipherSuite = enum(u16) { + // tls 1.2 cbc sha1 + ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xc009, // week + ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xc013, // week + RSA_WITH_AES_128_CBC_SHA = 0x002F, // week + // tls 1.2 cbc sha256 + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xc023, // week + ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xc027, // week + RSA_WITH_AES_128_CBC_SHA256 = 0x003c, // week + // tls 1.2 cbc sha384 + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xc024, // week + ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xc028, // week + // tls 1.2 gcm + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xc02b, // recommended + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xc02c, // recommended + ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xc02f, // secure + ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xc030, // secure + // tls 1.2 chacha + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca9, // recommended + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xcca8, // secure + // tls 1.3 (all are recommended) + AES_128_GCM_SHA256 = 0x1301, + AES_256_GCM_SHA384 = 0x1302, + CHACHA20_POLY1305_SHA256 = 0x1303, + AEGIS_128L_SHA256 = 0x1307, + // AEGIS_256_SHA512 = 0x1306, + _, + + pub fn validate(cs: CipherSuite) !void { + if (cipher_suites.includes(cipher_suites.tls12, cs)) return; + if (cipher_suites.includes(cipher_suites.tls13, cs)) return; + return error.TlsIllegalParameter; + } + + pub const Versions = enum { + both, + tls_1_3, + tls_1_2, + }; + + // get tls versions from list of cipher suites + pub fn versions(list: []const CipherSuite) !Versions { + var has_12 = false; + var has_13 = false; + for (list) |cs| { + if (cipher_suites.includes(cipher_suites.tls12, cs)) { + has_12 = true; + } else { + if (cipher_suites.includes(cipher_suites.tls13, cs)) has_13 = true; + } + } + if (has_12 and has_13) return .both; + if (has_12) return .tls_1_2; + if (has_13) return .tls_1_3; + return error.TlsIllegalParameter; + } + + pub const KeyExchangeAlgorithm = enum { + ecdhe, + rsa, + }; + + pub fn keyExchange(s: CipherSuite) KeyExchangeAlgorithm { + return switch (s) { + // Random premaster secret, encrypted with publich key from certificate. + // No server key exchange message. + .RSA_WITH_AES_128_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA256, + => .rsa, + else => .ecdhe, + }; + } + + pub const HashTag = enum { + sha256, + sha384, + sha512, + }; + + pub inline fn hash(cs: CipherSuite) HashTag { + return switch (cs) { + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .AES_256_GCM_SHA384, + => .sha384, + else => .sha256, + }; + } +}; + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "CipherSuite validate" { + { + const cs: CipherSuite = .AES_256_GCM_SHA384; + try cs.validate(); + try testing.expectEqual(cs.hash(), .sha384); + try testing.expectEqual(cs.keyExchange(), .ecdhe); + } + { + const cs: CipherSuite = .AES_128_GCM_SHA256; + try cs.validate(); + try testing.expectEqual(.sha256, cs.hash()); + try testing.expectEqual(.ecdhe, cs.keyExchange()); + } + for (cipher_suites.tls12) |cs| { + try cs.validate(); + _ = cs.hash(); + _ = cs.keyExchange(); + } +} + +test "CipherSuite versions" { + try testing.expectEqual(.tls_1_3, CipherSuite.versions(&[_]CipherSuite{.AES_128_GCM_SHA256})); + try testing.expectEqual(.both, CipherSuite.versions(&[_]CipherSuite{ .AES_128_GCM_SHA256, .ECDHE_ECDSA_WITH_AES_128_CBC_SHA })); + try testing.expectEqual(.tls_1_2, CipherSuite.versions(&[_]CipherSuite{.RSA_WITH_AES_128_CBC_SHA})); +} + +test "gcm 1.2 encrypt overhead" { + inline for ([_]type{ + Aead12Aes128Gcm, + Aead12Aes256Gcm, + }) |T| { + { + const expected_key_len = switch (T) { + Aead12Aes128Gcm => 16, + Aead12Aes256Gcm => 32, + else => unreachable, + }; + try testing.expectEqual(expected_key_len, T.key_len); + try testing.expectEqual(16, T.auth_tag_len); + try testing.expectEqual(12, T.nonce_len); + try testing.expectEqual(4, T.iv_len); + try testing.expectEqual(29, T.encrypt_overhead); + } + } +} + +test "cbc 1.2 encrypt overhead" { + try testing.expectEqual(85, encrypt_overhead_tls_12); + + inline for ([_]type{ + CbcAes128Sha1, + CbcAes128Sha256, + CbcAes256Sha384, + }) |T| { + switch (T) { + CbcAes128Sha1 => { + try testing.expectEqual(20, T.mac_len); + try testing.expectEqual(16, T.key_len); + try testing.expectEqual(57, T.encrypt_overhead); + }, + CbcAes128Sha256 => { + try testing.expectEqual(32, T.mac_len); + try testing.expectEqual(16, T.key_len); + try testing.expectEqual(69, T.encrypt_overhead); + }, + CbcAes256Sha384 => { + try testing.expectEqual(48, T.mac_len); + try testing.expectEqual(32, T.key_len); + try testing.expectEqual(85, T.encrypt_overhead); + }, + else => unreachable, + } + try testing.expectEqual(16, T.paddedLength(1)); // cbc block padding + try testing.expectEqual(16, T.iv_len); + } +} + +test "overhead tls 1.3" { + try testing.expectEqual(22, encrypt_overhead_tls_13); + try expectOverhead(Aes128GcmSha256, 16, 16, 12, 22); + try expectOverhead(Aes256GcmSha384, 32, 16, 12, 22); + try expectOverhead(ChaChaSha256, 32, 16, 12, 22); + try expectOverhead(Aegis128Sha256, 16, 16, 16, 22); + // and tls 1.2 chacha + try expectOverhead(Aead12ChaCha, 32, 16, 12, 21); +} + +fn expectOverhead(T: type, key_len: usize, auth_tag_len: usize, nonce_len: usize, overhead: usize) !void { + try testing.expectEqual(key_len, T.key_len); + try testing.expectEqual(auth_tag_len, T.auth_tag_len); + try testing.expectEqual(nonce_len, T.nonce_len); + try testing.expectEqual(overhead, T.encrypt_overhead); +} + +test "client/server encryption tls 1.3" { + inline for (cipher_suites.tls13) |cs| { + var buf: [256]u8 = undefined; + testu.fill(&buf); + const secret = Transcript.Secret{ + .client = buf[0..128], + .server = buf[128..], + }; + var client_cipher = try Cipher.initTls13(cs, secret, .client); + var server_cipher = try Cipher.initTls13(cs, secret, .server); + try encryptDecrypt(&client_cipher, &server_cipher); + + try client_cipher.keyUpdateEncrypt(); + try server_cipher.keyUpdateDecrypt(); + try encryptDecrypt(&client_cipher, &server_cipher); + + try client_cipher.keyUpdateDecrypt(); + try server_cipher.keyUpdateEncrypt(); + try encryptDecrypt(&client_cipher, &server_cipher); + } +} + +test "client/server encryption tls 1.2" { + inline for (cipher_suites.tls12) |cs| { + var key_material: [256]u8 = undefined; + testu.fill(&key_material); + var client_cipher = try Cipher.initTls12(cs, &key_material, .client); + var server_cipher = try Cipher.initTls12(cs, &key_material, .server); + try encryptDecrypt(&client_cipher, &server_cipher); + } +} + +fn encryptDecrypt(client_cipher: *Cipher, server_cipher: *Cipher) !void { + const cleartext = + \\ Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + \\ eiusmod tempor incididunt ut labore et dolore magna aliqua. + ; + var buf: [256]u8 = undefined; + + { // client to server + // encrypt + const encrypted = try client_cipher.encrypt(&buf, .application_data, cleartext); + const expected_encrypted_len = switch (client_cipher.*) { + inline else => |f| brk: { + const T = @TypeOf(f); + break :brk switch (T) { + CbcAes128Sha1, + CbcAes128Sha256, + CbcAes256Sha256, + CbcAes256Sha384, + => record.header_len + T.paddedLength(T.iv_len + cleartext.len + T.mac_len), + Aead12Aes128Gcm, + Aead12Aes256Gcm, + Aead12ChaCha, + Aes128GcmSha256, + Aes256GcmSha384, + ChaChaSha256, + Aegis128Sha256, + => cleartext.len + T.encrypt_overhead, + else => unreachable, + }; + }, + }; + try testing.expectEqual(expected_encrypted_len, encrypted.len); + // decrypt + const content_type, const decrypted = try server_cipher.decrypt(&buf, Record.init(encrypted)); + try testing.expectEqualSlices(u8, cleartext, decrypted); + try testing.expectEqual(.application_data, content_type); + } + // server to client + { + const encrypted = try server_cipher.encrypt(&buf, .application_data, cleartext); + const content_type, const decrypted = try client_cipher.decrypt(&buf, Record.init(encrypted)); + try testing.expectEqualSlices(u8, cleartext, decrypted); + try testing.expectEqual(.application_data, content_type); + } +} diff --git a/src/tls.zig/connection.zig b/src/tls.zig/connection.zig new file mode 100644 index 0000000..8ac8935 --- /dev/null +++ b/src/tls.zig/connection.zig @@ -0,0 +1,666 @@ +const std = @import("std"); +const assert = std.debug.assert; + +const proto = @import("protocol.zig"); +const record = @import("record.zig"); +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; + +const async_io = @import("../io.zig"); +const Cbk = async_io.Cbk; +const Loop = async_io.Blocking; +const Ctx = @import("../std/http/Client.zig").Ctx; + +pub fn connection(stream: anytype) Connection(@TypeOf(stream)) { + return .{ + .stream = stream, + .rec_rdr = record.reader(stream), + }; +} + +pub fn Connection(comptime Stream: type) type { + return struct { + stream: Stream, // underlying stream + rec_rdr: record.Reader(Stream), + cipher: Cipher = undefined, + + max_encrypt_seq: u64 = std.math.maxInt(u64) - 1, + key_update_requested: bool = false, + + read_buf: []const u8 = "", + received_close_notify: bool = false, + + const Self = @This(); + + /// Encrypts and writes single tls record to the stream. + fn writeRecord(c: *Self, content_type: proto.ContentType, bytes: []const u8) !void { + assert(bytes.len <= cipher.max_cleartext_len); + var write_buf: [cipher.max_ciphertext_record_len]u8 = undefined; + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); + + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&write_buf, .handshake, key_update); + try c.stream.writeAll(rec); + try c.cipher.keyUpdateEncrypt(); + } + const rec = try c.cipher.encrypt(&write_buf, content_type, bytes); + try c.stream.writeAll(rec); + } + + fn writeAlert(c: *Self, err: anyerror) !void { + const cleartext = proto.alertFromError(err); + var buf: [128]u8 = undefined; + const ciphertext = try c.cipher.encrypt(&buf, .alert, &cleartext); + c.stream.writeAll(ciphertext) catch {}; + } + + /// Returns next record of cleartext data. + /// Can be used in iterator like loop without memcpy to another buffer: + /// while (try client.next()) |buf| { ... } + pub fn next(c: *Self) ReadError!?[]const u8 { + const content_type, const data = c.nextRecord() catch |err| { + try c.writeAlert(err); + return err; + } orelse return null; + if (content_type != .application_data) return error.TlsUnexpectedMessage; + return data; + } + + fn nextRecord(c: *Self) ReadError!?struct { proto.ContentType, []const u8 } { + if (c.eof()) return null; + while (true) { + const content_type, const cleartext = try c.rec_rdr.nextDecrypt(&c.cipher) orelse return null; + + switch (content_type) { + .application_data => {}, + .handshake => { + const handshake_type: proto.Handshake = @enumFromInt(cleartext[0]); + switch (handshake_type) { + // skip new session ticket and read next record + .new_session_ticket => continue, + .key_update => { + if (cleartext.len != 5) return error.TlsDecodeError; + // rfc: Upon receiving a KeyUpdate, the receiver MUST + // update its receiving keys. + try c.cipher.keyUpdateDecrypt(); + const key: proto.KeyUpdateRequest = @enumFromInt(cleartext[4]); + switch (key) { + .update_requested => { + @atomicStore(bool, &c.key_update_requested, true, .monotonic); + }, + .update_not_requested => {}, + else => return error.TlsIllegalParameter, + } + // this record is handled read next + continue; + }, + else => {}, + } + }, + .alert => { + if (cleartext.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(cleartext[0..2].*).toError(); + // server side clean shutdown + c.received_close_notify = true; + return null; + }, + else => return error.TlsUnexpectedMessage, + } + return .{ content_type, cleartext }; + } + } + + pub fn eof(c: *Self) bool { + return c.received_close_notify and c.read_buf.len == 0; + } + + pub fn close(c: *Self) !void { + if (c.received_close_notify) return; + try c.writeRecord(.alert, &proto.Alert.closeNotify()); + } + + // read, write interface + + pub const ReadError = Stream.ReadError || proto.Alert.Error || + error{ + TlsBadVersion, + TlsUnexpectedMessage, + TlsRecordOverflow, + TlsDecryptError, + TlsDecodeError, + TlsBadRecordMac, + TlsIllegalParameter, + BufferOverflow, + }; + pub const WriteError = Stream.WriteError || + error{ + BufferOverflow, + TlsUnexpectedMessage, + }; + + pub const Reader = std.io.Reader(*Self, ReadError, read); + pub const Writer = std.io.Writer(*Self, WriteError, write); + + pub fn reader(c: *Self) Reader { + return .{ .context = c }; + } + + pub fn writer(c: *Self) Writer { + return .{ .context = c }; + } + + /// Encrypts cleartext and writes it to the underlying stream as single + /// tls record. Max single tls record payload length is 1<<14 (16K) + /// bytes. + pub fn write(c: *Self, bytes: []const u8) WriteError!usize { + const n = @min(bytes.len, cipher.max_cleartext_len); + try c.writeRecord(.application_data, bytes[0..n]); + return n; + } + + /// Encrypts cleartext and writes it to the underlying stream. If needed + /// splits cleartext into multiple tls record. + pub fn writeAll(c: *Self, bytes: []const u8) WriteError!void { + var index: usize = 0; + while (index < bytes.len) { + index += try c.write(bytes[index..]); + } + } + + pub fn read(c: *Self, buffer: []u8) ReadError!usize { + if (c.read_buf.len == 0) { + c.read_buf = try c.next() orelse return 0; + } + const n = @min(c.read_buf.len, buffer.len); + @memcpy(buffer[0..n], c.read_buf[0..n]); + c.read_buf = c.read_buf[n..]; + return n; + } + + /// Returns the number of bytes read. If the number read is smaller than + /// `buffer.len`, it means the stream reached the end. + pub fn readAll(c: *Self, buffer: []u8) ReadError!usize { + return c.readAtLeast(buffer, buffer.len); + } + + /// Returns the number of bytes read, calling the underlying read function + /// the minimal number of times until the buffer has at least `len` bytes + /// filled. If the number read is less than `len` it means the stream + /// reached the end. + pub fn readAtLeast(c: *Self, buffer: []u8, len: usize) ReadError!usize { + assert(len <= buffer.len); + var index: usize = 0; + while (index < len) { + const amt = try c.read(buffer[index..]); + if (amt == 0) break; + index += amt; + } + return index; + } + + /// Returns the number of bytes read. If the number read is less than + /// the space provided it means the stream reached the end. + pub fn readv(c: *Self, iovecs: []std.posix.iovec) !usize { + var vp: VecPut = .{ .iovecs = iovecs }; + while (true) { + if (c.read_buf.len == 0) { + c.read_buf = try c.next() orelse break; + } + const n = vp.put(c.read_buf); + const read_buf_len = c.read_buf.len; + c.read_buf = c.read_buf[n..]; + if ((n < read_buf_len) or + (n == read_buf_len and !c.rec_rdr.hasMore())) + break; + } + return vp.total; + } + + fn onWriteAll(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_write_bytes.len - ctx._tls_write_index > 0) { + const rec = ctx.conn().tls_client.prepareRecord(ctx.stream(), ctx) catch |err| return ctx.pop(err); + return ctx.stream().async_writeAll(rec, ctx, onWriteAll) catch |err| return ctx.pop(err); + } + + return ctx.pop({}); + } + + pub fn async_writeAll(c: *Self, stream: anytype, bytes: []const u8, ctx: *Ctx, comptime cbk: Cbk) !void { + assert(bytes.len <= cipher.max_cleartext_len); + + ctx._tls_write_bytes = bytes; + ctx._tls_write_index = 0; + const rec = try c.prepareRecord(stream, ctx); + + try ctx.push(cbk); + return stream.async_writeAll(rec, ctx, onWriteAll); + } + + fn prepareRecord(c: *Self, stream: anytype, ctx: *Ctx) ![]const u8 { + const len = @min(ctx._tls_write_bytes.len - ctx._tls_write_index, cipher.max_cleartext_len); + + // If key update is requested send key update message and update + // my encryption keys. + if (c.cipher.encryptSeq() >= c.max_encrypt_seq or @atomicLoad(bool, &c.key_update_requested, .monotonic)) { + @atomicStore(bool, &c.key_update_requested, false, .monotonic); + + // If the request_update field is set to "update_requested", + // then the receiver MUST send a KeyUpdate of its own with + // request_update set to "update_not_requested" prior to sending + // its next Application Data record. This mechanism allows + // either side to force an update to the entire connection, but + // causes an implementation which receives multiple KeyUpdates + // while it is silent to respond with a single update. + // + // rfc: https://datatracker.ietf.org/doc/html/rfc8446#autoid-57 + const key_update = &record.handshakeHeader(.key_update, 1) ++ [_]u8{0}; + const rec = try c.cipher.encrypt(&ctx._tls_write_buf, .handshake, key_update); + try stream.writeAll(rec); // TODO async + try c.cipher.keyUpdateEncrypt(); + } + + defer ctx._tls_write_index += len; + return c.cipher.encrypt(&ctx._tls_write_buf, .application_data, ctx._tls_write_bytes[ctx._tls_write_index..len]); + } + + fn onReadv(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + if (ctx._tls_read_buf == null) { + // end of read + ctx.setLen(ctx._vp.total); + return ctx.pop({}); + } + + while (true) { + const n = ctx._vp.put(ctx._tls_read_buf.?); + const read_buf_len = ctx._tls_read_buf.?.len; + const c = ctx.conn().tls_client; + + if (read_buf_len == 0) { + // read another buffer + c.async_next(ctx.stream(), ctx, onReadv) catch |err| return ctx.pop(err); + } + + ctx._tls_read_buf = ctx._tls_read_buf.?[n..]; + + if ((n < read_buf_len) or (n == read_buf_len and !c.rec_rdr.hasMore())) { + // end of read + ctx.setLen(ctx._vp.total); + return ctx.pop({}); + } + } + } + + pub fn async_readv(c: *Self, stream: anytype, iovecs: []std.posix.iovec, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + ctx._vp = .{ .iovecs = iovecs }; + + return c.async_next(stream, ctx, onReadv); + } + + fn onNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| { + ctx.conn().tls_client.writeAlert(err) catch |e| std.log.err("onNext: write alert: {any}", .{e}); // TODO async + return ctx.pop(err); + }; + + if (ctx._tls_read_content_type != .application_data) { + return ctx.pop(error.TlsUnexpectedMessage); + } + + return ctx.pop({}); + } + + pub fn async_next(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_next_decrypt(stream, ctx, onNext); + } + + pub fn onNextDecrypt(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + // TOOD not sure if this works in my async case... + if (c.eof()) { + ctx._tls_read_buf = null; + return ctx.pop({}); + } + + const content_type = ctx._tls_read_content_type; + + switch (content_type) { + .application_data => {}, + .handshake => { + const handshake_type: proto.Handshake = @enumFromInt(ctx._tls_read_buf.?[0]); + switch (handshake_type) { + // skip new session ticket and read next record + .new_session_ticket => return c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err), + .key_update => { + if (ctx._tls_read_buf.?.len != 5) return ctx.pop(error.TlsDecodeError); + // rfc: Upon receiving a KeyUpdate, the receiver MUST + // update its receiving keys. + try c.cipher.keyUpdateDecrypt(); + const key: proto.KeyUpdateRequest = @enumFromInt(ctx._tls_read_buf.?[4]); + switch (key) { + .update_requested => { + @atomicStore(bool, &c.key_update_requested, true, .monotonic); + }, + .update_not_requested => {}, + else => return ctx.pop(error.TlsIllegalParameter), + } + // this record is handled read next + c.async_next_record(ctx.stream(), ctx, onNextDecrypt) catch |err| return ctx.pop(err); + }, + else => {}, + } + }, + .alert => { + if (ctx._tls_read_buf.?.len < 2) return ctx.pop(error.TlsUnexpectedMessage); + try proto.Alert.parse(ctx._tls_read_buf.?[0..2].*).toError(); + // server side clean shutdown + c.received_close_notify = true; + ctx._tls_read_buf = null; + return ctx.pop({}); + }, + else => return ctx.pop(error.TlsUnexpectedMessage), + } + + return ctx.pop({}); + } + + pub fn async_next_decrypt(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_next_record(stream, ctx, onNextDecrypt) catch |err| return ctx.pop(err); + } + + pub fn onNextRecord(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const rec = ctx._tls_read_record orelse { + ctx._tls_read_buf = null; + return ctx.pop({}); + }; + + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + + const c = ctx.conn().tls_client; + const cph = &c.cipher; + + ctx._tls_read_content_type, ctx._tls_read_buf = cph.decrypt( + // Reuse reader buffer for cleartext. `rec.header` and + // `rec.payload`(ciphertext) are also pointing somewhere in + // this buffer. Decrypter is first reading then writing a + // block, cleartext has less length then ciphertext, + // cleartext starts from the beginning of the buffer, so + // ciphertext is always ahead of cleartext. + c.rec_rdr.buffer[0..c.rec_rdr.start], + rec, + ) catch |err| return ctx.pop(err); + + return ctx.pop({}); + } + + pub fn async_next_record(c: *Self, stream: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + + return c.async_reader_next(stream, ctx, onNextRecord); + } + + pub fn onReaderNext(ctx: *Ctx, res: anyerror!void) anyerror!void { + res catch |err| return ctx.pop(err); + + const c = ctx.conn().tls_client; + + const n = ctx.len(); + if (n == 0) { + ctx._tls_read_record = null; + return ctx.pop({}); + } + c.rec_rdr.end += n; + + return c.readNext(ctx); + } + + pub fn readNext(c: *Self, ctx: *Ctx) anyerror!void { + const buffer = c.rec_rdr.buffer[c.rec_rdr.start..c.rec_rdr.end]; + // If we have 5 bytes header. + if (buffer.len >= record.header_len) { + const record_header = buffer[0..record.header_len]; + const payload_len = std.mem.readInt(u16, record_header[3..5], .big); + if (payload_len > cipher.max_ciphertext_len) + return error.TlsRecordOverflow; + const record_len = record.header_len + payload_len; + // If we have whole record + if (buffer.len >= record_len) { + c.rec_rdr.start += record_len; + ctx._tls_read_record = record.Record.init(buffer[0..record_len]); + return ctx.pop({}); + } + } + { // Move dirty part to the start of the buffer. + const n = c.rec_rdr.end - c.rec_rdr.start; + if (n > 0 and c.rec_rdr.start > 0) { + if (c.rec_rdr.start > n) { + @memcpy(c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); + } else { + std.mem.copyForwards(u8, c.rec_rdr.buffer[0..n], c.rec_rdr.buffer[c.rec_rdr.start..][0..n]); + } + } + c.rec_rdr.start = 0; + c.rec_rdr.end = n; + } + // Read more from inner_reader. + return ctx.stream() + .async_read(c.rec_rdr.buffer[c.rec_rdr.end..], ctx, onReaderNext) catch |err| return ctx.pop(err); + } + + pub fn async_reader_next(c: *Self, _: anytype, ctx: *Ctx, comptime cbk: Cbk) !void { + try ctx.push(cbk); + return c.readNext(ctx); + } + }; +} + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const testu = @import("testu.zig"); + +test "encrypt decrypt" { + var output_buf: [1024]u8 = undefined; + const stream = testu.Stream.init(&(data12.server_pong ** 3), &output_buf); + var conn: Connection(@TypeOf(stream)) = .{ .stream = stream, .rec_rdr = record.reader(stream) }; + conn.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.rnd = testu.random(0); // use fixed rng + + conn.stream.output.reset(); + { // encrypt verify data from example + _ = testu.random(0x40); // sets iv to 40, 41, ... 4f + try conn.writeRecord(.handshake, &data12.client_finished); + try testing.expectEqualSlices(u8, &data12.verify_data_encrypted_msg, conn.stream.output.getWritten()); + } + + conn.stream.output.reset(); + { // encrypt ping + const cleartext = "ping"; + _ = testu.random(0); // sets iv to 00, 01, ... 0f + //conn.encrypt_seq = 1; + + try conn.writeAll(cleartext); + try testing.expectEqualSlices(u8, &data12.encrypted_ping_msg, conn.stream.output.getWritten()); + } + { // decrypt server pong message + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + try testing.expectEqualStrings("pong", (try conn.next()).?); + } + { // test reader interface + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + var rdr = conn.reader(); + var buffer: [4]u8 = undefined; + const n = try rdr.readAll(&buffer); + try testing.expectEqualStrings("pong", buffer[0..n]); + } + { // test readv interface + conn.cipher.ECDHE_RSA_WITH_AES_128_CBC_SHA.decrypt_seq = 1; + var buffer: [9]u8 = undefined; + var iovecs = [_]std.posix.iovec{ + .{ .base = &buffer, .len = 3 }, + .{ .base = buffer[3..], .len = 3 }, + .{ .base = buffer[6..], .len = 3 }, + }; + const n = try conn.readv(iovecs[0..]); + try testing.expectEqual(4, n); + try testing.expectEqualStrings("pong", buffer[0..n]); + } +} + +// Copied from: https://github.com/ziglang/zig/blob/455899668b620dfda40252501c748c0a983555bd/lib/std/crypto/tls/Client.zig#L1354 +/// Abstraction for sending multiple byte buffers to a slice of iovecs. +pub const VecPut = struct { + iovecs: []const std.posix.iovec, + idx: usize = 0, + off: usize = 0, + total: usize = 0, + + /// Returns the amount actually put which is always equal to bytes.len + /// unless the vectors ran out of space. + pub fn put(vp: *VecPut, bytes: []const u8) usize { + if (vp.idx >= vp.iovecs.len) return 0; + var bytes_i: usize = 0; + while (true) { + const v = vp.iovecs[vp.idx]; + const dest = v.base[vp.off..v.len]; + const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; + @memcpy(dest[0..src.len], src); + bytes_i += src.len; + vp.off += src.len; + if (vp.off >= v.len) { + vp.off = 0; + vp.idx += 1; + if (vp.idx >= vp.iovecs.len) { + vp.total += bytes_i; + return bytes_i; + } + } + if (bytes_i >= bytes.len) { + vp.total += bytes_i; + return bytes_i; + } + } + } +}; + +test "client/server connection" { + const BufReaderWriter = struct { + buf: []u8, + wp: usize = 0, + rp: usize = 0, + + const Self = @This(); + + pub fn write(self: *Self, bytes: []const u8) !usize { + if (self.wp == self.buf.len) return error.NoSpaceLeft; + + const n = @min(bytes.len, self.buf.len - self.wp); + @memcpy(self.buf[self.wp..][0..n], bytes[0..n]); + self.wp += n; + return n; + } + + pub fn writeAll(self: *Self, bytes: []const u8) !void { + var n: usize = 0; + while (n < bytes.len) { + n += try self.write(bytes[n..]); + } + } + + pub fn read(self: *Self, bytes: []u8) !usize { + const n = @min(bytes.len, self.wp - self.rp); + if (n == 0) return 0; + @memcpy(bytes[0..n], self.buf[self.rp..][0..n]); + self.rp += n; + if (self.rp == self.wp) { + self.wp = 0; + self.rp = 0; + } + return n; + } + }; + + const TestStream = struct { + inner_stream: *BufReaderWriter, + const Self = @This(); + pub const ReadError = error{}; + pub const WriteError = error{NoSpaceLeft}; + pub fn read(self: *Self, bytes: []u8) !usize { + return try self.inner_stream.read(bytes); + } + pub fn writeAll(self: *Self, bytes: []const u8) !void { + return try self.inner_stream.writeAll(bytes); + } + }; + + const buf_len = 32 * 1024; + const tls_records_in_buf = (std.math.divCeil(comptime_int, buf_len, cipher.max_cleartext_len) catch unreachable); + const overhead: usize = tls_records_in_buf * @import("cipher.zig").encrypt_overhead_tls_13; + var buf: [buf_len + overhead]u8 = undefined; + var inner_stream = BufReaderWriter{ .buf = &buf }; + + const cipher_client, const cipher_server = brk: { + const Transcript = @import("transcript.zig").Transcript; + const CipherSuite = @import("cipher.zig").CipherSuite; + const cipher_suite: CipherSuite = .AES_256_GCM_SHA384; + + var rnd: [128]u8 = undefined; + std.crypto.random.bytes(&rnd); + const secret = Transcript.Secret{ + .client = rnd[0..64], + .server = rnd[64..], + }; + + break :brk .{ + try Cipher.initTls13(cipher_suite, secret, .client), + try Cipher.initTls13(cipher_suite, secret, .server), + }; + }; + + var conn1 = connection(TestStream{ .inner_stream = &inner_stream }); + conn1.cipher = cipher_client; + + var conn2 = connection(TestStream{ .inner_stream = &inner_stream }); + conn2.cipher = cipher_server; + + var prng = std.Random.DefaultPrng.init(0); + const random = prng.random(); + var send_buf: [buf_len]u8 = undefined; + var recv_buf: [buf_len]u8 = undefined; + random.bytes(&send_buf); // fill send buffer with random bytes + + for (0..16) |_| { + const n = buf_len; //random.uintLessThan(usize, buf_len); + + const sent = send_buf[0..n]; + try conn1.writeAll(sent); + const r = try conn2.readAll(&recv_buf); + const received = recv_buf[0..r]; + + try testing.expectEqual(n, r); + try testing.expectEqualSlices(u8, sent, received); + } +} diff --git a/src/tls.zig/handshake_client.zig b/src/tls.zig/handshake_client.zig new file mode 100644 index 0000000..e7b48cf --- /dev/null +++ b/src/tls.zig/handshake_client.zig @@ -0,0 +1,955 @@ +const std = @import("std"); +const assert = std.debug.assert; +const crypto = std.crypto; +const mem = std.mem; +const Certificate = crypto.Certificate; + +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const CipherSuite = cipher.CipherSuite; +const cipher_suites = cipher.cipher_suites; +const Transcript = @import("transcript.zig").Transcript; +const record = @import("record.zig"); +const rsa = @import("rsa/rsa.zig"); +const key_log = @import("key_log.zig"); +const PrivateKey = @import("PrivateKey.zig"); +const proto = @import("protocol.zig"); + +const common = @import("handshake_common.zig"); +const dupe = common.dupe; +const CertificateBuilder = common.CertificateBuilder; +const CertificateParser = common.CertificateParser; +const DhKeyPair = common.DhKeyPair; +const CertBundle = common.CertBundle; +const CertKeyPair = common.CertKeyPair; + +pub const Options = struct { + host: []const u8, + /// Set of root certificate authorities that clients use when verifying + /// server certificates. + root_ca: CertBundle, + + /// Controls whether a client verifies the server's certificate chain and + /// host name. + insecure_skip_verify: bool = false, + + /// List of cipher suites to use. + /// To use just tls 1.3 cipher suites: + /// .cipher_suites = &tls.CipherSuite.tls13, + /// To select particular cipher suite: + /// .cipher_suites = &[_]tls.CipherSuite{tls.CipherSuite.CHACHA20_POLY1305_SHA256}, + cipher_suites: []const CipherSuite = cipher_suites.all, + + /// List of named groups to use. + /// To use specific named group: + /// .named_groups = &[_]tls.NamedGroup{.secp384r1}, + named_groups: []const proto.NamedGroup = supported_named_groups, + + /// Client authentication certificates and private key. + auth: ?CertKeyPair = null, + + /// If this structure is provided it will be filled with handshake attributes + /// at the end of the handshake process. + diagnostic: ?*Diagnostic = null, + + /// For logging current connection tls keys, so we can share them with + /// Wireshark and analyze decrypted traffic there. + key_log_callback: ?key_log.Callback = null, + + pub const Diagnostic = struct { + tls_version: proto.Version = @enumFromInt(0), + cipher_suite_tag: CipherSuite = @enumFromInt(0), + named_group: proto.NamedGroup = @enumFromInt(0), + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + client_signature_scheme: proto.SignatureScheme = @enumFromInt(0), + }; +}; + +const supported_named_groups = &[_]proto.NamedGroup{ + .x25519, + .secp256r1, + .secp384r1, + .x25519_kyber768d00, +}; + +/// Handshake parses tls server message and creates client messages. Collects +/// tls attributes: server random, cipher suite and so on. Client messages are +/// created using provided buffer. Provided record reader is used to get tls +/// record when needed. +pub fn Handshake(comptime Stream: type) type { + const RecordReaderT = record.Reader(Stream); + return struct { + client_random: [32]u8, + server_random: [32]u8 = undefined, + master_secret: [48]u8 = undefined, + key_material: [48 * 4]u8 = undefined, // for sha256 32 * 4 is filled, for sha384 48 * 4 + + transcript: Transcript = .{}, + cipher_suite: CipherSuite = @enumFromInt(0), + named_group: ?proto.NamedGroup = null, + dh_kp: DhKeyPair, + rsa_secret: RsaSecret, + tls_version: proto.Version = .tls_1_2, + cipher: Cipher = undefined, + cert: CertificateParser = undefined, + client_certificate_requested: bool = false, + // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97, x25519_kyber768d00 = 1120 + server_pub_key_buf: [2048]u8 = undefined, + server_pub_key: []const u8 = undefined, + + rec_rdr: *RecordReaderT, // tls record reader + buffer: []u8, // scratch buffer used in all messages creation + + const HandshakeT = @This(); + + pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { + return .{ + .client_random = undefined, + .dh_kp = undefined, + .rsa_secret = undefined, + //.now_sec = std.time.timestamp(), + .buffer = buf, + .rec_rdr = rec_rdr, + }; + } + + fn initKeys( + h: *HandshakeT, + named_groups: []const proto.NamedGroup, + ) !void { + const init_keys_buf_len = 32 + 46 + DhKeyPair.seed_len; + var buf: [init_keys_buf_len]u8 = undefined; + crypto.random.bytes(&buf); + + h.client_random = buf[0..32].*; + h.rsa_secret = RsaSecret.init(buf[32..][0..46].*); + h.dh_kp = try DhKeyPair.init(buf[32 + 46 ..][0..DhKeyPair.seed_len].*, named_groups); + } + + /// Handshake exchanges messages with server to get agreement about + /// cryptographic parameters. That upgrades existing client-server + /// connection to TLS connection. Returns cipher used in application for + /// encrypted message exchange. + /// + /// Handles TLS 1.2 and TLS 1.3 connections. After initial client hello + /// server chooses in its server hello which TLS version will be used. + /// + /// TLS 1.2 handshake messages exchange: + /// Client Server + /// -------------------------------------------------------------- + /// ClientHello client flight 1 ---> + /// ServerHello + /// Certificate + /// ServerKeyExchange + /// CertificateRequest* + /// <--- server flight 1 ServerHelloDone + /// Certificate* + /// ClientKeyExchange + /// CertificateVerify* + /// ChangeCipherSpec + /// Finished client flight 2 ---> + /// ChangeCipherSpec + /// <--- server flight 2 Finished + /// + /// TLS 1.3 handshake messages exchange: + /// Client Server + /// -------------------------------------------------------------- + /// ClientHello client flight 1 ---> + /// ServerHello + /// {EncryptedExtensions} + /// {CertificateRequest*} + /// {Certificate} + /// {CertificateVerify} + /// <--- server flight 1 {Finished} + /// ChangeCipherSpec + /// {Certificate*} + /// {CertificateVerify*} + /// Finished client flight 2 ---> + /// + /// * - optional + /// {} - encrypted + /// + /// References: + /// https://datatracker.ietf.org/doc/html/rfc5246#section-7.3 + /// https://datatracker.ietf.org/doc/html/rfc8446#section-2 + /// + pub fn handshake(h: *HandshakeT, w: Stream, opt: Options) !Cipher { + defer h.updateDiagnostic(opt); + try h.initKeys(opt.named_groups); + h.cert = .{ + .host = opt.host, + .root_ca = opt.root_ca.bundle, + .skip_verify = opt.insecure_skip_verify, + }; + + try w.writeAll(try h.makeClientHello(opt)); // client flight 1 + try h.readServerFlight1(); // server flight 1 + h.transcript.use(h.cipher_suite.hash()); + + // tls 1.3 specific handshake part + if (h.tls_version == .tls_1_3) { + try h.generateHandshakeCipher(opt.key_log_callback); + try h.readEncryptedServerFlight1(); // server flight 1 + const app_cipher = try h.generateApplicationCipher(opt.key_log_callback); + try w.writeAll(try h.makeClientFlight2Tls13(opt.auth)); // client flight 2 + return app_cipher; + } + + // tls 1.2 specific handshake part + try h.generateCipher(opt.key_log_callback); + try w.writeAll(try h.makeClientFlight2Tls12(opt.auth)); // client flight 2 + try h.readServerFlight2(); // server flight 2 + return h.cipher; + } + + /// Prepare key material and generate cipher for TLS 1.2 + fn generateCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + try h.verifyCertificateSignatureTls12(); + try h.generateKeyMaterial(key_log_callback); + h.cipher = try Cipher.initTls12(h.cipher_suite, &h.key_material, .client); + } + + /// Generate TLS 1.2 pre master secret, master secret and key material. + fn generateKeyMaterial(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + const pre_master_secret = if (h.named_group) |named_group| + try h.dh_kp.sharedKey(named_group, h.server_pub_key) + else + &h.rsa_secret.secret; + + _ = dupe( + &h.master_secret, + h.transcript.masterSecret(pre_master_secret, h.client_random, h.server_random), + ); + _ = dupe( + &h.key_material, + h.transcript.keyMaterial(&h.master_secret, h.client_random, h.server_random), + ); + if (key_log_callback) |cb| { + cb(key_log.label.client_random, &h.client_random, &h.master_secret); + } + } + + /// TLS 1.3 cipher used during handshake + fn generateHandshakeCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !void { + const shared_key = try h.dh_kp.sharedKey(h.named_group.?, h.server_pub_key); + const handshake_secret = h.transcript.handshakeSecret(shared_key); + if (key_log_callback) |cb| { + cb(key_log.label.server_handshake_traffic_secret, &h.client_random, handshake_secret.server); + cb(key_log.label.client_handshake_traffic_secret, &h.client_random, handshake_secret.client); + } + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .client); + } + + /// TLS 1.3 application (client) cipher + fn generateApplicationCipher(h: *HandshakeT, key_log_callback: ?key_log.Callback) !Cipher { + const application_secret = h.transcript.applicationSecret(); + if (key_log_callback) |cb| { + cb(key_log.label.server_traffic_secret_0, &h.client_random, application_secret.server); + cb(key_log.label.client_traffic_secret_0, &h.client_random, application_secret.client); + } + return try Cipher.initTls13(h.cipher_suite, application_secret, .client); + } + + fn makeClientHello(h: *HandshakeT, opt: Options) ![]const u8 { + // Buffer will have this parts: + // | header | payload | extensions | + // + // Header will be written last because we need to know length of + // payload and extensions when creating it. Payload has + // extensions length (u16) as last element. + // + var buffer = h.buffer; + const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) + const tls_versions = try CipherSuite.versions(opt.cipher_suites); + // Payload writer, preserve header_len bytes for handshake header. + var payload = record.Writer{ .buf = buffer[header_len..] }; + try payload.writeEnum(proto.Version.tls_1_2); + try payload.write(&h.client_random); + try payload.writeByte(0); // no session id + try payload.writeEnumArray(CipherSuite, opt.cipher_suites); + try payload.write(&[_]u8{ 0x01, 0x00 }); // no compression + + // Extensions writer starts after payload and preserves 2 more + // bytes for extension len in payload. + var ext = record.Writer{ .buf = buffer[header_len + payload.pos + 2 ..] }; + try ext.writeExtension(.supported_versions, switch (tls_versions) { + .both => &[_]proto.Version{ .tls_1_3, .tls_1_2 }, + .tls_1_3 => &[_]proto.Version{.tls_1_3}, + .tls_1_2 => &[_]proto.Version{.tls_1_2}, + }); + try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); + + try ext.writeExtension(.supported_groups, opt.named_groups); + if (tls_versions != .tls_1_2) { + var keys: [supported_named_groups.len][]const u8 = undefined; + for (opt.named_groups, 0..) |ng, i| { + keys[i] = try h.dh_kp.publicKey(ng); + } + try ext.writeKeyShare(opt.named_groups, keys[0..opt.named_groups.len]); + } + try ext.writeServerName(opt.host); + + // Extensions length at the end of the payload. + try payload.writeInt(@as(u16, @intCast(ext.pos))); + + // Header at the start of the buffer. + const body_len = payload.pos + ext.pos; + buffer[0..header_len].* = record.header(.handshake, 4 + body_len) ++ + record.handshakeHeader(.client_hello, body_len); + + const msg = buffer[0 .. header_len + body_len]; + h.transcript.update(msg[record.header_len..]); + return msg; + } + + /// Process first flight of the messages from the server. + /// Read server hello message. If TLS 1.3 is chosen in server hello + /// return. For TLS 1.2 continue and read certificate, key_exchange + /// eventual certificate request and hello done messages. + fn readServerFlight1(h: *HandshakeT) !void { + var handshake_states: []const proto.Handshake = &.{.server_hello}; + + while (true) { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.handshake); + + h.transcript.update(d.payload); + + // Multiple handshake messages can be packed in single tls record. + while (!d.eof()) { + const handshake_type = try d.decode(proto.Handshake); + + const length = try d.decode(u24); + if (length > cipher.max_cleartext_len) + return error.TlsUnsupportedFragmentedHandshakeMessage; + + brk: { + for (handshake_states) |state| + if (state == handshake_type) break :brk; + return error.TlsUnexpectedMessage; + } + switch (handshake_type) { + .server_hello => { // server hello, ref: https://datatracker.ietf.org/doc/html/rfc5246#section-7.4.1.3 + try h.parseServerHello(&d, length); + if (h.tls_version == .tls_1_3) { + if (!d.eof()) return error.TlsIllegalParameter; + return; // end of tls 1.3 server flight 1 + } + handshake_states = if (h.cert.skip_verify) + &.{ .certificate, .server_key_exchange, .server_hello_done } + else + &.{.certificate}; + }, + .certificate => { + try h.cert.parseCertificate(&d, h.tls_version); + handshake_states = if (h.cipher_suite.keyExchange() == .rsa) + &.{.server_hello_done} + else + &.{.server_key_exchange}; + }, + .server_key_exchange => { + try h.parseServerKeyExchange(&d); + handshake_states = &.{ .certificate_request, .server_hello_done }; + }, + .certificate_request => { + h.client_certificate_requested = true; + try d.skip(length); + handshake_states = &.{.server_hello_done}; + }, + .server_hello_done => { + if (length != 0) return error.TlsIllegalParameter; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + } + } + + /// Parse server hello message. + fn parseServerHello(h: *HandshakeT, d: *record.Decoder, length: u24) !void { + if (try d.decode(proto.Version) != proto.Version.tls_1_2) + return error.TlsBadVersion; + h.server_random = try d.array(32); + if (isServerHelloRetryRequest(&h.server_random)) + return error.TlsServerHelloRetryRequest; + + const session_id_len = try d.decode(u8); + if (session_id_len > 32) return error.TlsIllegalParameter; + try d.skip(session_id_len); + + h.cipher_suite = try d.decode(CipherSuite); + try h.cipher_suite.validate(); + try d.skip(1); // skip compression method + + const extensions_present = length > 2 + 32 + 1 + session_id_len + 2 + 1; + if (extensions_present) { + const exs_len = try d.decode(u16); + var l: usize = 0; + while (l < exs_len) { + const typ = try d.decode(proto.Extension); + const len = try d.decode(u16); + defer l += len + 4; + + switch (typ) { + .supported_versions => { + switch (try d.decode(proto.Version)) { + .tls_1_2, .tls_1_3 => |v| h.tls_version = v, + else => return error.TlsIllegalParameter, + } + if (len != 2) return error.TlsIllegalParameter; + }, + .key_share => { + h.named_group = try d.decode(proto.NamedGroup); + h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u16))); + if (len != h.server_pub_key.len + 4) return error.TlsIllegalParameter; + }, + else => { + try d.skip(len); + }, + } + } + } + } + + fn isServerHelloRetryRequest(server_random: []const u8) bool { + // Ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.3 + const hello_retry_request_magic = [32]u8{ + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, + 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, + }; + return std.mem.eql(u8, server_random, &hello_retry_request_magic); + } + + fn parseServerKeyExchange(h: *HandshakeT, d: *record.Decoder) !void { + const curve_type = try d.decode(proto.Curve); + h.named_group = try d.decode(proto.NamedGroup); + h.server_pub_key = dupe(&h.server_pub_key_buf, try d.slice(try d.decode(u8))); + h.cert.signature_scheme = try d.decode(proto.SignatureScheme); + h.cert.signature = dupe(&h.cert.signature_buf, try d.slice(try d.decode(u16))); + if (curve_type != .named_curve) return error.TlsIllegalParameter; + } + + /// Read encrypted part (after server hello) of the server first flight + /// for TLS 1.3: change cipher spec, eventual certificate request, + /// certificate, certificate verify and handshake finished messages. + fn readEncryptedServerFlight1(h: *HandshakeT) !void { + var cleartext_buf = h.buffer; + var cleartext_buf_head: usize = 0; + var cleartext_buf_tail: usize = 0; + var handshake_states: []const proto.Handshake = &.{.encrypted_extensions}; + + outer: while (true) { + // wrapped record decoder + const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + switch (rec.content_type) { + .change_cipher_spec => {}, + .application_data => { + const content_type, const cleartext = try h.cipher.decrypt( + cleartext_buf[cleartext_buf_tail..], + rec, + ); + cleartext_buf_tail += cleartext.len; + if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; + + var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); + try d.expectContentType(.handshake); + while (!d.eof()) { + const start_idx = d.idx; + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + + // std.debug.print("handshake loop: {} {} {} {}\n", .{ handshake_type, length, d.payload.len, d.idx }); + if (length > cipher.max_cleartext_len) + return error.TlsUnsupportedFragmentedHandshakeMessage; + if (length > d.rest().len) + continue :outer; // fragmented handshake into multiple records + + defer { + const handshake_payload = d.payload[start_idx..d.idx]; + h.transcript.update(handshake_payload); + cleartext_buf_head += handshake_payload.len; + } + + brk: { + for (handshake_states) |state| + if (state == handshake_type) break :brk; + return error.TlsUnexpectedMessage; + } + switch (handshake_type) { + .encrypted_extensions => { + try d.skip(length); + handshake_states = if (h.cert.skip_verify) + &.{ .certificate_request, .certificate, .finished } + else + &.{ .certificate_request, .certificate }; + }, + .certificate_request => { + h.client_certificate_requested = true; + try d.skip(length); + handshake_states = if (h.cert.skip_verify) + &.{ .certificate, .finished } + else + &.{.certificate}; + }, + .certificate => { + try h.cert.parseCertificate(&d, h.tls_version); + handshake_states = &.{.certificate_verify}; + }, + .certificate_verify => { + try h.cert.parseCertificateVerify(&d); + try h.cert.verifySignature(h.transcript.serverCertificateVerify()); + handshake_states = &.{.finished}; + }, + .finished => { + const actual = try d.slice(length); + var buf: [Transcript.max_mac_length]u8 = undefined; + const expected = h.transcript.serverFinishedTls13(&buf); + if (!mem.eql(u8, expected, actual)) + return error.TlsDecryptError; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + cleartext_buf_head = 0; + cleartext_buf_tail = 0; + }, + else => return error.TlsUnexpectedMessage, + } + } + } + + fn verifyCertificateSignatureTls12(h: *HandshakeT) !void { + if (h.cipher_suite.keyExchange() != .ecdhe) return; + const verify_bytes = brk: { + var w = record.Writer{ .buf = h.buffer }; + try w.write(&h.client_random); + try w.write(&h.server_random); + try w.writeEnum(proto.Curve.named_curve); + try w.writeEnum(h.named_group.?); + try w.writeInt(@as(u8, @intCast(h.server_pub_key.len))); + try w.write(h.server_pub_key); + break :brk w.getWritten(); + }; + try h.cert.verifySignature(verify_bytes); + } + + /// Create client key exchange, change cipher spec and handshake + /// finished messages for tls 1.2. + /// If client certificate is requested also adds client certificate and + /// certificate verify messages. + fn makeClientFlight2Tls12(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + var cert_builder: ?CertificateBuilder = null; + + // Client certificate message + if (h.client_certificate_requested) { + if (auth) |a| { + const cb = h.certificateBuilder(a); + cert_builder = cb; + const client_certificate = try cb.makeCertificate(w.getPayload()); + h.transcript.update(client_certificate); + try w.advanceRecord(.handshake, client_certificate.len); + } else { + const empty_certificate = &record.handshakeHeader(.certificate, 3) ++ [_]u8{ 0, 0, 0 }; + h.transcript.update(empty_certificate); + try w.writeRecord(.handshake, empty_certificate); + } + } + + // Client key exchange message + { + const key_exchange = try h.makeClientKeyExchange(w.getPayload()); + h.transcript.update(key_exchange); + try w.advanceRecord(.handshake, key_exchange.len); + } + + // Client certificate verify message + if (cert_builder) |cb| { + const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try w.advanceRecord(.handshake, certificate_verify.len); + } + + // Client change cipher spec message + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + + // Client handshake finished message + { + const client_finished = &record.handshakeHeader(.finished, 12) ++ + h.transcript.clientFinishedTls12(&h.master_secret); + h.transcript.update(client_finished); + try h.writeEncrypted(&w, client_finished); + } + + return w.getWritten(); + } + + /// Create client change cipher spec and handshake finished messages for + /// tls 1.3. + /// If the client certificate is requested by the server and client is + /// configured with certificates and private key then client certificate + /// and client certificate verify messages are also created. If the + /// server has requested certificate but the client is not configured + /// empty certificate message is sent, as is required by rfc. + fn makeClientFlight2Tls13(h: *HandshakeT, auth: ?CertKeyPair) ![]const u8 { + var w = record.Writer{ .buf = h.buffer }; + + // Client change cipher spec message + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + + if (h.client_certificate_requested) { + if (auth) |a| { + const cb = h.certificateBuilder(a); + { + const certificate = try cb.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); + } + { + const certificate_verify = try cb.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } else { + // Empty certificate message and no certificate verify message + const empty_certificate = &record.handshakeHeader(.certificate, 4) ++ [_]u8{ 0, 0, 0, 0 }; + h.transcript.update(empty_certificate); + try h.writeEncrypted(&w, empty_certificate); + } + } + + // Client handshake finished message + { + const client_finished = try h.makeClientFinishedTls13(w.getPayload()); + h.transcript.update(client_finished); + try h.writeEncrypted(&w, client_finished); + } + + return w.getWritten(); + } + + fn certificateBuilder(h: *HandshakeT, auth: CertKeyPair) CertificateBuilder { + return .{ + .bundle = auth.bundle, + .key = auth.key, + .transcript = &h.transcript, + .tls_version = h.tls_version, + .side = .client, + }; + } + + fn makeClientFinishedTls13(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const verify_data = h.transcript.clientFinishedTls13(w.getHandshakePayload()); + try w.advanceHandshake(.finished, verify_data.len); + return w.getWritten(); + } + + fn makeClientKeyExchange(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + if (h.named_group) |named_group| { + const key = try h.dh_kp.publicKey(named_group); + try w.writeHandshakeHeader(.client_key_exchange, 1 + key.len); + try w.writeInt(@as(u8, @intCast(key.len))); + try w.write(key); + } else { + const key = try h.rsa_secret.encrypted(h.cert.pub_key_algo, h.cert.pub_key); + try w.writeHandshakeHeader(.client_key_exchange, 2 + key.len); + try w.writeInt(@as(u16, @intCast(key.len))); + try w.write(key); + } + return w.getWritten(); + } + + fn readServerFlight2(h: *HandshakeT) !void { + // Read server change cipher spec message. + { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.change_cipher_spec); + } + // Read encrypted server handshake finished message. Verify that + // content of the server finished message is based on transcript + // hash and master secret. + { + const content_type, const server_finished = + try h.rec_rdr.nextDecrypt(&h.cipher) orelse return error.EndOfStream; + if (content_type != .handshake) + return error.TlsUnexpectedMessage; + const expected = record.handshakeHeader(.finished, 12) ++ h.transcript.serverFinishedTls12(&h.master_secret); + if (!mem.eql(u8, server_finished, &expected)) + return error.TlsBadRecordMac; + } + } + + /// Write encrypted handshake message into `w` + fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { + const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); + w.pos += ciphertext.len; + } + + // Copy handshake parameters to opt.diagnostic + fn updateDiagnostic(h: *HandshakeT, opt: Options) void { + if (opt.diagnostic) |d| { + d.tls_version = h.tls_version; + d.cipher_suite_tag = h.cipher_suite; + d.named_group = h.named_group orelse @as(proto.NamedGroup, @enumFromInt(0x0000)); + d.signature_scheme = h.cert.signature_scheme; + if (opt.auth) |a| + d.client_signature_scheme = a.key.signature_scheme; + } + } + }; +} + +const RsaSecret = struct { + secret: [48]u8, + + fn init(rand: [46]u8) RsaSecret { + return .{ .secret = [_]u8{ 0x03, 0x03 } ++ rand }; + } + + // Pre master secret encrypted with certificate public key. + inline fn encrypted( + self: RsaSecret, + cert_pub_key_algo: Certificate.Parsed.PubKeyAlgo, + cert_pub_key: []const u8, + ) ![]const u8 { + if (cert_pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const pk = try rsa.PublicKey.fromDer(cert_pub_key); + var out: [512]u8 = undefined; + return try pk.encryptPkcsv1_5(&self.secret, &out); + } +}; + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const data13 = @import("testdata/tls13.zig"); +const testu = @import("testu.zig"); + +fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { + return record.reader(std.io.fixedBufferStream(data)); +} +const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); + +test "parse tls 1.2 server hello" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data12.server_hello_responses); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + + // Set to known instead of random + h.client_random = data12.client_random; + h.dh_kp.x25519_kp.secret_key = data12.client_secret; + + // Parse server hello, certificate and key exchange messages. + // Read cipher suite, named group, signature scheme, server random certificate public key + // Verify host name, signature + // Calculate key material + h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; + try h.readServerFlight1(); + try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, h.cipher_suite); + try testing.expectEqual(.x25519, h.named_group.?); + try testing.expectEqual(.rsa_pkcs1_sha256, h.cert.signature_scheme); + try testing.expectEqualSlices(u8, &data12.server_random, &h.server_random); + try testing.expectEqualSlices(u8, &data12.server_pub_key, h.server_pub_key); + try testing.expectEqualSlices(u8, &data12.signature, h.cert.signature); + try testing.expectEqualSlices(u8, &data12.cert_pub_key, h.cert.pub_key); + + try h.verifyCertificateSignatureTls12(); + try h.generateKeyMaterial(null); + + try testing.expectEqualSlices(u8, &data12.key_material, h.key_material[0..data12.key_material.len]); +} + +test "verify google.com certificate" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(@embedFile("testdata/google.com/server_hello")); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + h.client_random = @embedFile("testdata/google.com/client_random").*; + + var ca_bundle: Certificate.Bundle = .{}; + try ca_bundle.rescan(testing.allocator); + defer ca_bundle.deinit(testing.allocator); + + h.cert = .{ .host = "google.com", .skip_verify = true, .root_ca = .{}, .now_sec = 1714846451 }; + try h.readServerFlight1(); + try h.verifyCertificateSignatureTls12(); +} + +test "parse tls 1.3 server hello" { + var rec_rdr = testReader(&data13.server_hello); + var d = (try rec_rdr.nextDecoder()); + + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + try testing.expectEqual(0x000076, length); + try testing.expectEqual(.server_hello, handshake_type); + + var h = TestHandshake.init(undefined, undefined); + try h.parseServerHello(&d, length); + + try testing.expectEqual(.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqualSlices(u8, &data13.server_random, &h.server_random); + try testing.expectEqual(.tls_1_3, h.tls_version); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.server_pub_key, h.server_pub_key); +} + +test "init tls 1.3 handshake cipher" { + const cipher_suite_tag: CipherSuite = .AES_256_GCM_SHA384; + + var transcript = Transcript{}; + transcript.use(cipher_suite_tag.hash()); + transcript.update(data13.client_hello[record.header_len..]); + transcript.update(data13.server_hello[record.header_len..]); + + var dh_kp = DhKeyPair{ + .x25519_kp = .{ + .public_key = data13.client_public_key, + .secret_key = data13.client_private_key, + }, + }; + const shared_key = try dh_kp.sharedKey(.x25519, &data13.server_pub_key); + try testing.expectEqualSlices(u8, &data13.shared_key, shared_key); + + const cph = try Cipher.initTls13(cipher_suite_tag, transcript.handshakeSecret(shared_key), .client); + + const c = &cph.AES_256_GCM_SHA384; + try testing.expectEqualSlices(u8, &data13.server_handshake_key, &c.decrypt_key); + try testing.expectEqualSlices(u8, &data13.client_handshake_key, &c.encrypt_key); + try testing.expectEqualSlices(u8, &data13.server_handshake_iv, &c.decrypt_iv); + try testing.expectEqualSlices(u8, &data13.client_handshake_iv, &c.encrypt_iv); +} + +fn initExampleHandshake(h: *TestHandshake) !void { + h.cipher_suite = .AES_256_GCM_SHA384; + h.transcript.use(h.cipher_suite.hash()); + h.transcript.update(data13.client_hello[record.header_len..]); + h.transcript.update(data13.server_hello[record.header_len..]); + h.cipher = try Cipher.initTls13(h.cipher_suite, h.transcript.handshakeSecret(&data13.shared_key), .client); + h.tls_version = .tls_1_3; + h.cert.now_sec = 1714846451; + h.server_pub_key = &data13.server_pub_key; +} + +test "tls 1.3 decrypt wrapped record" { + var cph = brk: { + var h = TestHandshake.init(undefined, undefined); + try initExampleHandshake(&h); + break :brk h.cipher; + }; + + var cleartext_buf: [1024]u8 = undefined; + { + const rec = record.Record.init(&data13.server_encrypted_extensions_wrapped); + + const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); + try testing.expectEqual(.handshake, content_type); + try testing.expectEqualSlices(u8, &data13.server_encrypted_extensions, cleartext); + } + { + const rec = record.Record.init(&data13.server_certificate_wrapped); + + const content_type, const cleartext = try cph.decrypt(&cleartext_buf, rec); + try testing.expectEqual(.handshake, content_type); + try testing.expectEqualSlices(u8, &data13.server_certificate, cleartext); + } +} + +test "tls 1.3 process server flight" { + var buffer: [1024]u8 = undefined; + var h = brk: { + var rec_rdr = testReader(&data13.server_flight); + break :brk TestHandshake.init(&buffer, &rec_rdr); + }; + + try initExampleHandshake(&h); + h.cert = .{ .host = "example.ulfheim.net", .skip_verify = true, .root_ca = .{} }; + try h.readEncryptedServerFlight1(); + + { // application cipher keys calculation + try testing.expectEqualSlices(u8, &data13.handshake_hash, &h.transcript.sha384.hash.peek()); + + var cph = try Cipher.initTls13(h.cipher_suite, h.transcript.applicationSecret(), .client); + const c = &cph.AES_256_GCM_SHA384; + try testing.expectEqualSlices(u8, &data13.server_application_key, &c.decrypt_key); + try testing.expectEqualSlices(u8, &data13.client_application_key, &c.encrypt_key); + try testing.expectEqualSlices(u8, &data13.server_application_iv, &c.decrypt_iv); + try testing.expectEqualSlices(u8, &data13.client_application_iv, &c.encrypt_iv); + + const encrypted = try cph.encrypt(&buffer, .application_data, "ping"); + try testing.expectEqualSlices(u8, &data13.client_ping_wrapped, encrypted); + } + { // client finished message + var buf: [4 + Transcript.max_mac_length]u8 = undefined; + const client_finished = try h.makeClientFinishedTls13(&buf); + try testing.expectEqualSlices(u8, &data13.client_finished_verify_data, client_finished[4..]); + const encrypted = try h.cipher.encrypt(&buffer, .handshake, client_finished); + try testing.expectEqualSlices(u8, &data13.client_finished_wrapped, encrypted); + } +} + +test "create client hello" { + var h = brk: { + var buffer: [1024]u8 = undefined; + var h = TestHandshake.init(&buffer, undefined); + h.client_random = testu.hexToBytes( + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + ); + break :brk h; + }; + + const actual = try h.makeClientHello(.{ + .host = "google.com", + .root_ca = .{}, + .cipher_suites = &[_]CipherSuite{CipherSuite.ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + .named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }, + }); + + const expected = testu.hexToBytes( + "16 03 03 00 6d " ++ // record header + "01 00 00 69 " ++ // handshake header + "03 03 " ++ // protocol version + "00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f " ++ // client random + "00 " ++ // no session id + "00 02 c0 2b " ++ // cipher suites + "01 00 " ++ // compression methods + "00 3e " ++ // extensions length + "00 2b 00 03 02 03 03 " ++ // supported versions extension + "00 0d 00 14 00 12 04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01 " ++ // signature algorithms extension + "00 0a 00 08 00 06 00 1d 00 17 00 18 " ++ // named groups extension + "00 00 00 0f 00 0d 00 00 0a 67 6f 6f 67 6c 65 2e 63 6f 6d ", // server name extension + ); + try testing.expectEqualSlices(u8, &expected, actual); +} + +test "handshake verify server finished message" { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data12.server_handshake_finished_msgs); + var h = TestHandshake.init(&buffer, &rec_rdr); + + h.cipher_suite = .ECDHE_ECDSA_WITH_AES_128_CBC_SHA; + h.master_secret = data12.master_secret; + + // add handshake messages to the transcript + for (data12.handshake_messages) |msg| { + h.transcript.update(msg[record.header_len..]); + } + + // expect verify data + const client_finished = h.transcript.clientFinishedTls12(&h.master_secret); + try testing.expectEqualSlices(u8, &data12.client_finished, &record.handshakeHeader(.finished, 12) ++ client_finished); + + // init client with prepared key_material + h.cipher = try Cipher.initTls12(.ECDHE_RSA_WITH_AES_128_CBC_SHA, &data12.key_material, .client); + + // check that server verify data matches calculates from hashes of all handshake messages + h.transcript.update(&data12.client_finished); + try h.readServerFlight2(); +} diff --git a/src/tls.zig/handshake_common.zig b/src/tls.zig/handshake_common.zig new file mode 100644 index 0000000..178a3ce --- /dev/null +++ b/src/tls.zig/handshake_common.zig @@ -0,0 +1,448 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; +const crypto = std.crypto; +const Certificate = crypto.Certificate; + +const Transcript = @import("transcript.zig").Transcript; +const PrivateKey = @import("PrivateKey.zig"); +const record = @import("record.zig"); +const rsa = @import("rsa/rsa.zig"); +const proto = @import("protocol.zig"); + +const X25519 = crypto.dh.X25519; +const EcdsaP256Sha256 = crypto.sign.ecdsa.EcdsaP256Sha256; +const EcdsaP384Sha384 = crypto.sign.ecdsa.EcdsaP384Sha384; +const Kyber768 = crypto.kem.kyber_d00.Kyber768; + +pub const supported_signature_algorithms = &[_]proto.SignatureScheme{ + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .ed25519, + .rsa_pkcs1_sha1, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, +}; + +pub const CertKeyPair = struct { + /// A chain of one or more certificates, leaf first. + /// + /// Each X.509 certificate contains the public key of a key pair, extra + /// information (the name of the holder, the name of an issuer of the + /// certificate, validity time spans) and a signature generated using the + /// private key of the issuer of the certificate. + /// + /// All certificates from the bundle are sent to the other side when creating + /// Certificate tls message. + /// + /// Leaf certificate and private key are used to create signature for + /// CertifyVerify tls message. + bundle: Certificate.Bundle, + + /// Private key corresponding to the public key in leaf certificate from the + /// bundle. + key: PrivateKey, + + pub fn load( + allocator: std.mem.Allocator, + dir: std.fs.Dir, + cert_path: []const u8, + key_path: []const u8, + ) !CertKeyPair { + var bundle: Certificate.Bundle = .{}; + try bundle.addCertsFromFilePath(allocator, dir, cert_path); + + const key_file = try dir.openFile(key_path, .{}); + defer key_file.close(); + const key = try PrivateKey.fromFile(allocator, key_file); + + return .{ .bundle = bundle, .key = key }; + } + + pub fn deinit(c: *CertKeyPair, allocator: std.mem.Allocator) void { + c.bundle.deinit(allocator); + } +}; + +pub const CertBundle = struct { + // A chain of one or more certificates. + // + // They are used to verify that certificate chain sent by the other side + // forms valid trust chain. + bundle: Certificate.Bundle = .{}, + + pub fn fromFile(allocator: std.mem.Allocator, dir: std.fs.Dir, path: []const u8) !CertBundle { + var bundle: Certificate.Bundle = .{}; + try bundle.addCertsFromFilePath(allocator, dir, path); + return .{ .bundle = bundle }; + } + + pub fn fromSystem(allocator: std.mem.Allocator) !CertBundle { + var bundle: Certificate.Bundle = .{}; + try bundle.rescan(allocator); + return .{ .bundle = bundle }; + } + + pub fn deinit(cb: *CertBundle, allocator: std.mem.Allocator) void { + cb.bundle.deinit(allocator); + } +}; + +pub const CertificateBuilder = struct { + bundle: Certificate.Bundle, + key: PrivateKey, + transcript: *Transcript, + tls_version: proto.Version = .tls_1_3, + side: proto.Side = .client, + + pub fn makeCertificate(h: CertificateBuilder, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const certs = h.bundle.bytes.items; + const certs_count = h.bundle.map.size; + + // Differences between tls 1.3 and 1.2 + // TLS 1.3 has request context in header and extensions for each certificate. + // Here we use empty length for each field. + // TLS 1.2 don't have these two fields. + const request_context, const extensions = if (h.tls_version == .tls_1_3) + .{ &[_]u8{0}, &[_]u8{ 0, 0 } } + else + .{ &[_]u8{}, &[_]u8{} }; + const certs_len = certs.len + (3 + extensions.len) * certs_count; + + // Write handshake header + try w.writeHandshakeHeader(.certificate, certs_len + request_context.len + 3); + try w.write(request_context); + try w.writeInt(@as(u24, @intCast(certs_len))); + + // Write each certificate + var index: u32 = 0; + while (index < certs.len) { + const e = try Certificate.der.Element.parse(certs, index); + const cert = certs[index..e.slice.end]; + try w.writeInt(@as(u24, @intCast(cert.len))); // certificate length + try w.write(cert); // certificate + try w.write(extensions); // certificate extensions + index = e.slice.end; + } + return w.getWritten(); + } + + pub fn makeCertificateVerify(h: CertificateBuilder, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const signature, const signature_scheme = try h.createSignature(); + try w.writeHandshakeHeader(.certificate_verify, signature.len + 4); + try w.writeEnum(signature_scheme); + try w.writeInt(@as(u16, @intCast(signature.len))); + try w.write(signature); + return w.getWritten(); + } + + /// Creates signature for client certificate signature message. + /// Returns signature bytes and signature scheme. + inline fn createSignature(h: CertificateBuilder) !struct { []const u8, proto.SignatureScheme } { + switch (h.key.signature_scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + const Ecdsa = SchemeEcdsa(comptime_scheme); + const key = h.key.key.ecdsa; + const key_len = Ecdsa.SecretKey.encoded_length; + if (key.len < key_len) return error.InvalidEncoding; + const secret_key = try Ecdsa.SecretKey.fromBytes(key[0..key_len].*); + const key_pair = try Ecdsa.KeyPair.fromSecretKey(secret_key); + var signer = try key_pair.signer(null); + h.setSignatureVerifyBytes(&signer); + const signature = try signer.finalize(); + var buf: [Ecdsa.Signature.der_encoded_length_max]u8 = undefined; + return .{ signature.toDer(&buf), comptime_scheme }; + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + const Hash = SchemeHash(comptime_scheme); + var signer = try h.key.key.rsa.signerOaep(Hash, null); + h.setSignatureVerifyBytes(&signer); + var buf: [512]u8 = undefined; + const signature = try signer.finalize(&buf); + return .{ signature.bytes, comptime_scheme }; + }, + else => return error.TlsUnknownSignatureScheme, + } + } + + fn setSignatureVerifyBytes(h: CertificateBuilder, signer: anytype) void { + if (h.tls_version == .tls_1_2) { + // tls 1.2 signature uses current transcript hash value. + // ref: https://datatracker.ietf.org/doc/html/rfc5246.html#section-7.4.8 + const Hash = @TypeOf(signer.h); + signer.h = h.transcript.hash(Hash); + } else { + // tls 1.3 signature is computed over concatenation of 64 spaces, + // context, separator and content. + // ref: https://datatracker.ietf.org/doc/html/rfc8446#section-4.4.3 + if (h.side == .server) { + signer.update(h.transcript.serverCertificateVerify()); + } else { + signer.update(h.transcript.clientCertificateVerify()); + } + } + } + + fn SchemeEcdsa(comptime scheme: proto.SignatureScheme) type { + return switch (scheme) { + .ecdsa_secp256r1_sha256 => EcdsaP256Sha256, + .ecdsa_secp384r1_sha384 => EcdsaP384Sha384, + else => unreachable, + }; + } +}; + +pub const CertificateParser = struct { + pub_key_algo: Certificate.Parsed.PubKeyAlgo = undefined, + pub_key_buf: [600]u8 = undefined, + pub_key: []const u8 = undefined, + + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + signature_buf: [1024]u8 = undefined, + signature: []const u8 = undefined, + + root_ca: Certificate.Bundle, + host: []const u8, + skip_verify: bool = false, + now_sec: i64 = 0, + + pub fn parseCertificate(h: *CertificateParser, d: *record.Decoder, tls_version: proto.Version) !void { + if (h.now_sec == 0) { + h.now_sec = std.time.timestamp(); + } + if (tls_version == .tls_1_3) { + const request_context = try d.decode(u8); + if (request_context != 0) return error.TlsIllegalParameter; + } + + var trust_chain_established = false; + var last_cert: ?Certificate.Parsed = null; + const certs_len = try d.decode(u24); + const start_idx = d.idx; + while (d.idx - start_idx < certs_len) { + const cert_len = try d.decode(u24); + // std.debug.print("=> {} {} {} {}\n", .{ certs_len, d.idx, cert_len, d.payload.len }); + const cert = try d.slice(cert_len); + if (tls_version == .tls_1_3) { + // certificate extensions present in tls 1.3 + try d.skip(try d.decode(u16)); + } + if (trust_chain_established) + continue; + + const subject = try (Certificate{ .buffer = cert, .index = 0 }).parse(); + if (last_cert) |pc| { + if (pc.verify(subject, h.now_sec)) { + last_cert = subject; + } else |err| switch (err) { + error.CertificateIssuerMismatch => { + // skip certificate which is not part of the chain + continue; + }, + else => return err, + } + } else { // first certificate + if (!h.skip_verify and h.host.len > 0) { + try subject.verifyHostName(h.host); + } + h.pub_key = dupe(&h.pub_key_buf, subject.pubKey()); + h.pub_key_algo = subject.pub_key_algo; + last_cert = subject; + } + if (!h.skip_verify) { + if (h.root_ca.verify(last_cert.?, h.now_sec)) |_| { + trust_chain_established = true; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => return err, + } + } + } + if (!h.skip_verify and !trust_chain_established) { + return error.CertificateIssuerNotFound; + } + } + + pub fn parseCertificateVerify(h: *CertificateParser, d: *record.Decoder) !void { + h.signature_scheme = try d.decode(proto.SignatureScheme); + h.signature = dupe(&h.signature_buf, try d.slice(try d.decode(u16))); + } + + pub fn verifySignature(h: *CertificateParser, verify_bytes: []const u8) !void { + switch (h.signature_scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + if (h.pub_key_algo != .X9_62_id_ecPublicKey) return error.TlsBadSignatureScheme; + const cert_named_curve = h.pub_key_algo.X9_62_id_ecPublicKey; + switch (cert_named_curve) { + inline .secp384r1, .X9_62_prime256v1 => |comptime_cert_named_curve| { + const Ecdsa = SchemeEcdsaCert(comptime_scheme, comptime_cert_named_curve); + const key = try Ecdsa.PublicKey.fromSec1(h.pub_key); + const sig = try Ecdsa.Signature.fromDer(h.signature); + try sig.verify(verify_bytes, key); + }, + else => return error.TlsUnknownSignatureScheme, + } + }, + .ed25519 => { + if (h.pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; + const Eddsa = crypto.sign.Ed25519; + if (h.signature.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; + const sig = Eddsa.Signature.fromBytes(h.signature[0..Eddsa.Signature.encoded_length].*); + if (h.pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; + const key = try Eddsa.PublicKey.fromBytes(h.pub_key[0..Eddsa.PublicKey.encoded_length].*); + try sig.verify(verify_bytes, key); + }, + inline .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + => |comptime_scheme| { + if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const Hash = SchemeHash(comptime_scheme); + const pk = try rsa.PublicKey.fromDer(h.pub_key); + const sig = rsa.Pss(Hash).Signature{ .bytes = h.signature }; + try sig.verify(verify_bytes, pk, null); + }, + inline .rsa_pkcs1_sha1, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + => |comptime_scheme| { + if (h.pub_key_algo != .rsaEncryption) return error.TlsBadSignatureScheme; + const Hash = SchemeHash(comptime_scheme); + const pk = try rsa.PublicKey.fromDer(h.pub_key); + const sig = rsa.PKCS1v1_5(Hash).Signature{ .bytes = h.signature }; + try sig.verify(verify_bytes, pk); + }, + else => return error.TlsUnknownSignatureScheme, + } + } + + fn SchemeEcdsaCert(comptime scheme: proto.SignatureScheme, comptime cert_named_curve: Certificate.NamedCurve) type { + const Sha256 = crypto.hash.sha2.Sha256; + const Sha384 = crypto.hash.sha2.Sha384; + const Ecdsa = crypto.sign.ecdsa.Ecdsa; + + return switch (scheme) { + .ecdsa_secp256r1_sha256 => Ecdsa(cert_named_curve.Curve(), Sha256), + .ecdsa_secp384r1_sha384 => Ecdsa(cert_named_curve.Curve(), Sha384), + else => @compileError("bad scheme"), + }; + } +}; + +fn SchemeHash(comptime scheme: proto.SignatureScheme) type { + const Sha256 = crypto.hash.sha2.Sha256; + const Sha384 = crypto.hash.sha2.Sha384; + const Sha512 = crypto.hash.sha2.Sha512; + + return switch (scheme) { + .rsa_pkcs1_sha1 => crypto.hash.Sha1, + .rsa_pss_rsae_sha256, .rsa_pkcs1_sha256 => Sha256, + .rsa_pss_rsae_sha384, .rsa_pkcs1_sha384 => Sha384, + .rsa_pss_rsae_sha512, .rsa_pkcs1_sha512 => Sha512, + else => @compileError("bad scheme"), + }; +} + +pub fn dupe(buf: []u8, data: []const u8) []u8 { + const n = @min(data.len, buf.len); + @memcpy(buf[0..n], data[0..n]); + return buf[0..n]; +} + +pub const DhKeyPair = struct { + x25519_kp: X25519.KeyPair = undefined, + secp256r1_kp: EcdsaP256Sha256.KeyPair = undefined, + secp384r1_kp: EcdsaP384Sha384.KeyPair = undefined, + kyber768_kp: Kyber768.KeyPair = undefined, + + pub const seed_len = 32 + 32 + 48 + 64; + + pub fn init(seed: [seed_len]u8, named_groups: []const proto.NamedGroup) !DhKeyPair { + var kp: DhKeyPair = .{}; + for (named_groups) |ng| + switch (ng) { + .x25519 => kp.x25519_kp = try X25519.KeyPair.create(seed[0..][0..X25519.seed_length].*), + .secp256r1 => kp.secp256r1_kp = try EcdsaP256Sha256.KeyPair.create(seed[32..][0..EcdsaP256Sha256.KeyPair.seed_length].*), + .secp384r1 => kp.secp384r1_kp = try EcdsaP384Sha384.KeyPair.create(seed[32 + 32 ..][0..EcdsaP384Sha384.KeyPair.seed_length].*), + .x25519_kyber768d00 => kp.kyber768_kp = try Kyber768.KeyPair.create(seed[32 + 32 + 48 ..][0..Kyber768.seed_length].*), + else => return error.TlsIllegalParameter, + }; + return kp; + } + + pub inline fn sharedKey(self: DhKeyPair, named_group: proto.NamedGroup, server_pub_key: []const u8) ![]const u8 { + return switch (named_group) { + .x25519 => brk: { + if (server_pub_key.len != X25519.public_length) + return error.TlsIllegalParameter; + break :brk &(try X25519.scalarmult( + self.x25519_kp.secret_key, + server_pub_key[0..X25519.public_length].*, + )); + }, + .secp256r1 => brk: { + const pk = try EcdsaP256Sha256.PublicKey.fromSec1(server_pub_key); + const mul = try pk.p.mulPublic(self.secp256r1_kp.secret_key.bytes, .big); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + .secp384r1 => brk: { + const pk = try EcdsaP384Sha384.PublicKey.fromSec1(server_pub_key); + const mul = try pk.p.mulPublic(self.secp384r1_kp.secret_key.bytes, .big); + break :brk &mul.affineCoordinates().x.toBytes(.big); + }, + .x25519_kyber768d00 => brk: { + const xksl = crypto.dh.X25519.public_length; + const hksl = xksl + Kyber768.ciphertext_length; + if (server_pub_key.len != hksl) + return error.TlsIllegalParameter; + + break :brk &((crypto.dh.X25519.scalarmult( + self.x25519_kp.secret_key, + server_pub_key[0..xksl].*, + ) catch return error.TlsDecryptFailure) ++ (self.kyber768_kp.secret_key.decaps( + server_pub_key[xksl..hksl], + ) catch return error.TlsDecryptFailure)); + }, + else => return error.TlsIllegalParameter, + }; + } + + // Returns 32, 65, 97 or 1216 bytes + pub inline fn publicKey(self: DhKeyPair, named_group: proto.NamedGroup) ![]const u8 { + return switch (named_group) { + .x25519 => &self.x25519_kp.public_key, + .secp256r1 => &self.secp256r1_kp.public_key.toUncompressedSec1(), + .secp384r1 => &self.secp384r1_kp.public_key.toUncompressedSec1(), + .x25519_kyber768d00 => &self.x25519_kp.public_key ++ self.kyber768_kp.public_key.toBytes(), + else => return error.TlsIllegalParameter, + }; + } +}; + +const testing = std.testing; +const testu = @import("testu.zig"); + +test "DhKeyPair.x25519" { + var seed: [DhKeyPair.seed_len]u8 = undefined; + testu.fill(&seed); + const server_pub_key = &testu.hexToBytes("3303486548531f08d91e675caf666c2dc924ac16f47a861a7f4d05919d143637"); + const expected = &testu.hexToBytes( + \\ F1 67 FB 4A 49 B2 91 77 08 29 45 A1 F7 08 5A 21 + \\ AF FE 9E 78 C2 03 9B 81 92 40 72 73 74 7A 46 1E + ); + const kp = try DhKeyPair.init(seed, &.{.x25519}); + try testing.expectEqualSlices(u8, expected, try kp.sharedKey(.x25519, server_pub_key)); +} diff --git a/src/tls.zig/handshake_server.zig b/src/tls.zig/handshake_server.zig new file mode 100644 index 0000000..c26e8c6 --- /dev/null +++ b/src/tls.zig/handshake_server.zig @@ -0,0 +1,520 @@ +const std = @import("std"); +const assert = std.debug.assert; +const crypto = std.crypto; +const mem = std.mem; +const Certificate = crypto.Certificate; + +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const CipherSuite = @import("cipher.zig").CipherSuite; +const cipher_suites = @import("cipher.zig").cipher_suites; +const Transcript = @import("transcript.zig").Transcript; +const record = @import("record.zig"); +const PrivateKey = @import("PrivateKey.zig"); +const proto = @import("protocol.zig"); + +const common = @import("handshake_common.zig"); +const dupe = common.dupe; +const CertificateBuilder = common.CertificateBuilder; +const CertificateParser = common.CertificateParser; +const DhKeyPair = common.DhKeyPair; +const CertBundle = common.CertBundle; +const CertKeyPair = common.CertKeyPair; + +pub const Options = struct { + /// Server authentication. If null server will not send Certificate and + /// CertificateVerify message. + auth: ?CertKeyPair, + + /// If not null server will request client certificate. If auth_type is + /// .request empty client certificate message will be accepted. + /// Client certificate will be verified with root_ca certificates. + client_auth: ?ClientAuth = null, +}; + +pub const ClientAuth = struct { + /// Set of root certificate authorities that server use when verifying + /// client certificates. + root_ca: CertBundle, + + auth_type: Type = .require, + + pub const Type = enum { + /// Client certificate will be requested during the handshake, but does + /// not require that the client send any certificates. + request, + /// Client certificate will be requested during the handshake, and client + /// has to send valid certificate. + require, + }; +}; + +pub fn Handshake(comptime Stream: type) type { + const RecordReaderT = record.Reader(Stream); + return struct { + // public key len: x25519 = 32, secp256r1 = 65, secp384r1 = 97 + const max_pub_key_len = 98; + const supported_named_groups = &[_]proto.NamedGroup{ .x25519, .secp256r1, .secp384r1 }; + + server_random: [32]u8 = undefined, + client_random: [32]u8 = undefined, + legacy_session_id_buf: [32]u8 = undefined, + legacy_session_id: []u8 = "", + cipher_suite: CipherSuite = @enumFromInt(0), + signature_scheme: proto.SignatureScheme = @enumFromInt(0), + named_group: proto.NamedGroup = @enumFromInt(0), + client_pub_key_buf: [max_pub_key_len]u8 = undefined, + client_pub_key: []u8 = "", + server_pub_key_buf: [max_pub_key_len]u8 = undefined, + server_pub_key: []u8 = "", + + cipher: Cipher = undefined, + transcript: Transcript = .{}, + rec_rdr: *RecordReaderT, + buffer: []u8, + + const HandshakeT = @This(); + + pub fn init(buf: []u8, rec_rdr: *RecordReaderT) HandshakeT { + return .{ + .rec_rdr = rec_rdr, + .buffer = buf, + }; + } + + fn writeAlert(h: *HandshakeT, stream: Stream, cph: ?*Cipher, err: anyerror) !void { + if (cph) |c| { + const cleartext = proto.alertFromError(err); + const ciphertext = try c.encrypt(h.buffer, .alert, &cleartext); + stream.writeAll(ciphertext) catch {}; + } else { + const alert = record.header(.alert, 2) ++ proto.alertFromError(err); + stream.writeAll(&alert) catch {}; + } + } + + pub fn handshake(h: *HandshakeT, stream: Stream, opt: Options) !Cipher { + crypto.random.bytes(&h.server_random); + if (opt.auth) |a| { + // required signature scheme in client hello + h.signature_scheme = a.key.signature_scheme; + } + + h.readClientHello() catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + h.transcript.use(h.cipher_suite.hash()); + + const server_flight = brk: { + var w = record.Writer{ .buf = h.buffer }; + + const shared_key = h.sharedKey() catch |err| { + try h.writeAlert(stream, null, err); + return err; + }; + { + const hello = try h.makeServerHello(w.getFree()); + h.transcript.update(hello[record.header_len..]); + w.pos += hello.len; + } + { + const handshake_secret = h.transcript.handshakeSecret(shared_key); + h.cipher = try Cipher.initTls13(h.cipher_suite, handshake_secret, .server); + } + try w.writeRecord(.change_cipher_spec, &[_]u8{1}); + { + const encrypted_extensions = &record.handshakeHeader(.encrypted_extensions, 2) ++ [_]u8{ 0, 0 }; + h.transcript.update(encrypted_extensions); + try h.writeEncrypted(&w, encrypted_extensions); + } + if (opt.client_auth) |_| { + const certificate_request = try makeCertificateRequest(w.getPayload()); + h.transcript.update(certificate_request); + try h.writeEncrypted(&w, certificate_request); + } + if (opt.auth) |a| { + const cm = CertificateBuilder{ + .bundle = a.bundle, + .key = a.key, + .transcript = &h.transcript, + .side = .server, + }; + { + const certificate = try cm.makeCertificate(w.getPayload()); + h.transcript.update(certificate); + try h.writeEncrypted(&w, certificate); + } + { + const certificate_verify = try cm.makeCertificateVerify(w.getPayload()); + h.transcript.update(certificate_verify); + try h.writeEncrypted(&w, certificate_verify); + } + } + { + const finished = try h.makeFinished(w.getPayload()); + h.transcript.update(finished); + try h.writeEncrypted(&w, finished); + } + break :brk w.getWritten(); + }; + try stream.writeAll(server_flight); + + var app_cipher = brk: { + const application_secret = h.transcript.applicationSecret(); + break :brk try Cipher.initTls13(h.cipher_suite, application_secret, .server); + }; + + h.readClientFlight2(opt) catch |err| { + // Alert received from client + if (!mem.startsWith(u8, @errorName(err), "TlsAlert")) { + try h.writeAlert(stream, &app_cipher, err); + } + return err; + }; + return app_cipher; + } + + inline fn sharedKey(h: *HandshakeT) ![]const u8 { + var seed: [DhKeyPair.seed_len]u8 = undefined; + crypto.random.bytes(&seed); + var kp = try DhKeyPair.init(seed, supported_named_groups); + h.server_pub_key = dupe(&h.server_pub_key_buf, try kp.publicKey(h.named_group)); + return try kp.sharedKey(h.named_group, h.client_pub_key); + } + + fn readClientFlight2(h: *HandshakeT, opt: Options) !void { + var cleartext_buf = h.buffer; + var cleartext_buf_head: usize = 0; + var cleartext_buf_tail: usize = 0; + var handshake_state: proto.Handshake = .finished; + var cert: CertificateParser = undefined; + if (opt.client_auth) |client_auth| { + cert = .{ .root_ca = client_auth.root_ca.bundle, .host = "" }; + handshake_state = .certificate; + } + + outer: while (true) { + const rec = (try h.rec_rdr.next() orelse return error.EndOfStream); + if (rec.protocol_version != .tls_1_2 and rec.content_type != .alert) + return error.TlsProtocolVersion; + + switch (rec.content_type) { + .change_cipher_spec => { + if (rec.payload.len != 1) return error.TlsUnexpectedMessage; + }, + .application_data => { + const content_type, const cleartext = try h.cipher.decrypt( + cleartext_buf[cleartext_buf_tail..], + rec, + ); + cleartext_buf_tail += cleartext.len; + if (cleartext_buf_tail > cleartext_buf.len) return error.TlsRecordOverflow; + + var d = record.Decoder.init(content_type, cleartext_buf[cleartext_buf_head..cleartext_buf_tail]); + try d.expectContentType(.handshake); + while (!d.eof()) { + const start_idx = d.idx; + const handshake_type = try d.decode(proto.Handshake); + const length = try d.decode(u24); + + if (length > cipher.max_cleartext_len) + return error.TlsRecordOverflow; + if (length > d.rest().len) + continue :outer; // fragmented handshake into multiple records + + defer { + const handshake_payload = d.payload[start_idx..d.idx]; + h.transcript.update(handshake_payload); + cleartext_buf_head += handshake_payload.len; + } + + if (handshake_state != handshake_type) + return error.TlsUnexpectedMessage; + + switch (handshake_type) { + .certificate => { + if (length == 4) { + // got empty certificate message + if (opt.client_auth.?.auth_type == .require) + return error.TlsCertificateRequired; + try d.skip(length); + handshake_state = .finished; + } else { + try cert.parseCertificate(&d, .tls_1_3); + handshake_state = .certificate_verify; + } + }, + .certificate_verify => { + try cert.parseCertificateVerify(&d); + cert.verifySignature(h.transcript.clientCertificateVerify()) catch |err| return switch (err) { + error.TlsUnknownSignatureScheme => error.TlsIllegalParameter, + else => error.TlsDecryptError, + }; + handshake_state = .finished; + }, + .finished => { + const actual = try d.slice(length); + var buf: [Transcript.max_mac_length]u8 = undefined; + const expected = h.transcript.clientFinishedTls13(&buf); + if (!mem.eql(u8, expected, actual)) + return if (expected.len == actual.len) + error.TlsDecryptError + else + error.TlsDecodeError; + return; + }, + else => return error.TlsUnexpectedMessage, + } + } + cleartext_buf_head = 0; + cleartext_buf_tail = 0; + }, + .alert => { + var d = rec.decoder(); + return d.raiseAlert(); + }, + else => return error.TlsUnexpectedMessage, + } + } + } + + fn makeFinished(h: *HandshakeT, buf: []u8) ![]const u8 { + var w = record.Writer{ .buf = buf }; + const verify_data = h.transcript.serverFinishedTls13(w.getHandshakePayload()); + try w.advanceHandshake(.finished, verify_data.len); + return w.getWritten(); + } + + /// Write encrypted handshake message into `w` + fn writeEncrypted(h: *HandshakeT, w: *record.Writer, cleartext: []const u8) !void { + const ciphertext = try h.cipher.encrypt(w.getFree(), .handshake, cleartext); + w.pos += ciphertext.len; + } + + fn makeServerHello(h: *HandshakeT, buf: []u8) ![]const u8 { + const header_len = 9; // tls record header (5 bytes) and handshake header (4 bytes) + var w = record.Writer{ .buf = buf[header_len..] }; + + try w.writeEnum(proto.Version.tls_1_2); + try w.write(&h.server_random); + { + try w.writeInt(@as(u8, @intCast(h.legacy_session_id.len))); + if (h.legacy_session_id.len > 0) try w.write(h.legacy_session_id); + } + try w.writeEnum(h.cipher_suite); + try w.write(&[_]u8{0}); // compression method + + var e = record.Writer{ .buf = buf[header_len + w.pos + 2 ..] }; + { // supported versions extension + try e.writeEnum(proto.Extension.supported_versions); + try e.writeInt(@as(u16, 2)); + try e.writeEnum(proto.Version.tls_1_3); + } + { // key share extension + const key_len: u16 = @intCast(h.server_pub_key.len); + try e.writeEnum(proto.Extension.key_share); + try e.writeInt(key_len + 4); + try e.writeEnum(h.named_group); + try e.writeInt(key_len); + try e.write(h.server_pub_key); + } + try w.writeInt(@as(u16, @intCast(e.pos))); // extensions length + + const payload_len = w.pos + e.pos; + buf[0..header_len].* = record.header(.handshake, 4 + payload_len) ++ + record.handshakeHeader(.server_hello, payload_len); + + return buf[0 .. header_len + payload_len]; + } + + fn makeCertificateRequest(buf: []u8) ![]const u8 { + // handshake header + context length + extensions length + const header_len = 4 + 1 + 2; + + // First write extensions, leave space for header. + var ext = record.Writer{ .buf = buf[header_len..] }; + try ext.writeExtension(.signature_algorithms, common.supported_signature_algorithms); + + var w = record.Writer{ .buf = buf }; + try w.writeHandshakeHeader(.certificate_request, ext.pos + 3); + try w.writeInt(@as(u8, 0)); // certificate request context length = 0 + try w.writeInt(@as(u16, @intCast(ext.pos))); // extensions length + assert(w.pos == header_len); + w.pos += ext.pos; + + return w.getWritten(); + } + + fn readClientHello(h: *HandshakeT) !void { + var d = try h.rec_rdr.nextDecoder(); + try d.expectContentType(.handshake); + h.transcript.update(d.payload); + + const handshake_type = try d.decode(proto.Handshake); + if (handshake_type != .client_hello) return error.TlsUnexpectedMessage; + _ = try d.decode(u24); // handshake length + if (try d.decode(proto.Version) != .tls_1_2) return error.TlsProtocolVersion; + + h.client_random = try d.array(32); + { // legacy session id + const len = try d.decode(u8); + h.legacy_session_id = dupe(&h.legacy_session_id_buf, try d.slice(len)); + } + { // cipher suites + const end_idx = try d.decode(u16) + d.idx; + + while (d.idx < end_idx) { + const cipher_suite = try d.decode(CipherSuite); + if (cipher_suites.includes(cipher_suites.tls13, cipher_suite) and + @intFromEnum(h.cipher_suite) == 0) + { + h.cipher_suite = cipher_suite; + } + } + if (@intFromEnum(h.cipher_suite) == 0) + return error.TlsHandshakeFailure; + } + try d.skip(2); // compression methods + + var key_share_received = false; + // extensions + const extensions_end_idx = try d.decode(u16) + d.idx; + while (d.idx < extensions_end_idx) { + const extension_type = try d.decode(proto.Extension); + const extension_len = try d.decode(u16); + + switch (extension_type) { + .supported_versions => { + var tls_1_3_supported = false; + const end_idx = try d.decode(u8) + d.idx; + while (d.idx < end_idx) { + if (try d.decode(proto.Version) == proto.Version.tls_1_3) { + tls_1_3_supported = true; + } + } + if (!tls_1_3_supported) return error.TlsProtocolVersion; + }, + .key_share => { + if (extension_len == 0) return error.TlsDecodeError; + key_share_received = true; + var selected_named_group_idx = supported_named_groups.len; + const end_idx = try d.decode(u16) + d.idx; + while (d.idx < end_idx) { + const named_group = try d.decode(proto.NamedGroup); + switch (@intFromEnum(named_group)) { + 0x0001...0x0016, + 0x001a...0x001c, + 0xff01...0xff02, + => return error.TlsIllegalParameter, + else => {}, + } + const client_pub_key = try d.slice(try d.decode(u16)); + for (supported_named_groups, 0..) |supported, idx| { + if (named_group == supported and idx < selected_named_group_idx) { + h.named_group = named_group; + h.client_pub_key = dupe(&h.client_pub_key_buf, client_pub_key); + selected_named_group_idx = idx; + } + } + } + if (@intFromEnum(h.named_group) == 0) + return error.TlsIllegalParameter; + }, + .supported_groups => { + const end_idx = try d.decode(u16) + d.idx; + while (d.idx < end_idx) { + const named_group = try d.decode(proto.NamedGroup); + switch (@intFromEnum(named_group)) { + 0x0001...0x0016, + 0x001a...0x001c, + 0xff01...0xff02, + => return error.TlsIllegalParameter, + else => {}, + } + } + }, + .signature_algorithms => { + if (@intFromEnum(h.signature_scheme) == 0) { + try d.skip(extension_len); + } else { + var found = false; + const list_len = try d.decode(u16); + if (list_len == 0) return error.TlsDecodeError; + const end_idx = list_len + d.idx; + while (d.idx < end_idx) { + const signature_scheme = try d.decode(proto.SignatureScheme); + if (signature_scheme == h.signature_scheme) found = true; + } + if (!found) return error.TlsHandshakeFailure; + } + }, + else => { + try d.skip(extension_len); + }, + } + } + if (!key_share_received) return error.TlsMissingExtension; + if (@intFromEnum(h.named_group) == 0) return error.TlsIllegalParameter; + } + }; +} + +const testing = std.testing; +const data13 = @import("testdata/tls13.zig"); +const testu = @import("testu.zig"); + +fn testReader(data: []const u8) record.Reader(std.io.FixedBufferStream([]const u8)) { + return record.reader(std.io.fixedBufferStream(data)); +} +const TestHandshake = Handshake(std.io.FixedBufferStream([]const u8)); + +test "read client hello" { + var buffer: [1024]u8 = undefined; + var rec_rdr = testReader(&data13.client_hello); + var h = TestHandshake.init(&buffer, &rec_rdr); + h.signature_scheme = .ecdsa_secp521r1_sha512; // this must be supported in signature_algorithms extension + try h.readClientHello(); + + try testing.expectEqual(CipherSuite.AES_256_GCM_SHA384, h.cipher_suite); + try testing.expectEqual(.x25519, h.named_group); + try testing.expectEqualSlices(u8, &data13.client_random, &h.client_random); + try testing.expectEqualSlices(u8, &data13.client_public_key, h.client_pub_key); +} + +test "make server hello" { + var buffer: [128]u8 = undefined; + var h = TestHandshake.init(&buffer, undefined); + h.cipher_suite = .AES_256_GCM_SHA384; + testu.fillFrom(&h.server_random, 0); + testu.fillFrom(&h.server_pub_key_buf, 0x20); + h.named_group = .x25519; + h.server_pub_key = h.server_pub_key_buf[0..32]; + + const actual = try h.makeServerHello(&buffer); + const expected = &testu.hexToBytes( + \\ 16 03 03 00 5a 02 00 00 56 + \\ 03 03 + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f + \\ 00 + \\ 13 02 00 + \\ 00 2e 00 2b 00 02 03 04 + \\ 00 33 00 24 00 1d 00 20 + \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f + ); + try testing.expectEqualSlices(u8, expected, actual); +} + +test "make certificate request" { + var buffer: [32]u8 = undefined; + + const expected = testu.hexToBytes("0d 00 00 1b" ++ // handshake header + "00 00 18" ++ // extension length + "00 0d" ++ // signature algorithms extension + "00 14" ++ // extension length + "00 12" ++ // list length 6 * 2 bytes + "04 03 05 03 08 04 08 05 08 06 08 07 02 01 04 01 05 01" // signature schemes + ); + const actual = try TestHandshake.makeCertificateRequest(&buffer); + try testing.expectEqualSlices(u8, &expected, actual); +} diff --git a/src/tls.zig/key_log.zig b/src/tls.zig/key_log.zig new file mode 100644 index 0000000..2da83f4 --- /dev/null +++ b/src/tls.zig/key_log.zig @@ -0,0 +1,60 @@ +//! Exporting tls key so we can share them with Wireshark and analyze decrypted +//! traffic in Wireshark. +//! To configure Wireshark to use exprted keys see curl reference. +//! +//! References: +//! curl: https://everything.curl.dev/usingcurl/tls/sslkeylogfile.html +//! openssl: https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_keylog_callback.html +//! https://udn.realityripple.com/docs/Mozilla/Projects/NSS/Key_Log_Format + +const std = @import("std"); + +const key_log_file_env = "SSLKEYLOGFILE"; + +pub const label = struct { + // tls 1.3 + pub const client_handshake_traffic_secret: []const u8 = "CLIENT_HANDSHAKE_TRAFFIC_SECRET"; + pub const server_handshake_traffic_secret: []const u8 = "SERVER_HANDSHAKE_TRAFFIC_SECRET"; + pub const client_traffic_secret_0: []const u8 = "CLIENT_TRAFFIC_SECRET_0"; + pub const server_traffic_secret_0: []const u8 = "SERVER_TRAFFIC_SECRET_0"; + // tls 1.2 + pub const client_random: []const u8 = "CLIENT_RANDOM"; +}; + +pub const Callback = *const fn (label: []const u8, client_random: []const u8, secret: []const u8) void; + +/// Writes tls keys to the file pointed by SSLKEYLOGFILE environment variable. +pub fn callback(label_: []const u8, client_random: []const u8, secret: []const u8) void { + if (std.posix.getenv(key_log_file_env)) |file_name| { + fileAppend(file_name, label_, client_random, secret) catch return; + } +} + +pub fn fileAppend(file_name: []const u8, label_: []const u8, client_random: []const u8, secret: []const u8) !void { + var buf: [1024]u8 = undefined; + const line = try formatLine(&buf, label_, client_random, secret); + try fileWrite(file_name, line); +} + +fn fileWrite(file_name: []const u8, line: []const u8) !void { + var file = try std.fs.createFileAbsolute(file_name, .{ .truncate = false }); + defer file.close(); + const stat = try file.stat(); + try file.seekTo(stat.size); + try file.writeAll(line); +} + +pub fn formatLine(buf: []u8, label_: []const u8, client_random: []const u8, secret: []const u8) ![]const u8 { + var fbs = std.io.fixedBufferStream(buf); + const w = fbs.writer(); + try w.print("{s} ", .{label_}); + for (client_random) |b| { + try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); + } + try w.writeByte(' '); + for (secret) |b| { + try std.fmt.formatInt(b, 16, .lower, .{ .width = 2, .fill = '0' }, w); + } + try w.writeByte('\n'); + return fbs.getWritten(); +} diff --git a/src/tls.zig/main.zig b/src/tls.zig/main.zig new file mode 100644 index 0000000..b974377 --- /dev/null +++ b/src/tls.zig/main.zig @@ -0,0 +1,51 @@ +const std = @import("std"); + +pub const CipherSuite = @import("cipher.zig").CipherSuite; +pub const cipher_suites = @import("cipher.zig").cipher_suites; +pub const PrivateKey = @import("PrivateKey.zig"); +pub const Connection = @import("connection.zig").Connection; +pub const ClientOptions = @import("handshake_client.zig").Options; +pub const ServerOptions = @import("handshake_server.zig").Options; +pub const key_log = @import("key_log.zig"); +pub const proto = @import("protocol.zig"); +pub const NamedGroup = proto.NamedGroup; +pub const Version = proto.Version; +const common = @import("handshake_common.zig"); +pub const CertBundle = common.CertBundle; +pub const CertKeyPair = common.CertKeyPair; + +pub const record = @import("record.zig"); +const connection = @import("connection.zig").connection; +const max_ciphertext_record_len = @import("cipher.zig").max_ciphertext_record_len; +const HandshakeServer = @import("handshake_server.zig").Handshake; +const HandshakeClient = @import("handshake_client.zig").Handshake; + +pub fn client(stream: anytype, opt: ClientOptions) !Connection(@TypeOf(stream)) { + const Stream = @TypeOf(stream); + var conn = connection(stream); + var write_buf: [max_ciphertext_record_len]u8 = undefined; + var h = HandshakeClient(Stream).init(&write_buf, &conn.rec_rdr); + conn.cipher = try h.handshake(conn.stream, opt); + return conn; +} + +pub fn server(stream: anytype, opt: ServerOptions) !Connection(@TypeOf(stream)) { + const Stream = @TypeOf(stream); + var conn = connection(stream); + var write_buf: [max_ciphertext_record_len]u8 = undefined; + var h = HandshakeServer(Stream).init(&write_buf, &conn.rec_rdr); + conn.cipher = try h.handshake(conn.stream, opt); + return conn; +} + +test { + _ = @import("handshake_common.zig"); + _ = @import("handshake_server.zig"); + _ = @import("handshake_client.zig"); + + _ = @import("connection.zig"); + _ = @import("cipher.zig"); + _ = @import("record.zig"); + _ = @import("transcript.zig"); + _ = @import("PrivateKey.zig"); +} diff --git a/src/tls.zig/protocol.zig b/src/tls.zig/protocol.zig new file mode 100644 index 0000000..e3bb07a --- /dev/null +++ b/src/tls.zig/protocol.zig @@ -0,0 +1,302 @@ +pub const Version = enum(u16) { + tls_1_2 = 0x0303, + tls_1_3 = 0x0304, + _, +}; + +pub const ContentType = enum(u8) { + invalid = 0, + change_cipher_spec = 20, + alert = 21, + handshake = 22, + application_data = 23, + _, +}; + +pub const Handshake = enum(u8) { + client_hello = 1, + server_hello = 2, + new_session_ticket = 4, + end_of_early_data = 5, + encrypted_extensions = 8, + certificate = 11, + server_key_exchange = 12, + certificate_request = 13, + server_hello_done = 14, + certificate_verify = 15, + client_key_exchange = 16, + finished = 20, + key_update = 24, + message_hash = 254, + _, +}; + +pub const Curve = enum(u8) { + named_curve = 0x03, + _, +}; + +pub const Extension = enum(u16) { + /// RFC 6066 + server_name = 0, + /// RFC 6066 + max_fragment_length = 1, + /// RFC 6066 + status_request = 5, + /// RFC 8422, 7919 + supported_groups = 10, + /// RFC 8446 + signature_algorithms = 13, + /// RFC 5764 + use_srtp = 14, + /// RFC 6520 + heartbeat = 15, + /// RFC 7301 + application_layer_protocol_negotiation = 16, + /// RFC 6962 + signed_certificate_timestamp = 18, + /// RFC 7250 + client_certificate_type = 19, + /// RFC 7250 + server_certificate_type = 20, + /// RFC 7685 + padding = 21, + /// RFC 8446 + pre_shared_key = 41, + /// RFC 8446 + early_data = 42, + /// RFC 8446 + supported_versions = 43, + /// RFC 8446 + cookie = 44, + /// RFC 8446 + psk_key_exchange_modes = 45, + /// RFC 8446 + certificate_authorities = 47, + /// RFC 8446 + oid_filters = 48, + /// RFC 8446 + post_handshake_auth = 49, + /// RFC 8446 + signature_algorithms_cert = 50, + /// RFC 8446 + key_share = 51, + + _, +}; + +pub fn alertFromError(err: anyerror) [2]u8 { + return [2]u8{ @intFromEnum(Alert.Level.fatal), @intFromEnum(Alert.fromError(err)) }; +} + +pub const Alert = enum(u8) { + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, + }; + + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; + + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(alert: Alert) Error!void { + return switch (alert) { + .close_notify => {}, // not an error + .unexpected_message => error.TlsAlertUnexpectedMessage, + .bad_record_mac => error.TlsAlertBadRecordMac, + .record_overflow => error.TlsAlertRecordOverflow, + .handshake_failure => error.TlsAlertHandshakeFailure, + .bad_certificate => error.TlsAlertBadCertificate, + .unsupported_certificate => error.TlsAlertUnsupportedCertificate, + .certificate_revoked => error.TlsAlertCertificateRevoked, + .certificate_expired => error.TlsAlertCertificateExpired, + .certificate_unknown => error.TlsAlertCertificateUnknown, + .illegal_parameter => error.TlsAlertIllegalParameter, + .unknown_ca => error.TlsAlertUnknownCa, + .access_denied => error.TlsAlertAccessDenied, + .decode_error => error.TlsAlertDecodeError, + .decrypt_error => error.TlsAlertDecryptError, + .protocol_version => error.TlsAlertProtocolVersion, + .insufficient_security => error.TlsAlertInsufficientSecurity, + .internal_error => error.TlsAlertInternalError, + .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => error.TlsAlertMissingExtension, + .unsupported_extension => error.TlsAlertUnsupportedExtension, + .unrecognized_name => error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, + .certificate_required => error.TlsAlertCertificateRequired, + .no_application_protocol => error.TlsAlertNoApplicationProtocol, + _ => error.TlsAlertUnknown, + }; + } + + pub fn fromError(err: anyerror) Alert { + return switch (err) { + error.TlsUnexpectedMessage => .unexpected_message, + error.TlsBadRecordMac => .bad_record_mac, + error.TlsRecordOverflow => .record_overflow, + error.TlsHandshakeFailure => .handshake_failure, + error.TlsBadCertificate => .bad_certificate, + error.TlsUnsupportedCertificate => .unsupported_certificate, + error.TlsCertificateRevoked => .certificate_revoked, + error.TlsCertificateExpired => .certificate_expired, + error.TlsCertificateUnknown => .certificate_unknown, + error.TlsIllegalParameter, + error.IdentityElement, + error.InvalidEncoding, + => .illegal_parameter, + error.TlsUnknownCa => .unknown_ca, + error.TlsAccessDenied => .access_denied, + error.TlsDecodeError => .decode_error, + error.TlsDecryptError => .decrypt_error, + error.TlsProtocolVersion => .protocol_version, + error.TlsInsufficientSecurity => .insufficient_security, + error.TlsInternalError => .internal_error, + error.TlsInappropriateFallback => .inappropriate_fallback, + error.TlsMissingExtension => .missing_extension, + error.TlsUnsupportedExtension => .unsupported_extension, + error.TlsUnrecognizedName => .unrecognized_name, + error.TlsBadCertificateStatusResponse => .bad_certificate_status_response, + error.TlsUnknownPskIdentity => .unknown_psk_identity, + error.TlsCertificateRequired => .certificate_required, + error.TlsNoApplicationProtocol => .no_application_protocol, + else => .internal_error, + }; + } + + pub fn parse(buf: [2]u8) Alert { + const level: Alert.Level = @enumFromInt(buf[0]); + const alert: Alert = @enumFromInt(buf[1]); + _ = level; + return alert; + } + + pub fn closeNotify() [2]u8 { + return [2]u8{ + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.close_notify), + }; + } +}; + +pub const SignatureScheme = enum(u16) { + // RSASSA-PKCS1-v1_5 algorithms + rsa_pkcs1_sha256 = 0x0401, + rsa_pkcs1_sha384 = 0x0501, + rsa_pkcs1_sha512 = 0x0601, + + // ECDSA algorithms + ecdsa_secp256r1_sha256 = 0x0403, + ecdsa_secp384r1_sha384 = 0x0503, + ecdsa_secp521r1_sha512 = 0x0603, + + // RSASSA-PSS algorithms with public key OID rsaEncryption + rsa_pss_rsae_sha256 = 0x0804, + rsa_pss_rsae_sha384 = 0x0805, + rsa_pss_rsae_sha512 = 0x0806, + + // EdDSA algorithms + ed25519 = 0x0807, + ed448 = 0x0808, + + // RSASSA-PSS algorithms with public key OID RSASSA-PSS + rsa_pss_pss_sha256 = 0x0809, + rsa_pss_pss_sha384 = 0x080a, + rsa_pss_pss_sha512 = 0x080b, + + // Legacy algorithms + rsa_pkcs1_sha1 = 0x0201, + ecdsa_sha1 = 0x0203, + + _, +}; + +pub const NamedGroup = enum(u16) { + // Elliptic Curve Groups (ECDHE) + secp256r1 = 0x0017, + secp384r1 = 0x0018, + secp521r1 = 0x0019, + x25519 = 0x001D, + x448 = 0x001E, + + // Finite Field Groups (DHE) + ffdhe2048 = 0x0100, + ffdhe3072 = 0x0101, + ffdhe4096 = 0x0102, + ffdhe6144 = 0x0103, + ffdhe8192 = 0x0104, + + // Hybrid post-quantum key agreements + x25519_kyber512d00 = 0xFE30, + x25519_kyber768d00 = 0x6399, + + _, +}; + +pub const KeyUpdateRequest = enum(u8) { + update_not_requested = 0, + update_requested = 1, + _, +}; + +pub const Side = enum { + client, + server, +}; diff --git a/src/tls.zig/record.zig b/src/tls.zig/record.zig new file mode 100644 index 0000000..6c4df32 --- /dev/null +++ b/src/tls.zig/record.zig @@ -0,0 +1,405 @@ +const std = @import("std"); +const assert = std.debug.assert; +const mem = std.mem; + +const proto = @import("protocol.zig"); +const cipher = @import("cipher.zig"); +const Cipher = cipher.Cipher; +const record = @import("record.zig"); + +pub const header_len = 5; + +pub fn header(content_type: proto.ContentType, payload_len: usize) [header_len]u8 { + const int2 = std.crypto.tls.int2; + return [1]u8{@intFromEnum(content_type)} ++ + int2(@intFromEnum(proto.Version.tls_1_2)) ++ + int2(@intCast(payload_len)); +} + +pub fn handshakeHeader(handshake_type: proto.Handshake, payload_len: usize) [4]u8 { + const int3 = std.crypto.tls.int3; + return [1]u8{@intFromEnum(handshake_type)} ++ int3(@intCast(payload_len)); +} + +pub fn reader(inner_reader: anytype) Reader(@TypeOf(inner_reader)) { + return .{ .inner_reader = inner_reader }; +} + +pub fn Reader(comptime InnerReader: type) type { + return struct { + inner_reader: InnerReader, + + buffer: [cipher.max_ciphertext_record_len]u8 = undefined, + start: usize = 0, + end: usize = 0, + + const ReaderT = @This(); + + pub fn nextDecoder(r: *ReaderT) !Decoder { + const rec = (try r.next()) orelse return error.EndOfStream; + if (@intFromEnum(rec.protocol_version) != 0x0300 and + @intFromEnum(rec.protocol_version) != 0x0301 and + rec.protocol_version != .tls_1_2) + return error.TlsBadVersion; + return .{ + .content_type = rec.content_type, + .payload = rec.payload, + }; + } + + pub fn contentType(buf: []const u8) proto.ContentType { + return @enumFromInt(buf[0]); + } + + pub fn protocolVersion(buf: []const u8) proto.Version { + return @enumFromInt(mem.readInt(u16, buf[1..3], .big)); + } + + pub fn next(r: *ReaderT) !?Record { + while (true) { + const buffer = r.buffer[r.start..r.end]; + // If we have 5 bytes header. + if (buffer.len >= record.header_len) { + const record_header = buffer[0..record.header_len]; + const payload_len = mem.readInt(u16, record_header[3..5], .big); + if (payload_len > cipher.max_ciphertext_len) + return error.TlsRecordOverflow; + const record_len = record.header_len + payload_len; + // If we have whole record + if (buffer.len >= record_len) { + r.start += record_len; + return Record.init(buffer[0..record_len]); + } + } + { // Move dirty part to the start of the buffer. + const n = r.end - r.start; + if (n > 0 and r.start > 0) { + if (r.start > n) { + @memcpy(r.buffer[0..n], r.buffer[r.start..][0..n]); + } else { + mem.copyForwards(u8, r.buffer[0..n], r.buffer[r.start..][0..n]); + } + } + r.start = 0; + r.end = n; + } + { // Read more from inner_reader. + const n = try r.inner_reader.read(r.buffer[r.end..]); + if (n == 0) return null; + r.end += n; + } + } + } + + pub fn nextDecrypt(r: *ReaderT, cph: *Cipher) !?struct { proto.ContentType, []const u8 } { + const rec = (try r.next()) orelse return null; + if (rec.protocol_version != .tls_1_2) return error.TlsBadVersion; + + return try cph.decrypt( + // Reuse reader buffer for cleartext. `rec.header` and + // `rec.payload`(ciphertext) are also pointing somewhere in + // this buffer. Decrypter is first reading then writing a + // block, cleartext has less length then ciphertext, + // cleartext starts from the beginning of the buffer, so + // ciphertext is always ahead of cleartext. + r.buffer[0..r.start], + rec, + ); + } + + pub fn hasMore(r: *ReaderT) bool { + return r.end > r.start; + } + }; +} + +pub const Record = struct { + content_type: proto.ContentType, + protocol_version: proto.Version = .tls_1_2, + header: []const u8, + payload: []const u8, + + pub fn init(buffer: []const u8) Record { + return .{ + .content_type = @enumFromInt(buffer[0]), + .protocol_version = @enumFromInt(mem.readInt(u16, buffer[1..3], .big)), + .header = buffer[0..record.header_len], + .payload = buffer[record.header_len..], + }; + } + + pub fn decoder(r: @This()) Decoder { + return Decoder.init(r.content_type, @constCast(r.payload)); + } +}; + +pub const Decoder = struct { + content_type: proto.ContentType, + payload: []const u8, + idx: usize = 0, + + pub fn init(content_type: proto.ContentType, payload: []u8) Decoder { + return .{ + .content_type = content_type, + .payload = payload, + }; + } + + pub fn decode(d: *Decoder, comptime T: type) !T { + switch (@typeInfo(T)) { + .Int => |info| switch (info.bits) { + 8 => { + try skip(d, 1); + return d.payload[d.idx - 1]; + }, + 16 => { + try skip(d, 2); + const b0: u16 = d.payload[d.idx - 2]; + const b1: u16 = d.payload[d.idx - 1]; + return (b0 << 8) | b1; + }, + 24 => { + try skip(d, 3); + const b0: u24 = d.payload[d.idx - 3]; + const b1: u24 = d.payload[d.idx - 2]; + const b2: u24 = d.payload[d.idx - 1]; + return (b0 << 16) | (b1 << 8) | b2; + }, + else => @compileError("unsupported int type: " ++ @typeName(T)), + }, + .Enum => |info| { + const int = try d.decode(info.tag_type); + if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); + return @as(T, @enumFromInt(int)); + }, + else => @compileError("unsupported type: " ++ @typeName(T)), + } + } + + pub fn array(d: *Decoder, comptime len: usize) ![len]u8 { + try d.skip(len); + return d.payload[d.idx - len ..][0..len].*; + } + + pub fn slice(d: *Decoder, len: usize) ![]const u8 { + try d.skip(len); + return d.payload[d.idx - len ..][0..len]; + } + + pub fn skip(d: *Decoder, amt: usize) !void { + if (d.idx + amt > d.payload.len) return error.TlsDecodeError; + d.idx += amt; + } + + pub fn rest(d: Decoder) []const u8 { + return d.payload[d.idx..]; + } + + pub fn eof(d: Decoder) bool { + return d.idx == d.payload.len; + } + + pub fn expectContentType(d: *Decoder, content_type: proto.ContentType) !void { + if (d.content_type == content_type) return; + + switch (d.content_type) { + .alert => try d.raiseAlert(), + else => return error.TlsUnexpectedMessage, + } + } + + pub fn raiseAlert(d: *Decoder) !void { + if (d.payload.len < 2) return error.TlsUnexpectedMessage; + try proto.Alert.parse(try d.array(2)).toError(); + return error.TlsAlertCloseNotify; + } +}; + +const testing = std.testing; +const data12 = @import("testdata/tls12.zig"); +const testu = @import("testu.zig"); +const CipherSuite = @import("cipher.zig").CipherSuite; + +test Reader { + var fbs = std.io.fixedBufferStream(&data12.server_responses); + var rdr = reader(fbs.reader()); + + const expected = [_]struct { + content_type: proto.ContentType, + payload_len: usize, + }{ + .{ .content_type = .handshake, .payload_len = 49 }, + .{ .content_type = .handshake, .payload_len = 815 }, + .{ .content_type = .handshake, .payload_len = 300 }, + .{ .content_type = .handshake, .payload_len = 4 }, + .{ .content_type = .change_cipher_spec, .payload_len = 1 }, + .{ .content_type = .handshake, .payload_len = 64 }, + }; + for (expected) |e| { + const rec = (try rdr.next()).?; + try testing.expectEqual(e.content_type, rec.content_type); + try testing.expectEqual(e.payload_len, rec.payload.len); + try testing.expectEqual(.tls_1_2, rec.protocol_version); + } +} + +test Decoder { + var fbs = std.io.fixedBufferStream(&data12.server_responses); + var rdr = reader(fbs.reader()); + + var d = (try rdr.nextDecoder()); + try testing.expectEqual(.handshake, d.content_type); + + try testing.expectEqual(.server_hello, try d.decode(proto.Handshake)); + try testing.expectEqual(45, try d.decode(u24)); // length + try testing.expectEqual(.tls_1_2, try d.decode(proto.Version)); + try testing.expectEqualStrings( + &testu.hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"), + try d.slice(32), + ); // server random + try testing.expectEqual(0, try d.decode(u8)); // session id len + try testing.expectEqual(.ECDHE_RSA_WITH_AES_128_CBC_SHA, try d.decode(CipherSuite)); + try testing.expectEqual(0, try d.decode(u8)); // compression method + try testing.expectEqual(5, try d.decode(u16)); // extension length + try testing.expectEqual(5, d.rest().len); + try d.skip(5); + try testing.expect(d.eof()); +} + +pub const Writer = struct { + buf: []u8, + pos: usize = 0, + + pub fn write(self: *Writer, data: []const u8) !void { + defer self.pos += data.len; + if (self.pos + data.len > self.buf.len) return error.BufferOverflow; + @memcpy(self.buf[self.pos..][0..data.len], data); + } + + pub fn writeByte(self: *Writer, b: u8) !void { + defer self.pos += 1; + if (self.pos == self.buf.len) return error.BufferOverflow; + self.buf[self.pos] = b; + } + + pub fn writeEnum(self: *Writer, value: anytype) !void { + try self.writeInt(@intFromEnum(value)); + } + + pub fn writeInt(self: *Writer, value: anytype) !void { + const IntT = @TypeOf(value); + const bytes = @divExact(@typeInfo(IntT).Int.bits, 8); + const free = self.buf[self.pos..]; + if (free.len < bytes) return error.BufferOverflow; + mem.writeInt(IntT, free[0..bytes], value, .big); + self.pos += bytes; + } + + pub fn writeHandshakeHeader(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { + try self.write(&record.handshakeHeader(handshake_type, payload_len)); + } + + /// Should be used after writing handshake payload in buffer provided by `getHandshakePayload`. + pub fn advanceHandshake(self: *Writer, handshake_type: proto.Handshake, payload_len: usize) !void { + try self.write(&record.handshakeHeader(handshake_type, payload_len)); + self.pos += payload_len; + } + + /// Record payload is already written by using buffer space from `getPayload`. + /// Now when we know payload len we can write record header and advance over payload. + pub fn advanceRecord(self: *Writer, content_type: proto.ContentType, payload_len: usize) !void { + try self.write(&record.header(content_type, payload_len)); + self.pos += payload_len; + } + + pub fn writeRecord(self: *Writer, content_type: proto.ContentType, payload: []const u8) !void { + try self.write(&record.header(content_type, payload.len)); + try self.write(payload); + } + + /// Preserves space for record header and returns buffer free space. + pub fn getPayload(self: *Writer) []u8 { + return self.buf[self.pos + record.header_len ..]; + } + + /// Preserves space for handshake header and returns buffer free space. + pub fn getHandshakePayload(self: *Writer) []u8 { + return self.buf[self.pos + 4 ..]; + } + + pub fn getWritten(self: *Writer) []const u8 { + return self.buf[0..self.pos]; + } + + pub fn getFree(self: *Writer) []u8 { + return self.buf[self.pos..]; + } + + pub fn writeEnumArray(self: *Writer, comptime E: type, tags: []const E) !void { + assert(@sizeOf(E) == 2); + try self.writeInt(@as(u16, @intCast(tags.len * 2))); + for (tags) |t| { + try self.writeEnum(t); + } + } + + pub fn writeExtension( + self: *Writer, + comptime et: proto.Extension, + tags: anytype, + ) !void { + try self.writeEnum(et); + if (et == .supported_versions) { + try self.writeInt(@as(u16, @intCast(tags.len * 2 + 1))); + try self.writeInt(@as(u8, @intCast(tags.len * 2))); + } else { + try self.writeInt(@as(u16, @intCast(tags.len * 2 + 2))); + try self.writeInt(@as(u16, @intCast(tags.len * 2))); + } + for (tags) |t| { + try self.writeEnum(t); + } + } + + pub fn writeKeyShare( + self: *Writer, + named_groups: []const proto.NamedGroup, + keys: []const []const u8, + ) !void { + assert(named_groups.len == keys.len); + try self.writeEnum(proto.Extension.key_share); + var l: usize = 0; + for (keys) |key| { + l += key.len + 4; + } + try self.writeInt(@as(u16, @intCast(l + 2))); + try self.writeInt(@as(u16, @intCast(l))); + for (named_groups, 0..) |ng, i| { + const key = keys[i]; + try self.writeEnum(ng); + try self.writeInt(@as(u16, @intCast(key.len))); + try self.write(key); + } + } + + pub fn writeServerName(self: *Writer, host: []const u8) !void { + const host_len: u16 = @intCast(host.len); + try self.writeEnum(proto.Extension.server_name); + try self.writeInt(host_len + 5); // byte length of extension payload + try self.writeInt(host_len + 3); // server_name_list byte count + try self.writeByte(0); // name type + try self.writeInt(host_len); + try self.write(host); + } +}; + +test "Writer" { + var buf: [16]u8 = undefined; + var w = Writer{ .buf = &buf }; + + try w.write("ab"); + try w.writeEnum(proto.Curve.named_curve); + try w.writeEnum(proto.NamedGroup.x25519); + try w.writeInt(@as(u16, 0x1234)); + try testing.expectEqualSlices(u8, &[_]u8{ 'a', 'b', 0x03, 0x00, 0x1d, 0x12, 0x34 }, w.getWritten()); +} diff --git a/src/tls.zig/rsa/der.zig b/src/tls.zig/rsa/der.zig new file mode 100644 index 0000000..743a65a --- /dev/null +++ b/src/tls.zig/rsa/der.zig @@ -0,0 +1,467 @@ +//! An encoding of ASN.1. +//! +//! Distinguised Encoding Rules as defined in X.690 and X.691. +//! +//! A version of Basic Encoding Rules (BER) where there is exactly ONE way to +//! represent non-constructed elements. This is useful for cryptographic signatures. +//! +//! Currently an implementation detail of the standard library not fit for public +//! use since it's missing an encoder. + +const std = @import("std"); +const builtin = @import("builtin"); + +pub const Index = usize; +const log = std.log.scoped(.der); + +/// A secure DER parser that: +/// - Does NOT read memory outside `bytes`. +/// - Does NOT return elements with slices outside `bytes`. +/// - Errors on values that do NOT follow DER rules. +/// - Lengths that could be represented in a shorter form. +/// - Booleans that are not 0xff or 0x00. +pub const Parser = struct { + bytes: []const u8, + index: Index = 0, + + pub const Error = Element.Error || error{ + UnexpectedElement, + InvalidIntegerEncoding, + Overflow, + NonCanonical, + }; + + pub fn expectBool(self: *Parser) Error!bool { + const ele = try self.expect(.universal, false, .boolean); + if (ele.slice.len() != 1) return error.InvalidBool; + + return switch (self.view(ele)[0]) { + 0x00 => false, + 0xff => true, + else => error.InvalidBool, + }; + } + + pub fn expectBitstring(self: *Parser) Error!BitString { + const ele = try self.expect(.universal, false, .bitstring); + const bytes = self.view(ele); + const right_padding = bytes[0]; + if (right_padding >= 8) return error.InvalidBitString; + return .{ + .bytes = bytes[1..], + .right_padding = @intCast(right_padding), + }; + } + + // TODO: return high resolution date time type instead of epoch seconds + pub fn expectDateTime(self: *Parser) Error!i64 { + const ele = try self.expect(.universal, false, null); + const bytes = self.view(ele); + switch (ele.identifier.tag) { + .utc_time => { + // Example: "YYMMDD000000Z" + if (bytes.len != 13) + return error.InvalidDateTime; + if (bytes[12] != 'Z') + return error.InvalidDateTime; + + var date: Date = undefined; + date.year = try parseTimeDigits(bytes[0..2], 0, 99); + date.year += if (date.year >= 50) 1900 else 2000; + date.month = try parseTimeDigits(bytes[2..4], 1, 12); + date.day = try parseTimeDigits(bytes[4..6], 1, 31); + const time = try parseTime(bytes[6..12]); + + return date.toEpochSeconds() + time.toSec(); + }, + .generalized_time => { + // Examples: + // "19920622123421Z" + // "19920722132100.3Z" + if (bytes.len < 15) + return error.InvalidDateTime; + + var date: Date = undefined; + date.year = try parseYear4(bytes[0..4]); + date.month = try parseTimeDigits(bytes[4..6], 1, 12); + date.day = try parseTimeDigits(bytes[6..8], 1, 31); + const time = try parseTime(bytes[8..14]); + + return date.toEpochSeconds() + time.toSec(); + }, + else => return error.InvalidDateTime, + } + } + + pub fn expectOid(self: *Parser) Error![]const u8 { + const oid = try self.expect(.universal, false, .object_identifier); + return self.view(oid); + } + + pub fn expectEnum(self: *Parser, comptime Enum: type) Error!Enum { + const oid = try self.expectOid(); + return Enum.oids.get(oid) orelse { + if (builtin.mode == .Debug) { + var buf: [256]u8 = undefined; + var stream = std.io.fixedBufferStream(&buf); + try @import("./oid.zig").decode(oid, stream.writer()); + log.warn("unknown oid {s} for enum {s}\n", .{ stream.getWritten(), @typeName(Enum) }); + } + return error.UnknownObjectId; + }; + } + + pub fn expectInt(self: *Parser, comptime T: type) Error!T { + const ele = try self.expectPrimitive(.integer); + const bytes = self.view(ele); + + const info = @typeInfo(T); + if (info != .Int) @compileError(@typeName(T) ++ " is not an int type"); + const Shift = std.math.Log2Int(u8); + + var result: std.meta.Int(.unsigned, info.Int.bits) = 0; + for (bytes, 0..) |b, index| { + const shifted = @shlWithOverflow(b, @as(Shift, @intCast(index * 8))); + if (shifted[1] == 1) return error.Overflow; + + result |= shifted[0]; + } + + return @bitCast(result); + } + + pub fn expectString(self: *Parser, allowed: std.EnumSet(String.Tag)) Error!String { + const ele = try self.expect(.universal, false, null); + switch (ele.identifier.tag) { + inline .string_utf8, + .string_numeric, + .string_printable, + .string_teletex, + .string_videotex, + .string_ia5, + .string_visible, + .string_universal, + .string_bmp, + => |t| { + const tagname = @tagName(t)["string_".len..]; + const tag = std.meta.stringToEnum(String.Tag, tagname) orelse unreachable; + if (allowed.contains(tag)) { + return String{ .tag = tag, .data = self.view(ele) }; + } + }, + else => {}, + } + return error.UnexpectedElement; + } + + pub fn expectPrimitive(self: *Parser, tag: ?Identifier.Tag) Error!Element { + var elem = try self.expect(.universal, false, tag); + if (tag == .integer and elem.slice.len() > 0) { + if (self.view(elem)[0] == 0) elem.slice.start += 1; + if (elem.slice.len() > 0 and self.view(elem)[0] == 0) return error.InvalidIntegerEncoding; + } + return elem; + } + + /// Remember to call `expectEnd` + pub fn expectSequence(self: *Parser) Error!Element { + return try self.expect(.universal, true, .sequence); + } + + /// Remember to call `expectEnd` + pub fn expectSequenceOf(self: *Parser) Error!Element { + return try self.expect(.universal, true, .sequence_of); + } + + pub fn expectEnd(self: *Parser, val: usize) Error!void { + if (self.index != val) return error.NonCanonical; // either forgot to parse end OR an attacker + } + + pub fn expect( + self: *Parser, + class: ?Identifier.Class, + constructed: ?bool, + tag: ?Identifier.Tag, + ) Error!Element { + if (self.index >= self.bytes.len) return error.EndOfStream; + + const res = try Element.init(self.bytes, self.index); + if (tag) |e| { + if (res.identifier.tag != e) return error.UnexpectedElement; + } + if (constructed) |e| { + if (res.identifier.constructed != e) return error.UnexpectedElement; + } + if (class) |e| { + if (res.identifier.class != e) return error.UnexpectedElement; + } + self.index = if (res.identifier.constructed) res.slice.start else res.slice.end; + return res; + } + + pub fn view(self: Parser, elem: Element) []const u8 { + return elem.slice.view(self.bytes); + } + + pub fn seek(self: *Parser, index: usize) void { + self.index = index; + } + + pub fn eof(self: *Parser) bool { + return self.index == self.bytes.len; + } +}; + +pub const Element = struct { + identifier: Identifier, + slice: Slice, + + pub const Slice = struct { + start: Index, + end: Index, + + pub fn len(self: Slice) Index { + return self.end - self.start; + } + + pub fn view(self: Slice, bytes: []const u8) []const u8 { + return bytes[self.start..self.end]; + } + }; + + pub const Error = error{ InvalidLength, EndOfStream }; + + pub fn init(bytes: []const u8, index: Index) Error!Element { + var stream = std.io.fixedBufferStream(bytes[index..]); + var reader = stream.reader(); + + const identifier = @as(Identifier, @bitCast(try reader.readByte())); + const size_or_len_size = try reader.readByte(); + + var start = index + 2; + // short form between 0-127 + if (size_or_len_size < 128) { + const end = start + size_or_len_size; + if (end > bytes.len) return error.InvalidLength; + + return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; + } + + // long form between 0 and std.math.maxInt(u1024) + const len_size: u7 = @truncate(size_or_len_size); + start += len_size; + if (len_size > @sizeOf(Index)) return error.InvalidLength; + const len = try reader.readVarInt(Index, .big, len_size); + if (len < 128) return error.InvalidLength; // should have used short form + + const end = std.math.add(Index, start, len) catch return error.InvalidLength; + if (end > bytes.len) return error.InvalidLength; + + return .{ .identifier = identifier, .slice = .{ .start = start, .end = end } }; + } +}; + +test Element { + const short_form = [_]u8{ 0x30, 0x03, 0x02, 0x01, 0x09 }; + try std.testing.expectEqual(Element{ + .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, + .slice = .{ .start = 2, .end = short_form.len }, + }, Element.init(&short_form, 0)); + + const long_form = [_]u8{ 0x30, 129, 129 } ++ [_]u8{0} ** 129; + try std.testing.expectEqual(Element{ + .identifier = Identifier{ .tag = .sequence, .constructed = true, .class = .universal }, + .slice = .{ .start = 3, .end = long_form.len }, + }, Element.init(&long_form, 0)); +} + +test "parser.expectInt" { + const one = [_]u8{ 2, 1, 1 }; + var parser = Parser{ .bytes = &one }; + try std.testing.expectEqual(@as(u8, 1), try parser.expectInt(u8)); +} + +pub const Identifier = packed struct(u8) { + tag: Tag, + constructed: bool, + class: Class, + + pub const Class = enum(u2) { + universal, + application, + context_specific, + private, + }; + + // https://www.oss.com/asn1/resources/asn1-made-simple/asn1-quick-reference/asn1-tags.html + pub const Tag = enum(u5) { + boolean = 1, + integer = 2, + bitstring = 3, + octetstring = 4, + null = 5, + object_identifier = 6, + real = 9, + enumerated = 10, + string_utf8 = 12, + sequence = 16, + sequence_of = 17, + string_numeric = 18, + string_printable = 19, + string_teletex = 20, + string_videotex = 21, + string_ia5 = 22, + utc_time = 23, + generalized_time = 24, + string_visible = 26, + string_universal = 28, + string_bmp = 30, + _, + }; +}; + +pub const BitString = struct { + bytes: []const u8, + right_padding: u3, + + pub fn bitLen(self: BitString) usize { + return self.bytes.len * 8 + self.right_padding; + } +}; + +pub const String = struct { + tag: Tag, + data: []const u8, + + pub const Tag = enum { + /// Blessed. + utf8, + /// us-ascii ([-][0-9][eE][.])* + numeric, + /// us-ascii ([A-Z][a-z][0-9][.?!,][ \t])* + printable, + /// iso-8859-1 with escaping into different character sets. + /// Cursed. + teletex, + /// iso-8859-1 + videotex, + /// us-ascii first 128 characters. + ia5, + /// us-ascii without control characters. + visible, + /// utf-32-be + universal, + /// utf-16-be + bmp, + }; + + pub const all = [_]Tag{ + .utf8, + .numeric, + .printable, + .teletex, + .videotex, + .ia5, + .visible, + .universal, + .bmp, + }; +}; + +const Date = struct { + year: Year, + month: u8, + day: u8, + + const Year = std.time.epoch.Year; + + fn toEpochSeconds(date: Date) i64 { + // Euclidean Affine Transform by Cassio and Neri. + // Shift and correction constants for 1970-01-01. + const s = 82; + const K = 719468 + 146097 * s; + const L = 400 * s; + + const Y_G: u32 = date.year; + const M_G: u32 = date.month; + const D_G: u32 = date.day; + // Map to computational calendar. + const J: u32 = if (M_G <= 2) 1 else 0; + const Y: u32 = Y_G + L - J; + const M: u32 = if (J != 0) M_G + 12 else M_G; + const D: u32 = D_G - 1; + const C: u32 = Y / 100; + + // Rata die. + const y_star: u32 = 1461 * Y / 4 - C + C / 4; + const m_star: u32 = (979 * M - 2919) / 32; + const N: u32 = y_star + m_star + D; + const days: i32 = @intCast(N - K); + + return @as(i64, days) * std.time.epoch.secs_per_day; + } +}; + +const Time = struct { + hour: std.math.IntFittingRange(0, 24), + minute: std.math.IntFittingRange(0, 60), + second: std.math.IntFittingRange(0, 60), + + fn toSec(t: Time) i64 { + var sec: i64 = 0; + sec += @as(i64, t.hour) * 60 * 60; + sec += @as(i64, t.minute) * 60; + sec += t.second; + return sec; + } +}; + +fn parseTimeDigits( + text: *const [2]u8, + min: comptime_int, + max: comptime_int, +) !std.math.IntFittingRange(min, max) { + const result = std.fmt.parseInt(std.math.IntFittingRange(min, max), text, 10) catch + return error.InvalidTime; + if (result < min) return error.InvalidTime; + if (result > max) return error.InvalidTime; + return result; +} + +test parseTimeDigits { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(u8, 0), try parseTimeDigits("00", 0, 99)); + try expectEqual(@as(u8, 99), try parseTimeDigits("99", 0, 99)); + try expectEqual(@as(u8, 42), try parseTimeDigits("42", 0, 99)); + + const expectError = std.testing.expectError; + try expectError(error.InvalidTime, parseTimeDigits("13", 1, 12)); + try expectError(error.InvalidTime, parseTimeDigits("00", 1, 12)); + try expectError(error.InvalidTime, parseTimeDigits("Di", 0, 99)); +} + +fn parseYear4(text: *const [4]u8) !Date.Year { + const result = std.fmt.parseInt(Date.Year, text, 10) catch return error.InvalidYear; + if (result > 9999) return error.InvalidYear; + return result; +} + +test parseYear4 { + const expectEqual = std.testing.expectEqual; + try expectEqual(@as(Date.Year, 0), try parseYear4("0000")); + try expectEqual(@as(Date.Year, 9999), try parseYear4("9999")); + try expectEqual(@as(Date.Year, 1988), try parseYear4("1988")); + + const expectError = std.testing.expectError; + try expectError(error.InvalidYear, parseYear4("999b")); + try expectError(error.InvalidYear, parseYear4("crap")); + try expectError(error.InvalidYear, parseYear4("r:bQ")); +} + +fn parseTime(bytes: *const [6]u8) !Time { + return .{ + .hour = try parseTimeDigits(bytes[0..2], 0, 23), + .minute = try parseTimeDigits(bytes[2..4], 0, 59), + .second = try parseTimeDigits(bytes[4..6], 0, 59), + }; +} diff --git a/src/tls.zig/rsa/oid.zig b/src/tls.zig/rsa/oid.zig new file mode 100644 index 0000000..fd360c3 --- /dev/null +++ b/src/tls.zig/rsa/oid.zig @@ -0,0 +1,132 @@ +//! Developed by ITU-U and ISO/IEC for naming objects. Used in DER. +//! +//! This implementation supports any number of `u32` arcs. + +const Arc = u32; +const encoding_base = 128; + +/// Returns encoded length. +pub fn encodeLen(dot_notation: []const u8) !usize { + var split = std.mem.splitScalar(u8, dot_notation, '.'); + if (split.next() == null) return 0; + if (split.next() == null) return 1; + + var res: usize = 1; + while (split.next()) |s| { + const parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + res += n_bytes; + res += 1; + } + + return res; +} + +pub const EncodeError = std.fmt.ParseIntError || error{ + MissingPrefix, + BufferTooSmall, +}; + +pub fn encode(dot_notation: []const u8, buf: []u8) EncodeError![]const u8 { + if (buf.len < try encodeLen(dot_notation)) return error.BufferTooSmall; + + var split = std.mem.splitScalar(u8, dot_notation, '.'); + const first_str = split.next() orelse return error.MissingPrefix; + const second_str = split.next() orelse return error.MissingPrefix; + + const first = try std.fmt.parseInt(u8, first_str, 10); + const second = try std.fmt.parseInt(u8, second_str, 10); + + buf[0] = first * 40 + second; + + var i: usize = 1; + while (split.next()) |s| { + var parsed = try std.fmt.parseUnsigned(Arc, s, 10); + const n_bytes = if (parsed == 0) 0 else std.math.log(Arc, encoding_base, parsed); + + for (0..n_bytes) |j| { + const place = std.math.pow(Arc, encoding_base, n_bytes - @as(Arc, @intCast(j))); + const digit: u8 = @intCast(@divFloor(parsed, place)); + + buf[i] = digit | 0x80; + parsed -= digit * place; + + i += 1; + } + buf[i] = @intCast(parsed); + i += 1; + } + + return buf[0..i]; +} + +pub fn decode(encoded: []const u8, writer: anytype) @TypeOf(writer).Error!void { + const first = @divTrunc(encoded[0], 40); + const second = encoded[0] - first * 40; + try writer.print("{d}.{d}", .{ first, second }); + + var i: usize = 1; + while (i != encoded.len) { + const n_bytes: usize = brk: { + var res: usize = 1; + var j: usize = i; + while (encoded[j] & 0x80 != 0) { + res += 1; + j += 1; + } + break :brk res; + }; + + var n: usize = 0; + for (0..n_bytes) |j| { + const place = std.math.pow(usize, encoding_base, n_bytes - j - 1); + n += place * (encoded[i] & 0b01111111); + i += 1; + } + try writer.print(".{d}", .{n}); + } +} + +pub fn encodeComptime(comptime dot_notation: []const u8) [encodeLen(dot_notation) catch unreachable]u8 { + @setEvalBranchQuota(10_000); + var buf: [encodeLen(dot_notation) catch unreachable]u8 = undefined; + _ = encode(dot_notation, &buf) catch unreachable; + return buf; +} + +const std = @import("std"); + +fn testOid(expected_encoded: []const u8, expected_dot_notation: []const u8) !void { + var buf: [256]u8 = undefined; + const encoded = try encode(expected_dot_notation, &buf); + try std.testing.expectEqualSlices(u8, expected_encoded, encoded); + + var stream = std.io.fixedBufferStream(&buf); + try decode(expected_encoded, stream.writer()); + try std.testing.expectEqualStrings(expected_dot_notation, stream.getWritten()); +} + +test "encode and decode" { + // https://learn.microsoft.com/en-us/windows/win32/seccertenroll/about-object-identifier + try testOid( + &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, + "1.3.6.1.4.1.311.21.20", + ); + // https://luca.ntop.org/Teaching/Appunti/asn1.html + try testOid(&[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d }, "1.2.840.113549"); + // https://www.sysadmins.lv/blog-en/how-to-encode-object-identifier-to-an-asn1-der-encoded-string.aspx + try testOid(&[_]u8{ 0x2a, 0x86, 0x8d, 0x20 }, "1.2.100000"); + try testOid( + &[_]u8{ 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b }, + "1.2.840.113549.1.1.11", + ); + try testOid(&[_]u8{ 0x2b, 0x65, 0x70 }, "1.3.101.112"); +} + +test encodeComptime { + try std.testing.expectEqual( + [_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x14 }, + encodeComptime("1.3.6.1.4.1.311.21.20"), + ); +} diff --git a/src/tls.zig/rsa/rsa.zig b/src/tls.zig/rsa/rsa.zig new file mode 100644 index 0000000..5e5f42f --- /dev/null +++ b/src/tls.zig/rsa/rsa.zig @@ -0,0 +1,880 @@ +//! RFC8017: Public Key Cryptography Standards #1 v2.2 (PKCS1) +const std = @import("std"); +const der = @import("der.zig"); +const ff = std.crypto.ff; + +pub const max_modulus_bits = 4096; +const max_modulus_len = max_modulus_bits / 8; + +const Modulus = std.crypto.ff.Modulus(max_modulus_bits); +const Fe = Modulus.Fe; + +pub const ValueError = error{ + Modulus, + Exponent, +}; + +pub const PublicKey = struct { + /// `n` + modulus: Modulus, + /// `e` + public_exponent: Fe, + + pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError || ff.InvalidModulusError || error{InsecureBitCount}; + + pub fn fromBytes(mod: []const u8, exp: []const u8) FromBytesError!PublicKey { + const modulus = try Modulus.fromBytes(mod, .big); + if (modulus.bits() <= 512) return error.InsecureBitCount; + const public_exponent = try Fe.fromBytes(modulus, exp, .big); + + if (std.debug.runtime_safety) { + // > the RSA public exponent e is an integer between 3 and n - 1 satisfying + // > GCD(e,\lambda(n)) = 1, where \lambda(n) = LCM(r_1 - 1, ..., r_u - 1) + const e_v = public_exponent.toPrimitive(u32) catch return error.Exponent; + if (!public_exponent.isOdd()) return error.Exponent; + if (e_v < 3) return error.Exponent; + if (modulus.v.compare(public_exponent.v) == .lt) return error.Exponent; + } + + return .{ .modulus = modulus, .public_exponent = public_exponent }; + } + + pub fn fromDer(bytes: []const u8) (der.Parser.Error || FromBytesError)!PublicKey { + var parser = der.Parser{ .bytes = bytes }; + + const seq = try parser.expectSequence(); + defer parser.seek(seq.slice.end); + + const modulus = try parser.expectPrimitive(.integer); + const pub_exp = try parser.expectPrimitive(.integer); + + try parser.expectEnd(seq.slice.end); + try parser.expectEnd(bytes.len); + + return try fromBytes(parser.view(modulus), parser.view(pub_exp)); + } + + /// Deprecated. + /// + /// Encrypt a short message using RSAES-PKCS1-v1_5. + /// The use of this scheme for encrypting an arbitrary message, as opposed to a + /// randomly generated key, is NOT RECOMMENDED. + pub fn encryptPkcsv1_5(pk: PublicKey, msg: []const u8, out: []u8) ![]const u8 { + // align variable names with spec + const k = byteLen(pk.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + if (msg.len > k - 11) return error.MessageTooLong; + + // EM = 0x00 || 0x02 || PS || 0x00 || M. + var em = out[0..k]; + em[0] = 0; + em[1] = 2; + + const ps = em[2..][0 .. k - msg.len - 3]; + // Section: 7.2.1 + // PS consists of pseudo-randomly generated nonzero octets. + for (ps) |*v| { + v.* = std.crypto.random.uintLessThan(u8, 0xff) + 1; + } + + em[em.len - msg.len - 1] = 0; + @memcpy(em[em.len - msg.len ..][0..msg.len], msg); + + const m = try Fe.fromBytes(pk.modulus, em, .big); + const e = try pk.modulus.powPublic(m, pk.public_exponent); + try e.toBytes(em, .big); + return em; + } + + /// Encrypt a short message using Optimal Asymmetric Encryption Padding (RSAES-OAEP). + pub fn encryptOaep( + pk: PublicKey, + comptime Hash: type, + msg: []const u8, + label: []const u8, + out: []u8, + ) ![]const u8 { + // align variable names with spec + const k = byteLen(pk.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + if (msg.len > k - 2 * Hash.digest_length - 2) return error.MessageTooLong; + + // EM = 0x00 || maskedSeed || maskedDB. + var em = out[0..k]; + em[0] = 0; + const seed = em[1..][0..Hash.digest_length]; + std.crypto.random.bytes(seed); + + // DB = lHash || PS || 0x01 || M. + var db = em[1 + seed.len ..]; + const lHash = labelHash(Hash, label); + @memcpy(db[0..lHash.len], &lHash); + @memset(db[lHash.len .. db.len - msg.len - 2], 0); + db[db.len - msg.len - 1] = 1; + @memcpy(db[db.len - msg.len ..], msg); + + var mgf_buf: [max_modulus_len]u8 = undefined; + + const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); + for (seed, seed_mask) |*v, m| v.* ^= m; + + const m = try Fe.fromBytes(pk.modulus, em, .big); + const e = try pk.modulus.powPublic(m, pk.public_exponent); + try e.toBytes(em, .big); + return em; + } +}; + +pub fn byteLen(bits: usize) usize { + return std.math.divCeil(usize, bits, 8) catch unreachable; +} + +pub const SecretKey = struct { + /// `d` + private_exponent: Fe, + + pub const FromBytesError = ValueError || ff.OverflowError || ff.FieldElementError; + + pub fn fromBytes(n: Modulus, exp: []const u8) FromBytesError!SecretKey { + const d = try Fe.fromBytes(n, exp, .big); + if (std.debug.runtime_safety) { + // > The RSA private exponent d is a positive integer less than n + // > satisfying e * d == 1 (mod \lambda(n)), + if (!d.isOdd()) return error.Exponent; + if (d.v.compare(n.v) != .lt) return error.Exponent; + } + + return .{ .private_exponent = d }; + } +}; + +pub const KeyPair = struct { + public: PublicKey, + secret: SecretKey, + + pub const FromDerError = PublicKey.FromBytesError || SecretKey.FromBytesError || der.Parser.Error || error{ KeyMismatch, InvalidVersion }; + + pub fn fromDer(bytes: []const u8) FromDerError!KeyPair { + var parser = der.Parser{ .bytes = bytes }; + const seq = try parser.expectSequence(); + const version = try parser.expectInt(u8); + + const mod = try parser.expectPrimitive(.integer); + const pub_exp = try parser.expectPrimitive(.integer); + const sec_exp = try parser.expectPrimitive(.integer); + + const public = try PublicKey.fromBytes(parser.view(mod), parser.view(pub_exp)); + const secret = try SecretKey.fromBytes(public.modulus, parser.view(sec_exp)); + + const prime1 = try parser.expectPrimitive(.integer); + const prime2 = try parser.expectPrimitive(.integer); + const exp1 = try parser.expectPrimitive(.integer); + const exp2 = try parser.expectPrimitive(.integer); + const coeff = try parser.expectPrimitive(.integer); + _ = .{ exp1, exp2, coeff }; + + switch (version) { + 0 => {}, + 1 => { + _ = try parser.expectSequenceOf(); + while (!parser.eof()) { + _ = try parser.expectSequence(); + const ri = try parser.expectPrimitive(.integer); + const di = try parser.expectPrimitive(.integer); + const ti = try parser.expectPrimitive(.integer); + _ = .{ ri, di, ti }; + } + }, + else => return error.InvalidVersion, + } + + try parser.expectEnd(seq.slice.end); + try parser.expectEnd(bytes.len); + + if (std.debug.runtime_safety) { + const p = try Fe.fromBytes(public.modulus, parser.view(prime1), .big); + const q = try Fe.fromBytes(public.modulus, parser.view(prime2), .big); + + // check that n = p * q + const expected_zero = public.modulus.mul(p, q); + if (!expected_zero.isZero()) return error.KeyMismatch; + + // TODO: check that d * e is one mod p-1 and mod q-1. Note d and e were bound + // const de = secret.private_exponent.mul(public.public_exponent); + // const one = public.modulus.one(); + + // if (public.modulus.mul(de, p).compare(one) != .eq) return error.KeyMismatch; + // if (public.modulus.mul(de, q).compare(one) != .eq) return error.KeyMismatch; + } + + return .{ .public = public, .secret = secret }; + } + + /// Deprecated. + pub fn signPkcsv1_5(kp: KeyPair, comptime Hash: type, msg: []const u8, out: []u8) !PKCS1v1_5(Hash).Signature { + var st = try signerPkcsv1_5(kp, Hash); + st.update(msg); + return try st.finalize(out); + } + + /// Deprecated. + pub fn signerPkcsv1_5(kp: KeyPair, comptime Hash: type) !PKCS1v1_5(Hash).Signer { + return PKCS1v1_5(Hash).Signer.init(kp); + } + + /// Deprecated. + pub fn decryptPkcsv1_5(kp: KeyPair, ciphertext: []const u8, out: []u8) ![]const u8 { + const k = byteLen(kp.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + const em = out[0..k]; + + const m = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); + const e = try kp.public.modulus.pow(m, kp.secret.private_exponent); + try e.toBytes(em, .big); + + // Care shall be taken to ensure that an opponent cannot + // distinguish these error conditions, whether by error + // message or timing. + const msg_start = ct.lastIndexOfScalar(em, 0) orelse em.len; + const ps_len = em.len - msg_start; + if (ct.@"or"(em[0] != 0, ct.@"or"(em[1] != 2, ps_len < 8))) { + return error.Inconsistent; + } + + return em[msg_start + 1 ..]; + } + + pub fn signOaep( + kp: KeyPair, + comptime Hash: type, + msg: []const u8, + salt: ?[]const u8, + out: []u8, + ) !Pss(Hash).Signature { + var st = try signerOaep(kp, Hash, salt); + st.update(msg); + return try st.finalize(out); + } + + /// Salt must outlive returned `PSS.Signer`. + pub fn signerOaep(kp: KeyPair, comptime Hash: type, salt: ?[]const u8) !Pss(Hash).Signer { + return Pss(Hash).Signer.init(kp, salt); + } + + pub fn decryptOaep( + kp: KeyPair, + comptime Hash: type, + ciphertext: []const u8, + label: []const u8, + out: []u8, + ) ![]u8 { + // align variable names with spec + const k = byteLen(kp.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + const mod = try Fe.fromBytes(kp.public.modulus, ciphertext, .big); + const exp = kp.public.modulus.pow(mod, kp.secret.private_exponent) catch unreachable; + const em = out[0..k]; + try exp.toBytes(em, .big); + + const y = em[0]; + const seed = em[1..][0..Hash.digest_length]; + const db = em[1 + Hash.digest_length ..]; + + var mgf_buf: [max_modulus_len]u8 = undefined; + + const seed_mask = mgf1(Hash, db, mgf_buf[0..seed.len]); + for (seed, seed_mask) |*v, m| v.* ^= m; + + const db_mask = mgf1(Hash, seed, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + const expected_hash = labelHash(Hash, label); + const actual_hash = db[0..expected_hash.len]; + + // Care shall be taken to ensure that an opponent cannot + // distinguish these error conditions, whether by error + // message or timing. + const msg_start = ct.indexOfScalarPos(em, expected_hash.len + 1, 1) orelse 0; + if (ct.@"or"(y != 0, ct.@"or"(msg_start == 0, !ct.memEql(&expected_hash, actual_hash)))) { + return error.Inconsistent; + } + + return em[msg_start + 1 ..]; + } + + /// Encrypt short plaintext with secret key. + pub fn encrypt(kp: KeyPair, plaintext: []const u8, out: []u8) !void { + const n = kp.public.modulus; + const k = byteLen(n.bits()); + if (plaintext.len > k) return error.MessageTooLong; + + const msg_as_int = try Fe.fromBytes(n, plaintext, .big); + const enc_as_int = try n.pow(msg_as_int, kp.secret.private_exponent); + try enc_as_int.toBytes(out, .big); + } +}; + +/// Deprecated. +/// +/// Signature Scheme with Appendix v1.5 (RSASSA-PKCS1-v1_5) +/// +/// This standard has been superceded by PSS which is formally proven secure +/// and has fewer footguns. +pub fn PKCS1v1_5(comptime Hash: type) type { + return struct { + const PkcsT = @This(); + pub const Signature = struct { + bytes: []const u8, + + const Self = @This(); + + pub fn verifier(self: Self, public_key: PublicKey) !Verifier { + return Verifier.init(self, public_key); + } + + pub fn verify(self: Self, msg: []const u8, public_key: PublicKey) !void { + var st = Verifier.init(self, public_key); + st.update(msg); + return st.verify(); + } + }; + + pub const Signer = struct { + h: Hash, + key_pair: KeyPair, + + fn init(key_pair: KeyPair) Signer { + return .{ + .h = Hash.init(.{}), + .key_pair = key_pair, + }; + } + + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + pub fn finalize(self: *Signer, out: []u8) !PkcsT.Signature { + const k = byteLen(self.key_pair.public.modulus.bits()); + if (out.len < k) return error.BufferTooSmall; + + var hash: [Hash.digest_length]u8 = undefined; + self.h.final(&hash); + + const em = try emsaEncode(hash, out[0..k]); + try self.key_pair.encrypt(em, em); + return .{ .bytes = em }; + } + }; + + pub const Verifier = struct { + h: Hash, + sig: PkcsT.Signature, + public_key: PublicKey, + + fn init(sig: PkcsT.Signature, public_key: PublicKey) Verifier { + return Verifier{ + .h = Hash.init(.{}), + .sig = sig, + .public_key = public_key, + }; + } + + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + pub fn verify(self: *Verifier) !void { + const pk = self.public_key; + const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); + const emm = try pk.modulus.powPublic(s, pk.public_exponent); + + var em_buf: [max_modulus_len]u8 = undefined; + const em = em_buf[0..byteLen(pk.modulus.bits())]; + try emm.toBytes(em, .big); + + var hash: [Hash.digest_length]u8 = undefined; + self.h.final(&hash); + + // TODO: compare hash values instead of emsa values + const expected = try emsaEncode(hash, em); + + if (!std.mem.eql(u8, expected, em)) return error.Inconsistent; + } + }; + + /// PKCS Encrypted Message Signature Appendix + fn emsaEncode(hash: [Hash.digest_length]u8, out: []u8) ![]u8 { + const digest_header = comptime digestHeader(); + const tLen = digest_header.len + Hash.digest_length; + const emLen = out.len; + if (emLen < tLen + 11) return error.ModulusTooShort; + if (out.len < emLen) return error.BufferTooSmall; + + var res = out[0..emLen]; + res[0] = 0; + res[1] = 1; + const padding_len = emLen - tLen - 3; + @memset(res[2..][0..padding_len], 0xff); + res[2 + padding_len] = 0; + @memcpy(res[2 + padding_len + 1 ..][0..digest_header.len], digest_header); + @memcpy(res[res.len - hash.len ..], &hash); + + return res; + } + + /// DER encoded header. Sequence of digest algo + digest. + /// TODO: use a DER encoder instead + fn digestHeader() []const u8 { + const sha2 = std.crypto.hash.sha2; + // Section 9.2 Notes 1. + return switch (Hash) { + std.crypto.hash.Sha1 => &hexToBytes( + \\30 21 30 09 06 05 2b 0e 03 02 1a 05 00 04 14 + ), + sha2.Sha224 => &hexToBytes( + \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 04 + \\05 00 04 1c + ), + sha2.Sha256 => &hexToBytes( + \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 01 05 00 + \\04 20 + ), + sha2.Sha384 => &hexToBytes( + \\30 41 30 0d 06 09 60 86 48 01 65 03 04 02 02 05 00 + \\04 30 + ), + sha2.Sha512 => &hexToBytes( + \\30 51 30 0d 06 09 60 86 48 01 65 03 04 02 03 05 00 + \\04 40 + ), + // sha2.Sha512224 => &hexToBytes( + // \\30 2d 30 0d 06 09 60 86 48 01 65 03 04 02 05 + // \\05 00 04 1c + // ), + // sha2.Sha512256 => &hexToBytes( + // \\30 31 30 0d 06 09 60 86 48 01 65 03 04 02 06 + // \\05 00 04 20 + // ), + else => @compileError("unknown Hash " ++ @typeName(Hash)), + }; + } + }; +} + +/// Probabilistic Signature Scheme (RSASSA-PSS) +pub fn Pss(comptime Hash: type) type { + // RFC 4055 S3.1 + const default_salt_len = Hash.digest_length; + return struct { + pub const Signature = struct { + bytes: []const u8, + + const Self = @This(); + + pub fn verifier(self: Self, public_key: PublicKey) !Verifier { + return Verifier.init(self, public_key); + } + + pub fn verify(self: Self, msg: []const u8, public_key: PublicKey, salt_len: ?usize) !void { + var st = Verifier.init(self, public_key, salt_len orelse default_salt_len); + st.update(msg); + return st.verify(); + } + }; + + const PssT = @This(); + + pub const Signer = struct { + h: Hash, + key_pair: KeyPair, + salt: ?[]const u8, + + fn init(key_pair: KeyPair, salt: ?[]const u8) Signer { + return .{ + .h = Hash.init(.{}), + .key_pair = key_pair, + .salt = salt, + }; + } + + pub fn update(self: *Signer, data: []const u8) void { + self.h.update(data); + } + + pub fn finalize(self: *Signer, out: []u8) !PssT.Signature { + var hashed: [Hash.digest_length]u8 = undefined; + self.h.final(&hashed); + + const salt = if (self.salt) |s| s else brk: { + var res: [default_salt_len]u8 = undefined; + std.crypto.random.bytes(&res); + break :brk &res; + }; + + const em_bits = self.key_pair.public.modulus.bits() - 1; + const em = try emsaEncode(hashed, salt, em_bits, out); + try self.key_pair.encrypt(em, em); + return .{ .bytes = em }; + } + }; + + pub const Verifier = struct { + h: Hash, + sig: PssT.Signature, + public_key: PublicKey, + salt_len: usize, + + fn init(sig: PssT.Signature, public_key: PublicKey, salt_len: usize) Verifier { + return Verifier{ + .h = Hash.init(.{}), + .sig = sig, + .public_key = public_key, + .salt_len = salt_len, + }; + } + + pub fn update(self: *Verifier, data: []const u8) void { + self.h.update(data); + } + + pub fn verify(self: *Verifier) !void { + const pk = self.public_key; + const s = try Fe.fromBytes(pk.modulus, self.sig.bytes, .big); + const emm = try pk.modulus.powPublic(s, pk.public_exponent); + + var em_buf: [max_modulus_len]u8 = undefined; + const em_bits = pk.modulus.bits() - 1; + const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; + var em = em_buf[0..em_len]; + try emm.toBytes(em, .big); + + if (em.len < Hash.digest_length + self.salt_len + 2) return error.Inconsistent; + if (em[em.len - 1] != 0xbc) return error.Inconsistent; + + const db = em[0 .. em.len - Hash.digest_length - 1]; + if (@clz(db[0]) < em.len * 8 - em_bits) return error.Inconsistent; + + const expected_hash = em[db.len..][0..Hash.digest_length]; + var mgf_buf: [max_modulus_len]u8 = undefined; + const db_mask = mgf1(Hash, expected_hash, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + for (1..db.len - self.salt_len - 1) |i| { + if (db[i] != 0) return error.Inconsistent; + } + if (db[db.len - self.salt_len - 1] != 1) return error.Inconsistent; + const salt = db[db.len - self.salt_len ..]; + var mp_buf: [max_modulus_len]u8 = undefined; + var mp = mp_buf[0 .. 8 + Hash.digest_length + self.salt_len]; + @memset(mp[0..8], 0); + self.h.final(mp[8..][0..Hash.digest_length]); + @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); + + var actual_hash: [Hash.digest_length]u8 = undefined; + Hash.hash(mp, &actual_hash, .{}); + + if (!std.mem.eql(u8, expected_hash, &actual_hash)) return error.Inconsistent; + } + }; + + /// PSS Encrypted Message Signature Appendix + fn emsaEncode(msg_hash: [Hash.digest_length]u8, salt: []const u8, em_bits: usize, out: []u8) ![]u8 { + const em_len = std.math.divCeil(usize, em_bits, 8) catch unreachable; + + if (em_len < Hash.digest_length + salt.len + 2) return error.Encoding; + + // EM = maskedDB || H || 0xbc + var em = out[0..em_len]; + em[em.len - 1] = 0xbc; + + var mp_buf: [max_modulus_len]u8 = undefined; + // M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt; + const mp = mp_buf[0 .. 8 + Hash.digest_length + salt.len]; + @memset(mp[0..8], 0); + @memcpy(mp[8..][0..Hash.digest_length], &msg_hash); + @memcpy(mp[8 + Hash.digest_length ..][0..salt.len], salt); + + // H = Hash(M') + const hash = em[em.len - 1 - Hash.digest_length ..][0..Hash.digest_length]; + Hash.hash(mp, hash, .{}); + + // DB = PS || 0x01 || salt + var db = em[0 .. em_len - Hash.digest_length - 1]; + @memset(db[0 .. db.len - salt.len - 1], 0); + db[db.len - salt.len - 1] = 1; + @memcpy(db[db.len - salt.len ..], salt); + + var mgf_buf: [max_modulus_len]u8 = undefined; + const db_mask = mgf1(Hash, hash, mgf_buf[0..db.len]); + for (db, db_mask) |*v, m| v.* ^= m; + + // Set the leftmost 8emLen - emBits bits of the leftmost octet + // in maskedDB to zero. + const shift = std.math.comptimeMod(8 * em_len - em_bits, 8); + const mask = @as(u8, 0xff) >> shift; + db[0] &= mask; + + return em; + } + }; +} + +/// Mask generation function. Currently the only one defined. +fn mgf1(comptime Hash: type, seed: []const u8, out: []u8) []u8 { + var c: [@sizeOf(u32)]u8 = undefined; + var tmp: [Hash.digest_length]u8 = undefined; + + var i: usize = 0; + var counter: u32 = 0; + while (i < out.len) : (counter += 1) { + var hasher = Hash.init(.{}); + hasher.update(seed); + std.mem.writeInt(u32, &c, counter, .big); + hasher.update(&c); + + const left = out.len - i; + if (left >= Hash.digest_length) { + // optimization: write straight to `out` + hasher.final(out[i..][0..Hash.digest_length]); + i += Hash.digest_length; + } else { + hasher.final(&tmp); + @memcpy(out[i..][0..left], tmp[0..left]); + i += left; + } + } + + return out; +} + +test mgf1 { + const Hash = std.crypto.hash.sha2.Sha256; + var out: [Hash.digest_length * 2 + 1]u8 = undefined; + try std.testing.expectEqualSlices( + u8, + &hexToBytes( + \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 + \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f + ), + mgf1(Hash, "asdf", out[0 .. Hash.digest_length - 1]), + ); + try std.testing.expectEqualSlices( + u8, + &hexToBytes( + \\ed 1b 84 6b b9 26 39 00 c8 17 82 ad 08 eb 17 01 + \\fa 8c 72 21 c6 57 63 77 31 7f 5c e8 09 89 9f 5a + \\22 F2 80 D5 28 08 F4 93 83 76 00 DE 09 E4 EC 92 + \\4A 2C 7C EF 0D F7 7B BE 8F 7F 12 CB 8F 33 A6 65 + \\AB + ), + mgf1(Hash, "asdf", &out), + ); +} + +/// For OAEP. +inline fn labelHash(comptime Hash: type, label: []const u8) [Hash.digest_length]u8 { + if (label.len == 0) { + // magic constants from NIST + const sha2 = std.crypto.hash.sha2; + switch (Hash) { + std.crypto.hash.Sha1 => return hexToBytes( + \\da39a3ee 5e6b4b0d 3255bfef 95601890 + \\afd80709 + ), + sha2.Sha256 => return hexToBytes( + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ), + sha2.Sha384 => return hexToBytes( + \\38b060a7 51ac9638 4cd9327e b1b1e36a + \\21fdb711 14be0743 4c0cc7bf 63f6e1da + \\274edebf e76f65fb d51ad2f1 4898b95b + ), + sha2.Sha512 => return hexToBytes( + \\cf83e135 7eefb8bd f1542850 d66d8007 + \\d620e405 0b5715dc 83f4a921 d36ce9ce + \\47d0d13c 5d85f2b0 ff8318d2 877eec2f + \\63b931bd 47417a81 a538327a f927da3e + ), + // just use the empty hash... + else => {}, + } + } + var res: [Hash.digest_length]u8 = undefined; + Hash.hash(label, &res, .{}); + return res; +} + +const ct = if (std.options.side_channels_mitigations == .none) ct_unprotected else ct_protected; + +const ct_unprotected = struct { + fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { + return std.mem.lastIndexOfScalar(u8, slice, value); + } + + fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { + return std.mem.indexOfScalarPos(u8, slice, start_index, value); + } + + fn memEql(a: []const u8, b: []const u8) bool { + return std.mem.eql(u8, a, b); + } + + fn @"and"(a: bool, b: bool) bool { + return a and b; + } + + fn @"or"(a: bool, b: bool) bool { + return a or b; + } +}; + +const ct_protected = struct { + fn lastIndexOfScalar(slice: []const u8, value: u8) ?usize { + var res: ?usize = null; + var i: usize = slice.len; + while (i != 0) { + i -= 1; + if (@intFromBool(res == null) & @intFromBool(slice[i] == value) == 1) res = i; + } + return res; + } + + fn indexOfScalarPos(slice: []const u8, start_index: usize, value: u8) ?usize { + var res: ?usize = null; + for (slice[start_index..], start_index..) |c, j| { + if (c == value) res = j; + } + return res; + } + + fn memEql(a: []const u8, b: []const u8) bool { + var res: u1 = 1; + for (a, b) |a_elem, b_elem| { + res &= @intFromBool(a_elem == b_elem); + } + return res == 1; + } + + fn @"and"(a: bool, b: bool) bool { + return (@intFromBool(a) & @intFromBool(b)) == 1; + } + + fn @"or"(a: bool, b: bool) bool { + return (@intFromBool(a) | @intFromBool(b)) == 1; + } +}; + +test ct { + const c = ct_unprotected; + try std.testing.expectEqual(true, c.@"or"(true, false)); + try std.testing.expectEqual(true, c.@"and"(true, true)); + try std.testing.expectEqual(true, c.memEql("Asdf", "Asdf")); + try std.testing.expectEqual(false, c.memEql("asdf", "Asdf")); + try std.testing.expectEqual(3, c.indexOfScalarPos("asdff", 1, 'f')); + try std.testing.expectEqual(4, c.lastIndexOfScalar("asdff", 'f')); +} + +fn removeNonHex(comptime hex: []const u8) []const u8 { + var res: [hex.len]u8 = undefined; + var i: usize = 0; + for (hex) |c| { + if (std.ascii.isHex(c)) { + res[i] = c; + i += 1; + } + } + return res[0..i]; +} + +/// For readable copy/pasting from hex viewers. +fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { + const hex2 = comptime removeNonHex(hex); + comptime var res: [hex2.len / 2]u8 = undefined; + _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; + return res; +} + +test hexToBytes { + const hex = + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ; + try std.testing.expectEqual( + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + hexToBytes(hex), + ); +} + +const TestHash = std.crypto.hash.sha2.Sha256; +fn testKeypair() !KeyPair { + const keypair_bytes = @embedFile("testdata/id_rsa.der"); + const kp = try KeyPair.fromDer(keypair_bytes); + try std.testing.expectEqual(2048, kp.public.modulus.bits()); + return kp; +} + +test "rsa PKCS1-v1_5 encrypt and decrypt" { + const kp = try testKeypair(); + + const msg = "rsa PKCS1-v1_5 encrypt and decrypt"; + var out: [max_modulus_len]u8 = undefined; + const enc = try kp.public.encryptPkcsv1_5(msg, &out); + + var out2: [max_modulus_len]u8 = undefined; + const dec = try kp.decryptPkcsv1_5(enc, &out2); + + try std.testing.expectEqualSlices(u8, msg, dec); +} + +test "rsa OAEP encrypt and decrypt" { + const kp = try testKeypair(); + + const msg = "rsa OAEP encrypt and decrypt"; + const label = ""; + var out: [max_modulus_len]u8 = undefined; + const enc = try kp.public.encryptOaep(TestHash, msg, label, &out); + + var out2: [max_modulus_len]u8 = undefined; + const dec = try kp.decryptOaep(TestHash, enc, label, &out2); + + try std.testing.expectEqualSlices(u8, msg, dec); +} + +test "rsa PKCS1-v1_5 signature" { + const kp = try testKeypair(); + + const msg = "rsa PKCS1-v1_5 signature"; + var out: [max_modulus_len]u8 = undefined; + + const signature = try kp.signPkcsv1_5(TestHash, msg, &out); + try signature.verify(msg, kp.public); +} + +test "rsa PSS signature" { + const kp = try testKeypair(); + + const msg = "rsa PSS signature"; + var out: [max_modulus_len]u8 = undefined; + + const salts = [_][]const u8{ "asdf", "" }; + for (salts) |salt| { + const signature = try kp.signOaep(TestHash, msg, salt, &out); + try signature.verify(msg, kp.public, salt.len); + } + + const signature = try kp.signOaep(TestHash, msg, null, &out); // random salt + try signature.verify(msg, kp.public, null); +} diff --git a/src/tls.zig/rsa/testdata/id_rsa.der b/src/tls.zig/rsa/testdata/id_rsa.der new file mode 100644 index 0000000..9e4f133 Binary files /dev/null and b/src/tls.zig/rsa/testdata/id_rsa.der differ diff --git a/src/tls.zig/testdata/ec_prime256v1_private_key.pem b/src/tls.zig/testdata/ec_prime256v1_private_key.pem new file mode 100644 index 0000000..67ebf38 --- /dev/null +++ b/src/tls.zig/testdata/ec_prime256v1_private_key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEINJSRKv8kSKEzLHptfAlg+LGh4/pHHlq0XLf30Q9pcztoAoGCCqGSM49 +AwEHoUQDQgAEJpmLyp8aGCgyMcFIJaIq/+4V1K6nPpeoih3bT2npeplF9eyXj7rm +8eW9Ua6VLhq71mqtMC+YLm+IkORBVq1cuA== +-----END EC PRIVATE KEY----- diff --git a/src/tls.zig/testdata/ec_private_key.pem b/src/tls.zig/testdata/ec_private_key.pem new file mode 100644 index 0000000..95048aa --- /dev/null +++ b/src/tls.zig/testdata/ec_private_key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDAQNT3KGxUdBqpxuO/z +GSJDePMgmB6xLytkfnHQMCqQquXrmcOQZT3BJhm+PwggmwGhZANiAATKxBc6kfqA +piA+Z0rIjVwaZaBNGnP4UZ5TqVewQ/dP9/BQCca2SJpsXauGLcUPmK4sKFQxGe6d +fzq9O50lo7qHEOIpwDBdhRp+oqB6sN2hMtCPbp6eyzsUlm3FUyhN9D0= +-----END PRIVATE KEY----- diff --git a/src/tls.zig/testdata/ec_secp384r1_private_key.pem b/src/tls.zig/testdata/ec_secp384r1_private_key.pem new file mode 100644 index 0000000..62eac9e --- /dev/null +++ b/src/tls.zig/testdata/ec_secp384r1_private_key.pem @@ -0,0 +1,6 @@ +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDDubYpeDdOwxksyQIDiOt6LHt3ikts2HNuR6rqhBg1CLdmp3AVDKfF4 +fPkIr8UDH22gBwYFK4EEACKhZANiAARcVFUVv3bIHS6BEfLt98rtps7XP1y26m2n +v5x/5ecbDH2p7AXBYerJERKFi7ZFE1DSrSAj+KK8otjdEG44ZA2Mtl5AHwDVrKde +RgtavVoreHhLN80jJOun8JnFXQjdNsA= +-----END EC PRIVATE KEY----- diff --git a/src/tls.zig/testdata/ec_secp521r1_private_key.pem b/src/tls.zig/testdata/ec_secp521r1_private_key.pem new file mode 100644 index 0000000..5b7f932 --- /dev/null +++ b/src/tls.zig/testdata/ec_secp521r1_private_key.pem @@ -0,0 +1,7 @@ +-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIB8C9axyQY6mgjjC6htLjc8hGylrDsh4BCv9669JaDj5vbxmCnTNlg +OuS6C9+uJNMbwm6CoIjB7RcgDTrxxX7oCyegBwYFK4EEACOhgYkDgYYABABAT5Q8 +aOj9U0iuJE5tXfKnYTgPuvD6keHZAGJ5veM9uR6jr3BhfGubD6bnlD+cIBQzYWo0 +y/BNMzCRJ55PDCNU5gGLw+vkwhJ1lGF5OS6l2oG5WN3fe6cYo+uJD7+PB3WYNIuX +Ls0oidsEM0Q4WLblQOEP6VLGf4qTcZyhoFWYfkjWiw== +-----END EC PRIVATE KEY----- diff --git a/src/tls.zig/testdata/google.com/client_random b/src/tls.zig/testdata/google.com/client_random new file mode 100644 index 0000000..e817c90 --- /dev/null +++ b/src/tls.zig/testdata/google.com/client_random @@ -0,0 +1 @@ +'”’ßqp0x­0)ì©–Ã~Ì+Œ`‡¬tY4•©D_ \ No newline at end of file diff --git a/src/tls.zig/testdata/google.com/server_hello b/src/tls.zig/testdata/google.com/server_hello new file mode 100644 index 0000000..57a8076 Binary files /dev/null and b/src/tls.zig/testdata/google.com/server_hello differ diff --git a/src/tls.zig/testdata/rsa_private_key.pem b/src/tls.zig/testdata/rsa_private_key.pem new file mode 100644 index 0000000..b8cc7a6 --- /dev/null +++ b/src/tls.zig/testdata/rsa_private_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDe9yPmdcxv3dVu +D4wJ+GLjYBvAfYzVBFAsNuI79zOfoRSvvs8aD0z1yzlwDjuX1iH3SJF5ynxo/Opi +oVpyT3hXDszyo1AF8UzKUXMQmhiOcfW0xz6+TO831IRLghzsCKPMBz1cC+WFP/62 +RHePPGovM8Nd9vIpRgQlfgXZ+DstpEBmnw1tGvq8CsWLhkMw7xQgQZ21zD5jtUgE +J8lc02IoX/W25HdJmayESqZnpZoaN8dgTLrBcM9XZEoh6gVTEOyUcUpDBIMAqloo +vPKMWBSS0oMX9HspD+eHokeyUxkSI/tLzlr4oYT5sfO/4/oQ+K2vh84DDqAsE3FX +xFVESETLAgMBAAECggEAUDuAmKqlEVAzQDKqAuB1vTpVYjQLnI+7xd1OFaQD2Jpf +VkqEPe1plT03AwKsIRw2BsT/TGM315PDSBCl+mJsfG9gAqQP5MOLDXa3wC6jTYbm +ktHr2xDWODHqFT3R6IHHZ2DnjJrfUc7QeogyucFUuH2Y/NQjGgUO8urhcikoKmi3 +kBiAHCHWNqhrSpzdFLifhe6VC/TGFwKqTepN+TnX3Z20HdL4kkYPGEGA9OonVSn4 +N1m/Q+yj6xm6vBMGlT0lS8lyz0EKb6rLedR7+rEJfOIvhVFEi8aXjkb5a6wIh5LO +rwu/jL0nUY8J5NP5BKz68gRwPtmmKBfCLXTpJUACSQKBgQDmPHEZkBC5wl8plx16 +hrwwSdJuQy0b6BYZO06gpYBOIIENULijKwZzoMYaL8zivGT3KIluEelA7+NXnCuk +NUx7LieeZ+ChIUuRLvT02H9lH11d1Va2PmBgRmUKgul26YyaxeIy3UzjPbbgUFJv +t970IRfgS8qGD9KuhdlovZlCzwKBgQD36mo4BxgO1xmr4Qq0WSgQi2QBMAP9lpE4 +Lc59UP5qvNrGXLGPsirdzz6VSeMGrrxGDyof+fGG9d0Wt0+8OMRysSuVua+SRiJ4 +ugoaCzLbsq6pzDWPXf/wzVevjKTIGh4ZXk6Qa7IHqyEmvOnvxdDsL3iZXgPcQoIF +HybqHU9NRQKBgC6tnGSJX8q5jJ+bAp//xxGnNeGi/vdEc46EBqntQ/kS//caIYT7 +SSCSPPe8Lzbc6T9u2YYWXYsL17TAddyh7bKfpeqottMUNAToV0N4zUNMO5q1kRH7 +zYBXZU7fQcQZD6elbPnRAjCkJ3qM7lm2Fp66QuP3mcTaWmWFv5FLt1HjAoGANVaF +y9Aa6PZ2W3hraSnVaNnUhjziXujKDaAtUODgG+7N0ueWfCgE+PvhpxTid0mY0Cnr +Ej4gLL0w9/YwfXppKZPcoLX2hC36tKayDbBjHMlwsq9wxoueyRwkxWwo97RGzYZw +uLmy79ttonv6iM+yh14fQD/t7LGSb6+oG656pVECgYEA0oya1vG0WL3K8ip8io4c +ovB2K1Uf7EyFzxJHJt6QpmXlPDKkwc6JzpKGJdCi09Pz49U63HodxahtB831rbAY +EduOUQ5scTKf66qA9/kEyClnwl14ZCds7/mu9ioZ7D0VNmWPFsYHaGKAUxsq97nb +xw9Y4zAdgbDcl1bzN9XCDKs= +-----END PRIVATE KEY----- diff --git a/src/tls.zig/testdata/tls12.zig b/src/tls.zig/testdata/tls12.zig new file mode 100644 index 0000000..e5bd1c4 --- /dev/null +++ b/src/tls.zig/testdata/tls12.zig @@ -0,0 +1,244 @@ +/// Messages from The Illustrated TLS 1.2 Connection +/// https://tls12.xargs.org/ +const hexToBytes = @import("../testu.zig").hexToBytes; + +pub const client_hello = hexToBytes( + \\ 16 03 01 00 a5 01 00 00 a1 03 03 00 01 02 03 04 + \\ 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 + \\ 15 16 17 18 19 1a 1b 1c 1d 1e 1f 00 00 20 cc a8 + \\ cc a9 c0 2f c0 30 c0 2b c0 2c c0 13 c0 09 c0 14 + \\ c0 0a 00 9c 00 9d 00 2f 00 35 c0 12 00 0a 01 00 + \\ 00 58 00 00 00 18 00 16 00 00 13 65 78 61 6d 70 + \\ 6c 65 2e 75 6c 66 68 65 69 6d 2e 6e 65 74 00 05 + \\ 00 05 01 00 00 00 00 00 0a 00 0a 00 08 00 1d 00 + \\ 17 00 18 00 19 00 0b 00 02 01 00 00 0d 00 12 00 + \\ 10 04 01 04 03 05 01 05 03 06 01 06 03 02 01 02 + \\ 03 ff 01 00 01 00 00 12 00 00 +); +pub const server_hello = hexToBytes( + \\ 16 03 03 00 31 02 00 00 2d 03 03 70 71 72 73 74 + \\ 75 76 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 84 + \\ 85 86 87 88 89 8a 8b 8c 8d 8e 8f 00 c0 13 00 00 + \\ 05 ff 01 00 01 00 +); +pub const server_certificate = hexToBytes( + \\ 16 03 03 03 2f 0b 00 03 2b 00 03 28 00 03 25 30 + \\ 82 03 21 30 82 02 09 a0 03 02 01 02 02 08 15 5a + \\ 92 ad c2 04 8f 90 30 0d 06 09 2a 86 48 86 f7 0d + \\ 01 01 0b 05 00 30 22 31 0b 30 09 06 03 55 04 06 + \\ 13 02 55 53 31 13 30 11 06 03 55 04 0a 13 0a 45 + \\ 78 61 6d 70 6c 65 20 43 41 30 1e 17 0d 31 38 31 + \\ 30 30 35 30 31 33 38 31 37 5a 17 0d 31 39 31 30 + \\ 30 35 30 31 33 38 31 37 5a 30 2b 31 0b 30 09 06 + \\ 03 55 04 06 13 02 55 53 31 1c 30 1a 06 03 55 04 + \\ 03 13 13 65 78 61 6d 70 6c 65 2e 75 6c 66 68 65 + \\ 69 6d 2e 6e 65 74 30 82 01 22 30 0d 06 09 2a 86 + \\ 48 86 f7 0d 01 01 01 05 00 03 82 01 0f 00 30 82 + \\ 01 0a 02 82 01 01 00 c4 80 36 06 ba e7 47 6b 08 + \\ 94 04 ec a7 b6 91 04 3f f7 92 bc 19 ee fb 7d 74 + \\ d7 a8 0d 00 1e 7b 4b 3a 4a e6 0f e8 c0 71 fc 73 + \\ e7 02 4c 0d bc f4 bd d1 1d 39 6b ba 70 46 4a 13 + \\ e9 4a f8 3d f3 e1 09 59 54 7b c9 55 fb 41 2d a3 + \\ 76 52 11 e1 f3 dc 77 6c aa 53 37 6e ca 3a ec be + \\ c3 aa b7 3b 31 d5 6c b6 52 9c 80 98 bc c9 e0 28 + \\ 18 e2 0b f7 f8 a0 3a fd 17 04 50 9e ce 79 bd 9f + \\ 39 f1 ea 69 ec 47 97 2e 83 0f b5 ca 95 de 95 a1 + \\ e6 04 22 d5 ee be 52 79 54 a1 e7 bf 8a 86 f6 46 + \\ 6d 0d 9f 16 95 1a 4c f7 a0 46 92 59 5c 13 52 f2 + \\ 54 9e 5a fb 4e bf d7 7a 37 95 01 44 e4 c0 26 87 + \\ 4c 65 3e 40 7d 7d 23 07 44 01 f4 84 ff d0 8f 7a + \\ 1f a0 52 10 d1 f4 f0 d5 ce 79 70 29 32 e2 ca be + \\ 70 1f df ad 6b 4b b7 11 01 f4 4b ad 66 6a 11 13 + \\ 0f e2 ee 82 9e 4d 02 9d c9 1c dd 67 16 db b9 06 + \\ 18 86 ed c1 ba 94 21 02 03 01 00 01 a3 52 30 50 + \\ 30 0e 06 03 55 1d 0f 01 01 ff 04 04 03 02 05 a0 + \\ 30 1d 06 03 55 1d 25 04 16 30 14 06 08 2b 06 01 + \\ 05 05 07 03 02 06 08 2b 06 01 05 05 07 03 01 30 + \\ 1f 06 03 55 1d 23 04 18 30 16 80 14 89 4f de 5b + \\ cc 69 e2 52 cf 3e a3 00 df b1 97 b8 1d e1 c1 46 + \\ 30 0d 06 09 2a 86 48 86 f7 0d 01 01 0b 05 00 03 + \\ 82 01 01 00 59 16 45 a6 9a 2e 37 79 e4 f6 dd 27 + \\ 1a ba 1c 0b fd 6c d7 55 99 b5 e7 c3 6e 53 3e ff + \\ 36 59 08 43 24 c9 e7 a5 04 07 9d 39 e0 d4 29 87 + \\ ff e3 eb dd 09 c1 cf 1d 91 44 55 87 0b 57 1d d1 + \\ 9b df 1d 24 f8 bb 9a 11 fe 80 fd 59 2b a0 39 8c + \\ de 11 e2 65 1e 61 8c e5 98 fa 96 e5 37 2e ef 3d + \\ 24 8a fd e1 74 63 eb bf ab b8 e4 d1 ab 50 2a 54 + \\ ec 00 64 e9 2f 78 19 66 0d 3f 27 cf 20 9e 66 7f + \\ ce 5a e2 e4 ac 99 c7 c9 38 18 f8 b2 51 07 22 df + \\ ed 97 f3 2e 3e 93 49 d4 c6 6c 9e a6 39 6d 74 44 + \\ 62 a0 6b 42 c6 d5 ba 68 8e ac 3a 01 7b dd fc 8e + \\ 2c fc ad 27 cb 69 d3 cc dc a2 80 41 44 65 d3 ae + \\ 34 8c e0 f3 4a b2 fb 9c 61 83 71 31 2b 19 10 41 + \\ 64 1c 23 7f 11 a5 d6 5c 84 4f 04 04 84 99 38 71 + \\ 2b 95 9e d6 85 bc 5c 5d d6 45 ed 19 90 94 73 40 + \\ 29 26 dc b4 0e 34 69 a1 59 41 e8 e2 cc a8 4b b6 + \\ 08 46 36 a0 +); +pub const server_key_exchange = hexToBytes( + \\ 16 03 03 01 2c 0c 00 01 28 03 00 1d 20 9f d7 ad + \\ 6d cf f4 29 8d d3 f9 6d 5b 1b 2a f9 10 a0 53 5b + \\ 14 88 d7 f8 fa bb 34 9a 98 28 80 b6 15 04 01 01 + \\ 00 04 02 b6 61 f7 c1 91 ee 59 be 45 37 66 39 bd + \\ c3 d4 bb 81 e1 15 ca 73 c8 34 8b 52 5b 0d 23 38 + \\ aa 14 46 67 ed 94 31 02 14 12 cd 9b 84 4c ba 29 + \\ 93 4a aa cc e8 73 41 4e c1 1c b0 2e 27 2d 0a d8 + \\ 1f 76 7d 33 07 67 21 f1 3b f3 60 20 cf 0b 1f d0 + \\ ec b0 78 de 11 28 be ba 09 49 eb ec e1 a1 f9 6e + \\ 20 9d c3 6e 4f ff d3 6b 67 3a 7d dc 15 97 ad 44 + \\ 08 e4 85 c4 ad b2 c8 73 84 12 49 37 25 23 80 9e + \\ 43 12 d0 c7 b3 52 2e f9 83 ca c1 e0 39 35 ff 13 + \\ a8 e9 6b a6 81 a6 2e 40 d3 e7 0a 7f f3 58 66 d3 + \\ d9 99 3f 9e 26 a6 34 c8 1b 4e 71 38 0f cd d6 f4 + \\ e8 35 f7 5a 64 09 c7 dc 2c 07 41 0e 6f 87 85 8c + \\ 7b 94 c0 1c 2e 32 f2 91 76 9e ac ca 71 64 3b 8b + \\ 98 a9 63 df 0a 32 9b ea 4e d6 39 7e 8c d0 1a 11 + \\ 0a b3 61 ac 5b ad 1c cd 84 0a 6c 8a 6e aa 00 1a + \\ 9d 7d 87 dc 33 18 64 35 71 22 6c 4d d2 c2 ac 41 + \\ fb +); +pub const server_hello_done = hexToBytes("16 03 03 00 04 0e 00 00 00 "); +pub const server_change_cipher_spec = hexToBytes("14 03 03 00 01 01 "); + +pub const server_handshake_finished = hexToBytes( + \\ 16 03 03 00 40 51 52 53 54 55 56 57 58 59 5a 5b + \\ 5c 5d 5e 5f 60 18 e0 75 31 7b 10 03 15 f6 08 1f + \\ cb f3 13 78 1a ac 73 ef e1 9f e2 5b a1 af 59 c2 + \\ 0b e9 4f c0 1b da 2d 68 00 29 8b 73 a7 e8 49 d7 + \\ 4b d4 94 cf 7d +); +pub const client_key_exchange_for_transcript = hexToBytes( + \\ 16 03 03 00 25 10 00 00 21 20 35 80 72 d6 36 58 + \\ 80 d1 ae ea 32 9a df 91 21 38 38 51 ed 21 a2 8e + \\ 3b 75 e9 65 d0 d2 cd 16 62 54 +); + +pub const server_hello_responses = server_hello ++ server_certificate ++ server_key_exchange ++ server_hello_done; + +pub const server_responses = server_hello_responses ++ server_change_cipher_spec ++ server_handshake_finished; + +pub const server_handshake_finished_msgs = server_change_cipher_spec ++ server_handshake_finished; + +pub const master_secret = hexToBytes( + \\ 91 6a bf 9d a5 59 73 e1 36 14 ae 0a 3f 5d 3f 37 + \\ b0 23 ba 12 9a ee 02 cc 91 34 33 81 27 cd 70 49 + \\ 78 1c 8e 19 fc 1e b2 a7 38 7a c0 6a e2 37 34 4c +); + +pub const client_key_exchange = hexToBytes( + \\ 16 03 03 00 25 10 00 00 21 20 00 01 02 03 04 05 + \\ 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 + \\ 16 17 18 19 1a 1b 1c 1d 1e 1f +); +pub const client_change_cyper_spec = hexToBytes("14 03 03 00 01 01 "); + +pub const client_handshake_finished = hexToBytes( + \\ 16 03 03 00 40 20 21 22 23 24 25 26 27 28 29 2a + \\ 2b 2c 2d 2e 2f a9 ac f5 5a f3 7a 90 17 63 ff 91 + \\ 68 9a b7 ee a0 d4 0c 1c ca 62 44 ef f3 0b a3 6d + \\ d0 df 86 3f 7d e3 98 d3 1a cc 37 6a e6 7a 00 6d + \\ 8c 08 bc 8a 5a +); + +pub const handshake_messages = [_][]const u8{ + &client_hello, + &server_hello, + &server_certificate, + &server_key_exchange, + &server_hello_done, + &client_key_exchange_for_transcript, +}; + +pub const client_finished = hexToBytes("14 00 00 0c cf 91 96 26 f1 36 0c 53 6a aa d7 3a "); + +// with iv 40 " ++ 41 ... 4f +// client_sequence = 0 +pub const verify_data_encrypted_msg = hexToBytes( + \\ 16 03 03 00 40 40 41 42 43 44 45 46 47 48 49 4a + \\ 4b 4c 4d 4e 4f 22 7b c9 ba 81 ef 30 f2 a8 a7 8f + \\ f1 df 50 84 4d 58 04 b7 ee b2 e2 14 c3 2b 68 92 + \\ ac a3 db 7b 78 07 7f dd 90 06 7c 51 6b ac b3 ba + \\ 90 de df 72 0f +); + +// with iv 00 " ++ 01 ... 1f +// client_sequence = 1 +pub const encrypted_ping_msg = hexToBytes( + \\ 17 03 03 00 30 00 01 02 03 04 05 06 07 08 09 0a + \\ 0b 0c 0d 0e 0f 6c 42 1c 71 c4 2b 18 3b fa 06 19 + \\ 5d 13 3d 0a 09 d0 0f c7 cb 4e 0f 5d 1c da 59 d1 + \\ 47 ec 79 0c 99 +); + +pub const key_material = hexToBytes( + \\ 1b 7d 11 7c 7d 5f 69 0b c2 63 ca e8 ef 60 af 0f + \\ 18 78 ac c2 2a d8 bd d8 c6 01 a6 17 12 6f 63 54 + \\ 0e b2 09 06 f7 81 fa d2 f6 56 d0 37 b1 73 ef 3e + \\ 11 16 9f 27 23 1a 84 b6 75 2a 18 e7 a9 fc b7 cb + \\ cd d8 f9 8d d8 f7 69 eb a0 d2 55 0c 92 38 ee bf + \\ ef 5c 32 25 1a bb 67 d6 43 45 28 db 49 37 d5 40 + \\ d3 93 13 5e 06 a1 1b b8 0e 45 ea eb e3 2c ac 72 + \\ 75 74 38 fb b3 df 64 5c bd a4 06 7c df a0 f8 48 +); + +pub const server_pong = hexToBytes( + \\ 17 03 03 00 30 61 62 63 64 65 66 67 68 69 6a 6b + \\ 6c 6d 6e 6f 70 97 83 48 8a f5 fa 20 bf 7a 2e f6 + \\ 9d eb b5 34 db 9f b0 7a 8c 27 21 de e5 40 9f 77 + \\ af 0c 3d de 56 +); + +pub const client_random = hexToBytes( + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f +); + +pub const server_random = hexToBytes( + \\ 70 71 72 73 74 75 76 77 78 79 7a 7b 7c 7d 7e 7f 80 81 82 83 84 85 86 87 88 89 8a 8b 8c 8d 8e 8f +); + +pub const client_secret = hexToBytes( + \\ 20 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 31 32 33 34 35 36 37 38 39 3a 3b 3c 3d 3e 3f +); + +pub const server_pub_key = hexToBytes( + \\ 9f d7 ad 6d cf f4 29 8d d3 f9 6d 5b 1b 2a f9 10 a0 53 5b 14 88 d7 f8 fa bb 34 9a 98 28 80 b6 15 +); + +pub const signature = hexToBytes( + \\ 04 02 b6 61 f7 c1 91 ee 59 be 45 37 66 39 bd c3 + \\ d4 bb 81 e1 15 ca 73 c8 34 8b 52 5b 0d 23 38 aa + \\ 14 46 67 ed 94 31 02 14 12 cd 9b 84 4c ba 29 93 + \\ 4a aa cc e8 73 41 4e c1 1c b0 2e 27 2d 0a d8 1f + \\ 76 7d 33 07 67 21 f1 3b f3 60 20 cf 0b 1f d0 ec + \\ b0 78 de 11 28 be ba 09 49 eb ec e1 a1 f9 6e 20 + \\ 9d c3 6e 4f ff d3 6b 67 3a 7d dc 15 97 ad 44 08 + \\ e4 85 c4 ad b2 c8 73 84 12 49 37 25 23 80 9e 43 + \\ 12 d0 c7 b3 52 2e f9 83 ca c1 e0 39 35 ff 13 a8 + \\ e9 6b a6 81 a6 2e 40 d3 e7 0a 7f f3 58 66 d3 d9 + \\ 99 3f 9e 26 a6 34 c8 1b 4e 71 38 0f cd d6 f4 e8 + \\ 35 f7 5a 64 09 c7 dc 2c 07 41 0e 6f 87 85 8c 7b + \\ 94 c0 1c 2e 32 f2 91 76 9e ac ca 71 64 3b 8b 98 + \\ a9 63 df 0a 32 9b ea 4e d6 39 7e 8c d0 1a 11 0a + \\ b3 61 ac 5b ad 1c cd 84 0a 6c 8a 6e aa 00 1a 9d + \\ 7d 87 dc 33 18 64 35 71 22 6c 4d d2 c2 ac 41 fb +); + +pub const cert_pub_key = hexToBytes( + \\ 30 82 01 0a 02 82 01 01 00 c4 80 36 06 ba e7 47 + \\ 6b 08 94 04 ec a7 b6 91 04 3f f7 92 bc 19 ee fb + \\ 7d 74 d7 a8 0d 00 1e 7b 4b 3a 4a e6 0f e8 c0 71 + \\ fc 73 e7 02 4c 0d bc f4 bd d1 1d 39 6b ba 70 46 + \\ 4a 13 e9 4a f8 3d f3 e1 09 59 54 7b c9 55 fb 41 + \\ 2d a3 76 52 11 e1 f3 dc 77 6c aa 53 37 6e ca 3a + \\ ec be c3 aa b7 3b 31 d5 6c b6 52 9c 80 98 bc c9 + \\ e0 28 18 e2 0b f7 f8 a0 3a fd 17 04 50 9e ce 79 + \\ bd 9f 39 f1 ea 69 ec 47 97 2e 83 0f b5 ca 95 de + \\ 95 a1 e6 04 22 d5 ee be 52 79 54 a1 e7 bf 8a 86 + \\ f6 46 6d 0d 9f 16 95 1a 4c f7 a0 46 92 59 5c 13 + \\ 52 f2 54 9e 5a fb 4e bf d7 7a 37 95 01 44 e4 c0 + \\ 26 87 4c 65 3e 40 7d 7d 23 07 44 01 f4 84 ff d0 + \\ 8f 7a 1f a0 52 10 d1 f4 f0 d5 ce 79 70 29 32 e2 + \\ ca be 70 1f df ad 6b 4b b7 11 01 f4 4b ad 66 6a + \\ 11 13 0f e2 ee 82 9e 4d 02 9d c9 1c dd 67 16 db + \\ b9 06 18 86 ed c1 ba 94 21 02 03 01 00 01 +); diff --git a/src/tls.zig/testdata/tls13.zig b/src/tls.zig/testdata/tls13.zig new file mode 100644 index 0000000..f98f9ff --- /dev/null +++ b/src/tls.zig/testdata/tls13.zig @@ -0,0 +1,64 @@ +const hexToBytes = @import("../testu.zig").hexToBytes; + +pub const client_hello = + hexToBytes("16030100f8010000f40303000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff000813021303130100ff010000a30000001800160000136578616d706c652e756c666865696d2e6e6574000b000403000102000a00160014001d0017001e0019001801000101010201030104002300000016000000170000000d001e001c040305030603080708080809080a080b080408050806040105010601002b0003020304002d00020101003300260024001d0020358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd166254"); + +pub const server_hello = + hexToBytes("160303007a") ++ // record header + hexToBytes("020000760303") ++ // handshake header, server version + hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f") ++ // server_random + hexToBytes("20e0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff") ++ // session id + hexToBytes("130200") ++ // cipher suite, compression method + hexToBytes("002e002b00020304") ++ // extensions, supported version + hexToBytes("00330024001d00209fd7ad6dcff4298dd3f96d5b1b2af910a0535b1488d7f8fabb349a982880b615"); // extension key share + +pub const client_random = hexToBytes( + \\ 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f +); + +pub const server_random = + hexToBytes("707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f"); +pub const server_pub_key = + hexToBytes("9fd7ad6dcff4298dd3f96d5b1b2af910a0535b1488d7f8fabb349a982880b615"); +pub const client_private_key = + hexToBytes("202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f"); +pub const client_public_key = + hexToBytes("358072d6365880d1aeea329adf9121383851ed21a28e3b75e965d0d2cd166254"); + +pub const shared_key = hexToBytes("df4a291baa1eb7cfa6934b29b474baad2697e29f1f920dcc77c8a0a088447624"); +pub const server_handshake_key = hexToBytes("9f13575ce3f8cfc1df64a77ceaffe89700b492ad31b4fab01c4792be1b266b7f"); +pub const server_handshake_iv = hexToBytes("9563bc8b590f671f488d2da3"); +pub const client_handshake_key = hexToBytes("1135b4826a9a70257e5a391ad93093dfd7c4214812f493b3e3daae1eb2b1ac69"); +pub const client_handshake_iv = hexToBytes("4256d2e0e88babdd05eb2f27"); + +pub const server_application_key = hexToBytes("01f78623f17e3edcc09e944027ba3218d57c8e0db93cd3ac419309274700ac27"); +pub const server_application_iv = hexToBytes("196a750b0c5049c0cc51a541"); +pub const client_application_key = hexToBytes("de2f4c7672723a692319873e5c227606691a32d1c59d8b9f51dbb9352e9ca9cc"); +pub const client_application_iv = hexToBytes("bb007956f474b25de902432f"); + +pub const server_encrypted_extensions_wrapped = + hexToBytes("17030300176be02f9da7c2dc9ddef56f2468b90adfa25101ab0344ae"); +pub const server_encrypted_extensions = + hexToBytes("080000020000"); + +pub const server_certificate_wrapped = + hexToBytes("1703030343baf00a9be50f3f2307e726edcbdacbe4b18616449d46c6207af6e9953ee5d2411ba65d31feaf4f78764f2d693987186cc01329c187a5e4608e8d27b318e98dd94769f7739ce6768392caca8dcc597d77ec0d1272233785f6e69d6f43effa8e7905edfdc4037eee5933e990a7972f206913a31e8d04931366d3d8bcd6a4a4d647dd4bd80b0ff863ce3554833d744cf0e0b9c07cae726dd23f9953df1f1ce3aceb3b7230871e92310cfb2b098486f43538f8e82d8404e5c6c25f66a62ebe3c5f26232640e20a769175ef83483cd81e6cb16e78dfad4c1b714b04b45f6ac8d1065ad18c13451c9055c47da300f93536ea56f531986d6492775393c4ccb095467092a0ec0b43ed7a0687cb470ce350917b0ac30c6e5c24725a78c45f9f5f29b6626867f6f79ce054273547b36df030bd24af10d632dba54fc4e890bd0586928c0206ca2e28e44e227a2d5063195935df38da8936092eef01e84cad2e49d62e470a6c7745f625ec39e4fc23329c79d1172876807c36d736ba42bb69b004ff55f93850dc33c1f98abb92858324c76ff1eb085db3c1fc50f74ec04442e622973ea70743418794c388140bb492d6294a0540e5a59cfae60ba0f14899fca71333315ea083a68e1d7c1e4cdc2f56bcd6119681a4adbc1bbf42afd806c3cbd42a076f545dee4e118d0b396754be2b042a685dd4727e89c0386a94d3cd6ecb9820e9d49afeed66c47e6fc243eabebbcb0b02453877f5ac5dbfbdf8db1052a3c994b224cd9aaaf56b026bb9efa2e01302b36401ab6494e7018d6e5b573bd38bcef023b1fc92946bbca0209ca5fa926b4970b1009103645cb1fcfe552311ff730558984370038fd2cce2a91fc74d6f3e3ea9f843eed356f6f82d35d03bc24b81b58ceb1a43ec9437e6f1e50eb6f555e321fd67c8332eb1b832aa8d795a27d479c6e27d5a61034683891903f66421d094e1b00a9a138d861e6f78a20ad3e1580054d2e305253c713a02fe1e28deee7336246f6ae34331806b46b47b833c39b9d31cd300c2a6ed831399776d07f570eaf0059a2c68a5f3ae16b617404af7b7231a4d942758fc020b3f23ee8c15e36044cfd67cd640993b16207597fbf385ea7a4d99e8d456ff83d41f7b8b4f069b028a2a63a919a70e3a10e3084158faa5bafa30186c6b2f238eb530c73e"); +pub const server_certificate = + hexToBytes("0b00032e0000032a0003253082032130820209a0030201020208155a92adc2048f90300d06092a864886f70d01010b05003022310b300906035504061302555331133011060355040a130a4578616d706c65204341301e170d3138313030353031333831375a170d3139313030353031333831375a302b310b3009060355040613025553311c301a060355040313136578616d706c652e756c666865696d2e6e657430820122300d06092a864886f70d01010105000382010f003082010a0282010100c4803606bae7476b089404eca7b691043ff792bc19eefb7d74d7a80d001e7b4b3a4ae60fe8c071fc73e7024c0dbcf4bdd11d396bba70464a13e94af83df3e10959547bc955fb412da3765211e1f3dc776caa53376eca3aecbec3aab73b31d56cb6529c8098bcc9e02818e20bf7f8a03afd1704509ece79bd9f39f1ea69ec47972e830fb5ca95de95a1e60422d5eebe527954a1e7bf8a86f6466d0d9f16951a4cf7a04692595c1352f2549e5afb4ebfd77a37950144e4c026874c653e407d7d23074401f484ffd08f7a1fa05210d1f4f0d5ce79702932e2cabe701fdfad6b4bb71101f44bad666a11130fe2ee829e4d029dc91cdd6716dbb9061886edc1ba94210203010001a3523050300e0603551d0f0101ff0404030205a0301d0603551d250416301406082b0601050507030206082b06010505070301301f0603551d23041830168014894fde5bcc69e252cf3ea300dfb197b81de1c146300d06092a864886f70d01010b05000382010100591645a69a2e3779e4f6dd271aba1c0bfd6cd75599b5e7c36e533eff3659084324c9e7a504079d39e0d42987ffe3ebdd09c1cf1d914455870b571dd19bdf1d24f8bb9a11fe80fd592ba0398cde11e2651e618ce598fa96e5372eef3d248afde17463ebbfabb8e4d1ab502a54ec0064e92f7819660d3f27cf209e667fce5ae2e4ac99c7c93818f8b2510722dfed97f32e3e9349d4c66c9ea6396d744462a06b42c6d5ba688eac3a017bddfc8e2cfcad27cb69d3ccdca280414465d3ae348ce0f34ab2fb9c618371312b191041641c237f11a5d65c844f0404849938712b959ed685bc5c5dd645ed19909473402926dcb40e3469a15941e8e2cca84bb6084636a00000"); + +pub const server_certificate_verify_wrapped = hexToBytes("170303011973719fce07ec2f6d3bba0292a0d40b2770c06a271799a53314f6f77fc95c5fe7b9a4329fd9548c670ebeea2f2d5c351dd9356ef2dcd52eb137bd3a676522f8cd0fb7560789ad7b0e3caba2e37e6b4199c6793b3346ed46cf740a9fa1fec414dc715c415c60e575703ce6a34b70b5191aa6a61a18faff216c687ad8d17e12a7e99915a611bfc1a2befc15e6e94d784642e682fd17382a348c301056b940c9847200408bec56c81ea3d7217ab8e85a88715395899c90587f72e8ddd74b26d8edc1c7c837d9f2ebbc260962219038b05654a63a0b12999b4a8306a3ddcc0e17c53ba8f9c80363f7841354d291b4ace0c0f330c0fcd5aa9deef969ae8ab2d98da88ebb6ea80a3a11f00ea296a3232367ff075e1c66dd9cbedc4713"); +pub const server_finished_wrapped = hexToBytes("17030300451061de27e51c2c9f342911806f282b710c10632ca5006755880dbf7006002d0e84fed9adf27a43b5192303e4df5c285d58e3c76224078440c0742374744aecf28cf3182fd0"); + +pub const handshake_hash = hexToBytes("fa6800169a6baac19159524fa7b9721b41be3c9db6f3f93fa5ff7e3db3ece204d2b456c51046e40ec5312c55a86126f5"); + +pub const client_finished_verify_data = hexToBytes("bff56a671b6c659d0a7c5dd18428f58bdd38b184a3ce342d9fde95cbd5056f7da7918ee320eab7a93abd8f1c02454d27"); + +pub const client_finished_wrapped = hexToBytes("17030300459ff9b063175177322a46dd9896f3c3bb820ab51743ebc25fdadd53454b73deb54cc7248d411a18bccf657a960824e9a19364837c350a69a88d4bf635c85eb874aebc9dfde8"); + +pub const client_ping_wrapped = hexToBytes("1703030015828139cb7b73aaabf5b82fbf9a2961bcde10038a32"); +pub const server_flight = + hexToBytes("140303000101") ++ + server_encrypted_extensions_wrapped ++ + server_certificate_wrapped ++ + server_certificate_verify_wrapped ++ + server_finished_wrapped; diff --git a/src/tls.zig/testu.zig b/src/tls.zig/testu.zig new file mode 100644 index 0000000..255fe6d --- /dev/null +++ b/src/tls.zig/testu.zig @@ -0,0 +1,117 @@ +const std = @import("std"); + +pub fn bufPrint(var_name: []const u8, buf: []const u8) void { + // std.debug.print("\nconst {s} = [_]u8{{\n", .{var_name}); + // for (buf, 1..) |b, i| { + // std.debug.print("0x{x:0>2}, ", .{b}); + // if (i % 16 == 0) + // std.debug.print("\n", .{}); + // } + // std.debug.print("}};\n", .{}); + + std.debug.print("const {s} = \"", .{var_name}); + const charset = "0123456789abcdef"; + for (buf) |b| { + const x = charset[b >> 4]; + const y = charset[b & 15]; + std.debug.print("{c}{c} ", .{ x, y }); + } + std.debug.print("\"\n", .{}); +} + +const random_instance = std.Random{ .ptr = undefined, .fillFn = randomFillFn }; +var random_seed: u8 = 0; + +pub fn randomFillFn(_: *anyopaque, buf: []u8) void { + for (buf) |*v| { + v.* = random_seed; + random_seed +%= 1; + } +} + +pub fn random(seed: u8) std.Random { + random_seed = seed; + return random_instance; +} + +// Fill buf with 0,1,..ff,0,... +pub fn fill(buf: []u8) void { + fillFrom(buf, 0); +} + +pub fn fillFrom(buf: []u8, start: u8) void { + var i: u8 = start; + for (buf) |*v| { + v.* = i; + i +%= 1; + } +} + +pub const Stream = struct { + output: std.io.FixedBufferStream([]u8) = undefined, + input: std.io.FixedBufferStream([]const u8) = undefined, + + pub fn init(input: []const u8, output: []u8) Stream { + return .{ + .input = std.io.fixedBufferStream(input), + .output = std.io.fixedBufferStream(output), + }; + } + + pub const ReadError = error{}; + pub const WriteError = error{NoSpaceLeft}; + + pub fn write(self: *Stream, buf: []const u8) !usize { + return try self.output.writer().write(buf); + } + + pub fn writeAll(self: *Stream, buffer: []const u8) !void { + var n: usize = 0; + while (n < buffer.len) { + n += try self.write(buffer[n..]); + } + } + + pub fn read(self: *Stream, buffer: []u8) !usize { + return self.input.read(buffer); + } +}; + +// Copied from: https://github.com/clickingbuttons/zig/blob/f1cea91624fd2deae28bfb2414a4fd9c7e246883/lib/std/crypto/rsa.zig#L791 +/// For readable copy/pasting from hex viewers. +pub fn hexToBytes(comptime hex: []const u8) [removeNonHex(hex).len / 2]u8 { + @setEvalBranchQuota(1000 * 100); + const hex2 = comptime removeNonHex(hex); + comptime var res: [hex2.len / 2]u8 = undefined; + _ = comptime std.fmt.hexToBytes(&res, hex2) catch unreachable; + return res; +} + +fn removeNonHex(comptime hex: []const u8) []const u8 { + @setEvalBranchQuota(1000 * 100); + var res: [hex.len]u8 = undefined; + var i: usize = 0; + for (hex) |c| { + if (std.ascii.isHex(c)) { + res[i] = c; + i += 1; + } + } + return res[0..i]; +} + +test hexToBytes { + const hex = + \\e3b0c442 98fc1c14 9afbf4c8 996fb924 + \\27ae41e4 649b934c a495991b 7852b855 + ; + try std.testing.expectEqual( + [_]u8{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + hexToBytes(hex), + ); +} diff --git a/src/tls.zig/transcript.zig b/src/tls.zig/transcript.zig new file mode 100644 index 0000000..59c9498 --- /dev/null +++ b/src/tls.zig/transcript.zig @@ -0,0 +1,297 @@ +const std = @import("std"); +const crypto = std.crypto; +const tls = crypto.tls; +const hkdfExpandLabel = tls.hkdfExpandLabel; + +const Sha256 = crypto.hash.sha2.Sha256; +const Sha384 = crypto.hash.sha2.Sha384; +const Sha512 = crypto.hash.sha2.Sha512; + +const HashTag = @import("cipher.zig").CipherSuite.HashTag; + +// Transcript holds hash of all handshake message. +// +// Until the server hello is parsed we don't know which hash (sha256, sha384, +// sha512) will be used so we update all of them. Handshake process will set +// `selected` field once cipher suite is known. Other function will use that +// selected hash. We continue to calculate all hashes because client certificate +// message could use different hash than the other part of the handshake. +// Handshake hash is dictated by the server selected cipher. Client certificate +// hash is dictated by the private key used. +// +// Most of the functions are inlined because they are returning pointers. +// +pub const Transcript = struct { + sha256: Type(.sha256) = .{ .hash = Sha256.init(.{}) }, + sha384: Type(.sha384) = .{ .hash = Sha384.init(.{}) }, + sha512: Type(.sha512) = .{ .hash = Sha512.init(.{}) }, + + tag: HashTag = .sha256, + + pub const max_mac_length = Type(.sha512).mac_length; + + // Transcript Type from hash tag + fn Type(h: HashTag) type { + return switch (h) { + .sha256 => TranscriptT(Sha256), + .sha384 => TranscriptT(Sha384), + .sha512 => TranscriptT(Sha512), + }; + } + + /// Set hash to use in all following function calls. + pub fn use(t: *Transcript, tag: HashTag) void { + t.tag = tag; + } + + pub fn update(t: *Transcript, buf: []const u8) void { + t.sha256.hash.update(buf); + t.sha384.hash.update(buf); + t.sha512.hash.update(buf); + } + + // tls 1.2 handshake specific + + pub inline fn masterSecret( + t: *Transcript, + pre_master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).masterSecret( + pre_master_secret, + client_random, + server_random, + ), + }; + } + + pub inline fn keyMaterial( + t: *Transcript, + master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).keyExpansion( + master_secret, + client_random, + server_random, + ), + }; + } + + pub fn clientFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).clientFinishedTls12(master_secret), + }; + } + + pub fn serverFinishedTls12(t: *Transcript, master_secret: []const u8) [12]u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).serverFinishedTls12(master_secret), + }; + } + + // tls 1.3 handshake specific + + pub inline fn serverCertificateVerify(t: *Transcript) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).serverCertificateVerify(), + }; + } + + pub inline fn clientCertificateVerify(t: *Transcript) []const u8 { + return switch (t.tag) { + inline else => |h| &@field(t, @tagName(h)).clientCertificateVerify(), + }; + } + + pub fn serverFinishedTls13(t: *Transcript, buf: []u8) []const u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).serverFinishedTls13(buf), + }; + } + + pub fn clientFinishedTls13(t: *Transcript, buf: []u8) []const u8 { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).clientFinishedTls13(buf), + }; + } + + pub const Secret = struct { + client: []const u8, + server: []const u8, + }; + + pub inline fn handshakeSecret(t: *Transcript, shared_key: []const u8) Secret { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).handshakeSecret(shared_key), + }; + } + + pub inline fn applicationSecret(t: *Transcript) Secret { + return switch (t.tag) { + inline else => |h| @field(t, @tagName(h)).applicationSecret(), + }; + } + + // other + + pub fn Hkdf(h: HashTag) type { + return Type(h).Hkdf; + } + + /// Copy of the current hash value + pub inline fn hash(t: *Transcript, comptime Hash: type) Hash { + return switch (Hash) { + Sha256 => t.sha256.hash, + Sha384 => t.sha384.hash, + Sha512 => t.sha512.hash, + else => @compileError("unimplemented"), + }; + } +}; + +fn TranscriptT(comptime Hash: type) type { + return struct { + const Hmac = crypto.auth.hmac.Hmac(Hash); + const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + const mac_length = Hmac.mac_length; + + hash: Hash, + handshake_secret: [Hmac.mac_length]u8 = undefined, + server_finished_key: [Hmac.key_length]u8 = undefined, + client_finished_key: [Hmac.key_length]u8 = undefined, + + const Self = @This(); + + fn init(transcript: Hash) Self { + return .{ .transcript = transcript }; + } + + fn serverCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { + return ([1]u8{0x20} ** 64) ++ + "TLS 1.3, server CertificateVerify\x00".* ++ + c.hash.peek(); + } + + // ref: https://www.rfc-editor.org/rfc/rfc8446#section-4.4.3 + fn clientCertificateVerify(c: *Self) [64 + 34 + Hash.digest_length]u8 { + return ([1]u8{0x20} ** 64) ++ + "TLS 1.3, client CertificateVerify\x00".* ++ + c.hash.peek(); + } + + fn masterSecret( + _: *Self, + pre_master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) [mac_length * 2]u8 { + const seed = "master secret" ++ client_random ++ server_random; + + var a1: [mac_length]u8 = undefined; + var a2: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, pre_master_secret); + Hmac.create(&a2, &a1, pre_master_secret); + + var p1: [mac_length]u8 = undefined; + var p2: [mac_length]u8 = undefined; + Hmac.create(&p1, a1 ++ seed, pre_master_secret); + Hmac.create(&p2, a2 ++ seed, pre_master_secret); + + return p1 ++ p2; + } + + fn keyExpansion( + _: *Self, + master_secret: []const u8, + client_random: [32]u8, + server_random: [32]u8, + ) [mac_length * 4]u8 { + const seed = "key expansion" ++ server_random ++ client_random; + + const a0 = seed; + var a1: [mac_length]u8 = undefined; + var a2: [mac_length]u8 = undefined; + var a3: [mac_length]u8 = undefined; + var a4: [mac_length]u8 = undefined; + Hmac.create(&a1, a0, master_secret); + Hmac.create(&a2, &a1, master_secret); + Hmac.create(&a3, &a2, master_secret); + Hmac.create(&a4, &a3, master_secret); + + var key_material: [mac_length * 4]u8 = undefined; + Hmac.create(key_material[0..mac_length], a1 ++ seed, master_secret); + Hmac.create(key_material[mac_length .. mac_length * 2], a2 ++ seed, master_secret); + Hmac.create(key_material[mac_length * 2 .. mac_length * 3], a3 ++ seed, master_secret); + Hmac.create(key_material[mac_length * 3 ..], a4 ++ seed, master_secret); + return key_material; + } + + fn clientFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { + const seed = "client finished" ++ self.hash.peek(); + var a1: [mac_length]u8 = undefined; + var p1: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, master_secret); + Hmac.create(&p1, a1 ++ seed, master_secret); + return p1[0..12].*; + } + + fn serverFinishedTls12(self: *Self, master_secret: []const u8) [12]u8 { + const seed = "server finished" ++ self.hash.peek(); + var a1: [mac_length]u8 = undefined; + var p1: [mac_length]u8 = undefined; + Hmac.create(&a1, seed, master_secret); + Hmac.create(&p1, a1 ++ seed, master_secret); + return p1[0..12].*; + } + + // tls 1.3 + + inline fn handshakeSecret(self: *Self, shared_key: []const u8) Transcript.Secret { + const hello_hash = self.hash.peek(); + + const zeroes = [1]u8{0} ** Hash.digest_length; + const early_secret = Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(Hash); + const hs_derived_secret = hkdfExpandLabel(Hkdf, early_secret, "derived", &empty_hash, Hash.digest_length); + + self.handshake_secret = Hkdf.extract(&hs_derived_secret, shared_key); + const client_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "c hs traffic", &hello_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "s hs traffic", &hello_hash, Hash.digest_length); + + self.server_finished_key = hkdfExpandLabel(Hkdf, server_secret, "finished", "", Hmac.key_length); + self.client_finished_key = hkdfExpandLabel(Hkdf, client_secret, "finished", "", Hmac.key_length); + + return .{ .client = &client_secret, .server = &server_secret }; + } + + inline fn applicationSecret(self: *Self) Transcript.Secret { + const handshake_hash = self.hash.peek(); + + const empty_hash = tls.emptyHash(Hash); + const zeroes = [1]u8{0} ** Hash.digest_length; + const ap_derived_secret = hkdfExpandLabel(Hkdf, self.handshake_secret, "derived", &empty_hash, Hash.digest_length); + const master_secret = Hkdf.extract(&ap_derived_secret, &zeroes); + + const client_secret = hkdfExpandLabel(Hkdf, master_secret, "c ap traffic", &handshake_hash, Hash.digest_length); + const server_secret = hkdfExpandLabel(Hkdf, master_secret, "s ap traffic", &handshake_hash, Hash.digest_length); + + return .{ .client = &client_secret, .server = &server_secret }; + } + + fn serverFinishedTls13(self: *Self, buf: []u8) []const u8 { + Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.server_finished_key); + return buf[0..mac_length]; + } + + // client finished message with header + fn clientFinishedTls13(self: *Self, buf: []u8) []const u8 { + Hmac.create(buf[0..mac_length], &self.hash.peek(), &self.client_finished_key); + return buf[0..mac_length]; + } + }; +}