Skip to content

Commit

Permalink
Fix crash in WebSocket client when handshaking fails or when the HTTP…
Browse files Browse the repository at this point in the history
… response is invalid (#7933)

* Fix double-free in websocket client

* Update test

* Fix null pointer dereference

* Fix missing protect() / unprotect() call

* More careful checks

---------

Co-authored-by: Jarred Sumner <709451+Jarred-Sumner@users.noreply.github.com>
  • Loading branch information
Jarred-Sumner and Jarred-Sumner authored Jan 2, 2024
1 parent 9d6c064 commit 837cbd6
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 59 deletions.
11 changes: 8 additions & 3 deletions packages/bun-usockets/src/crypto/openssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
s = (struct us_internal_ssl_socket_t *)context->sc.on_writable(
&s->s); // cast here!
// if we are closed here, then exit
if (us_socket_is_closed(0, &s->s)) {
if (!s || us_socket_is_closed(0, &s->s)) {
return s;
}
}
Expand Down Expand Up @@ -544,8 +544,13 @@ ssl_on_writable(struct us_internal_ssl_socket_t *s) {
0); // cast here!
}

// should this one come before we have read? should it come always? spurious
// on_writable is okay

// Do not call on_writable if the socket is closed.
// on close means the socket data is no longer accessible
if (!s || us_socket_is_closed(0, &s->s)) {
return 0;
}

s = context->on_writable(s);

return s;
Expand Down
11 changes: 6 additions & 5 deletions packages/bun-usockets/src/loop.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
#endif
}
cb->cb(cb->cb_expects_the_loop ? (struct us_internal_callback_t *) cb->loop : (struct us_internal_callback_t *) &cb->p);
break;
}
break;
case POLL_TYPE_SEMI_SOCKET: {
/* Both connect and listen sockets are semi-sockets
* but they poll for different events */
Expand All @@ -220,6 +220,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
/* Emit error, close without emitting on_close */
s->context->on_connect_error(s, 0);
us_socket_close_connecting(0, s);
s = NULL;
} else {
/* All sockets poll for readable */
us_poll_change(p, s->context->loop, LIBUS_SOCKET_READABLE);
Expand Down Expand Up @@ -274,8 +275,8 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
} while ((client_fd = bsd_accept_socket(us_poll_fd(p), &addr)) != LIBUS_SOCKET_ERROR);
}
}
}
break;
}
case POLL_TYPE_SOCKET_SHUT_DOWN:
case POLL_TYPE_SOCKET: {
/* We should only use s, no p after this point */
Expand All @@ -288,7 +289,7 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)

s = s->context->on_writable(s);

if (us_socket_is_closed(0, s)) {
if (!s || us_socket_is_closed(0, s)) {
return;
}

Expand Down Expand Up @@ -346,13 +347,13 @@ void us_internal_dispatch_ready_poll(struct us_poll_t *p, int error, int events)
}

/* Such as epollerr epollhup */
if (error) {
if (error && s) {
/* Todo: decide what code we give here */
s = us_socket_close(0, s, 0, NULL);
return;
}
break;
}
break;
}
}

Expand Down
14 changes: 8 additions & 6 deletions src/bun.js/api/bun/socket.zig
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ pub const SocketConfig = struct {
return null;
}

const handlers = Handlers.fromJS(globalObject, opts.get(globalObject, "socket") orelse JSValue.zero, exception) orelse {
var handlers = Handlers.fromJS(globalObject, opts.get(globalObject, "socket") orelse JSValue.zero, exception) orelse {
hostname_or_unix.deinit();
return null;
};
Expand All @@ -444,6 +444,8 @@ pub const SocketConfig = struct {
default_data = default_data_value;
}

handlers.protect();

return SocketConfig{
.hostname_or_unix = hostname_or_unix,
.port = port,
Expand Down Expand Up @@ -547,7 +549,7 @@ pub const Listener = struct {
return .zero;
};

var prev_handlers = this.handlers;
var prev_handlers = &this.handlers;
prev_handlers.unprotect();
this.handlers = handlers; // TODO: this is a memory leak
this.handlers.protect();
Expand Down Expand Up @@ -579,7 +581,7 @@ pub const Listener = struct {
var hostname_or_unix = socket_config.hostname_or_unix;
const port = socket_config.port;
var ssl = socket_config.ssl;
var handlers = socket_config.handlers;
var handlers = &socket_config.handlers;
var protos: ?[]const u8 = null;
const exclusive = socket_config.exclusive;
handlers.is_server = true;
Expand Down Expand Up @@ -714,7 +716,7 @@ pub const Listener = struct {
};

var socket = Listener{
.handlers = handlers,
.handlers = handlers.*,
.connection = connection,
.ssl = ssl_enabled,
.socket_context = socket_context,
Expand Down Expand Up @@ -837,6 +839,7 @@ pub const Listener = struct {
this.poll_ref.unref(this.handlers.vm);
std.debug.assert(this.listener == null);
std.debug.assert(this.handlers.active_connections == 0);
this.handlers.unprotect();

if (this.socket_context) |ctx| {
ctx.deinit(this.ssl);
Expand Down Expand Up @@ -925,8 +928,6 @@ pub const Listener = struct {
const ssl_enabled = ssl != null;
defer if (ssl != null) ssl.?.deinit();

handlers.protect();

const ctx_opts: uws.us_bun_socket_context_options_t = JSC.API.ServerConfig.SSLConfig.asUSockets(socket_config.ssl);

globalObject.bunVM().eventLoop().ensureWaker();
Expand All @@ -938,6 +939,7 @@ pub const Listener = struct {
.code = if (port == null) bun.String.static("ENOENT") else bun.String.static("ECONNREFUSED"),
};
exception.* = err.toErrorInstance(globalObject).asObjectRef();
handlers.unprotect();
return .zero;
};

Expand Down
10 changes: 6 additions & 4 deletions src/bun.js/api/server.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3651,6 +3651,7 @@ pub const ServerWebSocket = struct {
opened: bool = false,

pub usingnamespace JSC.Codegen.JSServerWebSocket;
pub usingnamespace bun.New(ServerWebSocket);

const log = Output.scoped(.WebSocketServer, false);

Expand Down Expand Up @@ -3958,7 +3959,7 @@ pub const ServerWebSocket = struct {

pub fn finalize(this: *ServerWebSocket) callconv(.C) void {
log("finalize", .{});
bun.default_allocator.destroy(this);
this.destroy();
}

pub fn publish(
Expand Down Expand Up @@ -5077,11 +5078,12 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp

resp.clearAborted();

const ws = this.vm.allocator.create(ServerWebSocket) catch return .zero;
ws.* = .{
data_value.ensureStillAlive();
const ws = ServerWebSocket.new(.{
.handler = &this.config.websocket.?.handler,
.this_value = data_value,
};
});
data_value.ensureStillAlive();

var sec_websocket_protocol_str = sec_websocket_protocol.toSlice(bun.default_allocator);
defer sec_websocket_protocol_str.deinit();
Expand Down
18 changes: 18 additions & 0 deletions src/bun.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2835,7 +2835,13 @@ pub inline fn destroyWithAlloc(allocator: std.mem.Allocator, t: anytype) void {

pub fn New(comptime T: type) type {
return struct {
const allocation_logger = Output.scoped(.alloc, @hasDecl(T, "logAllocations"));

pub inline fn destroy(self: *T) void {
if (comptime Environment.allow_assert) {
allocation_logger("destroy({*})", .{self});
}

if (comptime is_heap_breakdown_enabled) {
HeapBreakdown.allocator(T).destroy(self);
} else {
Expand All @@ -2847,11 +2853,18 @@ pub fn New(comptime T: type) type {
if (comptime is_heap_breakdown_enabled) {
const ptr = HeapBreakdown.allocator(T).create(T) catch outOfMemory();
ptr.* = t;
if (comptime Environment.allow_assert) {
allocation_logger("new() = {*}", .{ptr});
}
return ptr;
}

const ptr = default_allocator.create(T) catch outOfMemory();
ptr.* = t;

if (comptime Environment.allow_assert) {
allocation_logger("new() = {*}", .{ptr});
}
return ptr;
}
};
Expand All @@ -2874,9 +2887,12 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void)
}

return struct {
const allocation_logger = Output.scoped(.alloc, @hasDecl(T, "logAllocations"));

pub fn destroy(self: *T) void {
if (comptime Environment.allow_assert) {
std.debug.assert(self.ref_count == 0);
allocation_logger("destroy() = {*}", .{self});
}

if (comptime is_heap_breakdown_enabled) {
Expand Down Expand Up @@ -2909,6 +2925,7 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void)

if (comptime Environment.allow_assert) {
std.debug.assert(ptr.ref_count == 1);
allocation_logger("new() = {*}", .{ptr});
}

return ptr;
Expand All @@ -2919,6 +2936,7 @@ pub fn NewRefCounted(comptime T: type, comptime deinit_fn: ?fn (self: *T) void)

if (comptime Environment.allow_assert) {
std.debug.assert(ptr.ref_count == 1);
allocation_logger("new() = {*}", .{ptr});
}

return ptr;
Expand Down
18 changes: 18 additions & 0 deletions src/deps/uws.zig
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,24 @@ pub fn NewSocketHandler(comptime is_ssl: bool) type {
@field(holder, socket_field_name) = adopted;
return holder;
}

pub fn adoptPtr(
socket: *Socket,
socket_ctx: *SocketContext,
comptime Context: type,
comptime socket_field_name: []const u8,
ctx: *Context,
) bool {
var adopted = ThisSocket{ .socket = us_socket_context_adopt_socket(comptime ssl_int, socket_ctx, socket, @sizeOf(*Context)) orelse return false };
const holder = adopted.ext(*anyopaque) orelse {
if (comptime bun.Environment.allow_assert) unreachable;
_ = us_socket_close(comptime ssl_int, socket, 0, null);
return false;
};
holder.* = ctx;
@field(ctx, socket_field_name) = adopted;
return true;
}
};
}

Expand Down
Loading

0 comments on commit 837cbd6

Please sign in to comment.