Skip to content

Commit

Permalink
use std reader/writer
Browse files Browse the repository at this point in the history
  • Loading branch information
marler8997 committed Jun 10, 2020
1 parent 22b7820 commit a47321b
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 77 deletions.
1 change: 0 additions & 1 deletion build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pub fn build(b: *Builder) !void {
exe.setTarget(target);
exe.single_threaded = true;
exe.setBuildMode(mode);
exe.addPackage(Pkg { .name = "stdext", .path = "src-stdext/stdext.zig" });
if (openssl) {
exe.addPackage(Pkg { .name = "ssl", .path = "openssl/ssl.zig" });
exe.linkSystemLibrary("c");
Expand Down
13 changes: 6 additions & 7 deletions nossl/ssl.zig
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
const std = @import("std");

// TODO: why isn't this working?
//const stdext = @import("stdext");
const stdext= @import("../src-stdext/stdext.zig");
const readwrite = stdext.readwrite;
const ReaderWriter = readwrite.ReaderWriter;

pub fn init() anyerror!void {
std.debug.warn("[DEBUG] nossl init\n", .{});
}

pub const SslConn = struct {
rw: ReaderWriter,
pub fn init(file: std.fs.File, serverName: []const u8) !SslConn {
return error.NoSslConfigured;
}
pub fn deinit(self: SslConn) void { }
pub fn read(self: SslConn, data: []u8) !usize {
@panic("nossl has been configured");
}
pub fn write(self: SslConn, data: []const u8) !usize {
@panic("nossl has been configured");
}
};
84 changes: 39 additions & 45 deletions openssl/ssl.zig
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
const std = @import("std");

// TODO: why isn't this working?
//const stdext = @import("stdext");
const stdext= @import("../src-stdext/stdext.zig");
const readwrite = stdext.readwrite;
const Reader = readwrite.Reader;
const Writer = readwrite.Writer;
const ReaderWriter = readwrite.ReaderWriter;

const openssl = @cImport({
@cInclude("openssl/ssl.h");
@cInclude("openssl/err.h");
Expand Down Expand Up @@ -45,7 +37,6 @@ pub fn init() anyerror!void {
}

pub const SslConn = struct {
rw: ReaderWriter,
ctx: *openssl.SSL_CTX,
ssl: *openssl.SSL,

Expand Down Expand Up @@ -102,20 +93,30 @@ pub const SslConn = struct {
}

return SslConn {
.rw = .{
.reader = .{ .readFn = read },
.writer = .{ .writeFn = write },
},
.ctx = ctx,
.ssl = ssl,
};
}
pub fn deinit(self: SslConn) void {
openssl.SSL_CTX_free(self.ctx);
}
pub fn read(reader: *Reader, data: []u8) anyerror!usize {
const self = @fieldParentPtr(SslConn, "rw",
@fieldParentPtr(ReaderWriter, "reader", reader));

//pub const ReadError = FnError(@TypeOf(readBoth));
//pub const WriteError = FnError(@TypeOf(writeBoth));
pub const ReadError = error { };
pub const WriteError = error { };
pub const Reader = std.io.Reader(*SslConn, ReadError, read);
pub const Writer = std.io.Writer(*SslConn, WriteError, write);

pub fn reader(self: *@This()) Reader {
return .{ .context = self };
}

pub fn writer(self: *@This()) Writer {
return .{ .context = self };
}

pub fn read(self: *SslConn, data: []u8) !usize {
var readSize : usize = undefined;
const result = openssl.SSL_read_ex(self.ssl, data.ptr, data.len, &readSize);
if (1 == result)
Expand All @@ -127,36 +128,29 @@ pub const SslConn = struct {
else => std.debug.panic("SSL_read failed with {}\n", .{err}),
}
}
pub fn write(writer: *Writer, data: []const u8) anyerror!void {
const self = @fieldParentPtr(SslConn, "rw",
@fieldParentPtr(ReaderWriter, "writer", writer));
var written : usize = 0;
while (written < data.len) {
var writeSize = data.len - written;
if (writeSize > std.math.maxInt(c_int))
writeSize = std.math.maxInt(c_int);
const result = openssl.SSL_write(self.ssl, data.ptr + written, @intCast(c_int, writeSize));
if (result <= 0) {
const err = openssl.SSL_get_error(self.ssl, result);
switch (err) {
openssl.SSL_ERROR_NONE => unreachable,
openssl.SSL_ERROR_ZERO_RETURN => unreachable,
openssl.SSL_ERROR_WANT_READ
,openssl.SSL_ERROR_WANT_WRITE
,openssl.SSL_ERROR_WANT_CONNECT
,openssl.SSL_ERROR_WANT_ACCEPT
,openssl.SSL_ERROR_WANT_X509_LOOKUP
,openssl.SSL_ERROR_WANT_ASYNC
,openssl.SSL_ERROR_WANT_ASYNC_JOB
,openssl.SSL_ERROR_WANT_CLIENT_HELLO_CB
,openssl.SSL_ERROR_SYSCALL
,openssl.SSL_ERROR_SSL
=> std.debug.panic("SSL_write failed with {}\n", .{err}),
else
=> std.debug.panic("SSL_write failed with {}\n", .{err}),
}
pub fn write(self: *SslConn, data: []const u8) !usize {
// TODO: and writeSize with c_int mask, it's ok if we don't write all the data
const result = openssl.SSL_write(self.ssl, data.ptr, @intCast(c_int, data.len));
if (result <= 0) {
const err = openssl.SSL_get_error(self.ssl, result);
switch (err) {
openssl.SSL_ERROR_NONE => unreachable,
openssl.SSL_ERROR_ZERO_RETURN => unreachable,
openssl.SSL_ERROR_WANT_READ
,openssl.SSL_ERROR_WANT_WRITE
,openssl.SSL_ERROR_WANT_CONNECT
,openssl.SSL_ERROR_WANT_ACCEPT
,openssl.SSL_ERROR_WANT_X509_LOOKUP
,openssl.SSL_ERROR_WANT_ASYNC
,openssl.SSL_ERROR_WANT_ASYNC_JOB
,openssl.SSL_ERROR_WANT_CLIENT_HELLO_CB
,openssl.SSL_ERROR_SYSCALL
,openssl.SSL_ERROR_SSL
=> std.debug.panic("SSL_write failed with {}\n", .{err}),
else
=> std.debug.panic("SSL_write failed with {}\n", .{err}),
}
written += @intCast(usize, result);
}
return @intCast(usize, result);
}
};
6 changes: 1 addition & 5 deletions ziget-cmdline.zig
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
const std = @import("std");

const stdext = @import("stdext");
const readwrite = stdext.readwrite;

const ziget = @import("./ziget.zig");
const ssl = @import("ssl");

Expand Down Expand Up @@ -118,10 +115,9 @@ pub fn main() anyerror!u8 {
if (outFile.handle != std.io.getStdOut().handle)
outFile.close();
}
var outFileRw = readwrite.FileReaderWriter.init(outFile);

var downloadState = ziget.request.DownloadState.init();
ziget.request.download(url, &outFileRw.rw.writer, options, &downloadState) catch |e| switch (e) {
ziget.request.download(url, outFile.writer(), options, &downloadState) catch |e| switch (e) {
error.UnknownUrlScheme => {
printError("unknown url scheme '{}'", .{url.schemeString()});
return 1;
Expand Down
46 changes: 46 additions & 0 deletions ziget/net_stream.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
const std = @import("std");
const ssl = @import("ssl");

pub const NetStream = union(enum) {
File: *const std.fs.File,
Ssl: *ssl.SslConn,

pub const Reader = std.io.Reader(*@This(), FnErrorSet(@TypeOf(read)), read);
pub const Writer = std.io.Writer(*@This(), FnErrorSet(@TypeOf(write)), write);

pub fn initFile(file: *const std.fs.File) NetStream {
return .{ .File = file };
}
pub fn initSsl(conn: *ssl.SslConn) NetStream {
return .{ .Ssl = conn };
}

pub fn reader(self: *@This()) Reader {
return .{ .context = self };
}

pub fn writer(self: *@This()) Writer {
return .{ .context = self };
}

pub fn read(self: *@This(), dest: []u8) !usize {
switch (self.*) {
.File => |s| return s.read(dest),
.Ssl => |s| return s.read(dest),
}
}

pub fn write(self: *@This(), bytes: []const u8) !usize {
switch (self.*) {
.File => |s| return try s.write(bytes),
.Ssl => |s| return try s.write(bytes),
}
}
};

/// TODO: move this
/// Returns the error set for the given function type
fn FnErrorSet(comptime FnType: type) type {
const Return = @typeInfo(FnType).Fn.return_type.?;
return @typeInfo(Return).ErrorUnion.error_set;
}
34 changes: 15 additions & 19 deletions ziget/request.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ const testing = std.testing;

const Allocator = mem.Allocator;

const stdext = @import("stdext");
const readwrite = stdext.readwrite;
const Reader = readwrite.Reader;
const Writer = readwrite.Writer;
const FileReaderWriter = readwrite.FileReaderWriter;
const net_stream = @import("./net_stream.zig");
const NetStream = net_stream.NetStream;

const urlmod = @import("./url.zig");
const Url = urlmod.Url;
Expand Down Expand Up @@ -52,7 +49,7 @@ pub const DownloadState = struct {
};

// have to use anyerror for now because download and downloadHttp recursively call each other
pub fn download(url: Url, writer: *Writer, options: DownloadOptions, state: *DownloadState) !void {
pub fn download(url: Url, writer: var, options: DownloadOptions, state: *DownloadState) !void {
var nextUrl = url;
while (true) {
const result = switch (nextUrl) {
Expand Down Expand Up @@ -94,7 +91,7 @@ pub fn httpAlloc(allocator: *Allocator, method: []const u8, resource: []const u8
method, resource, host, headers});
}

pub fn sendHttpGet(allocator: *Allocator, writer: *Writer, httpUrl: Url.Http, keepAlive: bool) !void {
pub fn sendHttpGet(allocator: *Allocator, writer: var, httpUrl: Url.Http, keepAlive: bool) !void {
const request = try httpAlloc(allocator, "GET", httpUrl.str,
httpUrl.getHostPortString(),
if (keepAlive) "Connection: keep-alive\r\n" else "Connection: close\r\n"
Expand All @@ -104,14 +101,14 @@ pub fn sendHttpGet(allocator: *Allocator, writer: *Writer, httpUrl: Url.Http, ke
std.debug.warn("Sending HTTP Request...\n", .{});
std.debug.warn("--------------------------------------------------------------------------------\n", .{});
std.debug.warn("{}", .{request});
try writer.write(request);
try writer.writeAll(request);
}

const HttpResponseData = struct {
headerLimit: usize,
dataLimit: usize,
};
pub fn readHttpResponse(buffer: []u8, reader: *Reader) !HttpResponseData {
pub fn readHttpResponse(buffer: []u8, reader: var) !HttpResponseData {
var totalRead : usize = 0;
while (true) {
if (totalRead >= buffer.len)
Expand All @@ -131,34 +128,33 @@ pub fn readHttpResponse(buffer: []u8, reader: *Reader) !HttpResponseData {
}

// TODO: call sendFile on linux so we don't have to read the data into memory
pub fn forward(buffer: []u8, reader: *Reader, writer: *Writer) !void {
pub fn forward(buffer: []u8, reader: var, writer: var) !void {
while (true) {
var len = try reader.read(buffer);
if (len == 0) break;
try writer.write(buffer[0..len]);
try writer.writeAll(buffer[0..len]);
}
}

pub fn downloadHttpOrRedirect(httpUrl: Url.Http, writer: *Writer, options: DownloadOptions, state: *DownloadState) !DownloadResult {
pub fn downloadHttpOrRedirect(httpUrl: Url.Http, writer: var, options: DownloadOptions, state: *DownloadState) !DownloadResult {
const file = try net.tcpConnectToHost(options.allocator, httpUrl.getHostString(), httpUrl.getPortOrDefault());
defer {
// TODO: file.shutdown()???
file.close();
}
var fileRw = FileReaderWriter.init(file);
var rw = &fileRw.rw;
var stream = NetStream.initFile(&file);

var sslConn : ssl.SslConn = undefined;
if (httpUrl.secure) {
sslConn = try ssl.SslConn.init(file, httpUrl.getHostString());
rw = &sslConn.rw;
stream = NetStream.initSsl(&sslConn);
}
defer { if (httpUrl.secure) sslConn.deinit(); }

try sendHttpGet(options.allocator, &rw.writer, httpUrl, false);
try sendHttpGet(options.allocator, stream.writer(), httpUrl, false);

const buffer = options.buffer;
const response = try readHttpResponse(buffer, &rw.reader);
const response = try readHttpResponse(buffer, stream.reader());
std.debug.warn("--------------------------------------------------------------------------------\n", .{});
std.debug.warn("Received Http Response:\n", .{});
std.debug.warn("--------------------------------------------------------------------------------\n", .{});
Expand All @@ -180,8 +176,8 @@ pub fn downloadHttpOrRedirect(httpUrl: Url.Http, writer: *Writer, options: Downl
}

if (response.dataLimit > response.headerLimit) {
try writer.write(buffer[response.headerLimit..response.dataLimit]);
try writer.writeAll(buffer[response.headerLimit..response.dataLimit]);
}
try forward(buffer, &rw.reader, writer);
try forward(buffer, stream.reader(), writer);
return DownloadResult.Success;
}

0 comments on commit a47321b

Please sign in to comment.