Skip to content

Commit

Permalink
add tls.zig
Browse files Browse the repository at this point in the history
  • Loading branch information
krichprollsch committed Nov 13, 2024
1 parent 7fb1db3 commit 0f04c1f
Show file tree
Hide file tree
Showing 33 changed files with 7,126 additions and 2,012 deletions.
2 changes: 1 addition & 1 deletion build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub fn build(b: *std.Build) void {
// Creates a step for unit testing. This only builds the test executable
// but does not run it.
const lib_unit_tests = b.addTest(.{
.root_source_file = b.path("src/root.zig"),
.root_source_file = b.path("src/main.zig"),
.target = target,
.optimize = optimize,
});
Expand Down
4 changes: 2 additions & 2 deletions src/io.zig
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const std = @import("std");

const Ctx = @import("std/http/Client.zig").Ctx;
const Cbk = @import("std/http/Client.zig").Cbk;
pub const Ctx = @import("std/http/Client.zig").Ctx;
pub const Cbk = @import("std/http/Client.zig").Cbk;

pub const Blocking = struct {
pub fn connect(
Expand Down
2 changes: 1 addition & 1 deletion src/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn main() !void {
var ctx = try Client.Ctx.init(&loop, &req);
defer ctx.deinit();

var server_header_buffer: [2048]u8 = undefined;
var server_header_buffer: [1024 * 1024]u8 = undefined;

try client.async_open(
.GET,
Expand Down
1,968 changes: 0 additions & 1,968 deletions src/std/crypto/tls/Client.zig

This file was deleted.

13 changes: 4 additions & 9 deletions src/std/http.zig
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub const Client = @import("http/Client.zig");
pub const Server = @import("http/Server.zig");
pub const protocol = @import("http/protocol.zig");
pub const HeadParser = @import("http/HeadParser.zig");
pub const ChunkParser = @import("http/ChunkParser.zig");
pub const HeaderIterator = @import("http/HeaderIterator.zig");
pub const HeadParser = std.http.HeadParser;
pub const ChunkParser = std.http.ChunkParser;
pub const HeaderIterator = std.http.HeaderIterator;

pub const Version = enum {
@"HTTP/1.0",
Expand Down Expand Up @@ -308,16 +308,11 @@ pub const Header = struct {
};

const builtin = @import("builtin");
const std = @import("std.zig");
const std = @import("std");

test {
_ = Client;
_ = Method;
_ = Server;
_ = Status;
_ = HeadParser;
_ = ChunkParser;
if (builtin.os.tag != .wasi) {
_ = @import("http/test.zig");
}
}
74 changes: 44 additions & 30 deletions src/std/http/Client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ const use_vectors = builtin.zig_backend != .stage2_x86_64;
const Client = @This();
const proto = @import("protocol.zig");

const async_net = @import("../net.zig");
const async_tls = @import("../crypto/tls/Client.zig");
const tls23 = @import("../../tls.zig/main.zig");
const VecPut = @import("../../tls.zig/connection.zig").VecPut;
const GenericStack = @import("../../stack.zig").Stack;
const async_io = @import("../../io.zig");
const Loop = async_io.Blocking;

const cipher = @import("../../tls.zig/cipher.zig");

pub const disable_tls = std.options.http_disable_tls;

/// Used for all client allocations. Must be thread-safe.
Expand Down Expand Up @@ -196,9 +198,9 @@ pub const ConnectionPool = struct {

/// An interface to either a plain or TLS connection.
pub const Connection = struct {
stream: async_net.Stream,
stream: net.Stream,
/// undefined unless protocol is tls.
tls_client: if (!disable_tls) *async_tls.Client else void,
tls_client: if (!disable_tls) *tls23.Connection(net.Stream) else void,

/// The protocol that this connection is using.
protocol: Protocol,
Expand Down Expand Up @@ -244,12 +246,12 @@ pub const Connection = struct {
}

pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
return conn.tls_client.readv(conn.stream, buffers) catch |err| {
return conn.tls_client.readv(buffers) catch |err| {
// https://github.com/ziglang/zig/issues/2473
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;

switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.TlsRecordOverflow, error.TlsBadRecordMac, error.TlsUnexpectedMessage => return error.TlsFailure,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
Expand Down Expand Up @@ -289,6 +291,7 @@ pub const Connection = struct {
if (conn.read_end != conn.read_start) return;

ctx._iovecs = try ctx.alloc().alloc(std.posix.iovec, 1);
errdefer ctx.alloc().free(ctx._iovecs);
const iovecs = [1]std.posix.iovec{
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
Expand Down Expand Up @@ -369,7 +372,7 @@ pub const Connection = struct {
}

pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void {
return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) {
return conn.tls_client.writeAll(buffer) catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
Expand Down Expand Up @@ -476,7 +479,7 @@ pub const Connection = struct {
if (disable_tls) unreachable;

// try to cleanly close the TLS connection, for any server that cares.
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
conn.tls_client.close() catch {};
allocator.destroy(conn.tls_client);
}

Expand Down Expand Up @@ -1644,13 +1647,14 @@ fn setConnection(ctx: *Ctx, res: anyerror!void) !void {
if (ctx.data.conn.protocol == .tls) {
if (disable_tls) unreachable;

ctx.data.conn.tls_client = try ctx.alloc().create(async_tls.Client);
ctx.data.conn.tls_client = try ctx.alloc().create(tls23.Connection(net.Stream));
errdefer ctx.alloc().destroy(ctx.data.conn.tls_client);

ctx.data.conn.tls_client.* = async_tls.Client.init(ctx.data.conn.stream, ctx.req.client.ca_bundle, ctx.data.conn.host) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
ctx.data.conn.tls_client.allow_truncation_attacks = true;
// TODO tls23.client does an handshake to pick a cipher.
ctx.data.conn.tls_client.* = tls23.client(ctx.data.conn.stream, .{
.host = ctx.data.conn.host,
.root_ca = .{ .bundle = ctx.req.client.ca_bundle },
}) catch return error.TlsInitializationFailed;
}

// add connection node in pool
Expand Down Expand Up @@ -1720,13 +1724,14 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec
if (protocol == .tls) {
if (disable_tls) unreachable;

conn.data.tls_client = try client.allocator.create(async_tls.Client);
conn.data.tls_client = try client.allocator.create(tls23.Connection(net.Stream));
errdefer client.allocator.destroy(conn.data.tls_client);

conn.data.tls_client.* = async_tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
// This is appropriate for HTTPS because the HTTP headers contain
// the content length which is used to detect truncation attacks.
conn.data.tls_client.allow_truncation_attacks = true;
// TODO tls23.client does an handshake to pick a cipher.
conn.data.tls_client.* = tls23.client(stream, .{
.host = host,
.root_ca = .{ .bundle = client.ca_bundle },
}) catch return error.TlsInitializationFailed;
}

client.connection_pool.addUsed(conn);
Expand Down Expand Up @@ -1755,7 +1760,7 @@ pub fn async_connectTcp(
if (disable_tls and protocol == .tls)
return error.TlsInitializationFailed;

return async_net.async_tcpConnectToHost(
return net.async_tcpConnectToHost(
client.allocator,
host,
port,
Expand Down Expand Up @@ -2317,17 +2322,26 @@ pub const Ctx = struct {
_buffer: ?[]const u8 = null,
_len: ?usize = null,

// TLS readvAtLeast
_off_i: usize = 0,
_vec_i: usize = 0,
_tls_len: usize = 0,
_iovecs: []std.posix.iovec = undefined,

// TLS readvAdvanced
_vp: *async_tls.VecPut = undefined,
_cleartext_stack_buffer: []u8 = undefined,
_in_stack_buffer: []u8 = undefined,
_first_iov: []u8 = undefined,
// TLS readvAtLeast
// _off_i: usize = 0,
// _vec_i: usize = 0,
// _tls_len: usize = 0,

// TLS readv
_vp: VecPut = undefined,
// _tls_read_buf contains the next decrypted buffer
_tls_read_buf: ?[]u8 = undefined,
_tls_read_content_type: tls23.proto.ContentType = undefined,

// _tls_read_record contains the crypted record
_tls_read_record: ?tls23.record.Record = null,

// TLS writeAll
_tls_write_bytes: []const u8 = undefined,
_tls_write_index: usize = 0,
_tls_write_buf: [cipher.max_ciphertext_record_len]u8 = undefined,

pub fn init(loop: *Loop, req: *Request) !Ctx {
const connection = try req.client.allocator.create(Connection);
Expand Down Expand Up @@ -2404,7 +2418,7 @@ pub const Ctx = struct {
return self.req.connection.?;
}

pub fn stream(self: Ctx) async_net.Stream {
pub fn stream(self: Ctx) net.Stream {
return self.conn().stream;
}
};
Expand All @@ -2424,7 +2438,7 @@ fn onRequestWait(ctx: *Ctx, res: anyerror!void) !void {
};
std.log.debug("REQUEST WAITED", .{});
std.log.debug("Status code: {any}", .{ctx.req.response.status});
const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 2000);
const body = try ctx.req.reader().readAllAlloc(ctx.alloc(), 1024 * 1024);
defer ctx.alloc().free(body);
std.log.debug("Body: \n{s}", .{body});
}
Expand Down
2 changes: 1 addition & 1 deletion src/std/http/Server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ fn rebase(s: *Server, index: usize) void {
s.read_buffer_len = index + leftover.len;
}

const std = @import("../std.zig");
const std = @import("std");
const http = std.http;
const mem = std.mem;
const net = std.net;
Expand Down
Loading

0 comments on commit 0f04c1f

Please sign in to comment.