diff --git a/packages/bun-types/bun.d.ts b/packages/bun-types/bun.d.ts index 185ab31607e04..99fd9de88efd0 100644 --- a/packages/bun-types/bun.d.ts +++ b/packages/bun-types/bun.d.ts @@ -2767,6 +2767,16 @@ declare module "bun" { compress?: boolean, ): ServerWebSocketSendStatus; + /** + * A count of connections subscribed to a given topic + * + * This operation will loop through each topic internally to get the count. + * + * @param topic the websocket topic to check how many subscribers are connected to + * @returns the number of subscribers + */ + subscriberCount(topic: string): number; + /** * Returns the client IP address and port of the given Request. If the request was closed or is a unix socket, returns null. * diff --git a/src/bun.js/api/server.classes.ts b/src/bun.js/api/server.classes.ts index ca71dc02fd449..3a267f0cfd9a4 100644 --- a/src/bun.js/api/server.classes.ts +++ b/src/bun.js/api/server.classes.ts @@ -16,6 +16,10 @@ function generate(name) { fn: "doPublish", length: 3, }, + subscriberCount: { + fn: "doSubscriberCount", + length: 1, + }, reload: { fn: "doReload", length: 2, diff --git a/src/bun.js/api/server.zig b/src/bun.js/api/server.zig index 04221c6672a16..8a1808492c054 100644 --- a/src/bun.js/api/server.zig +++ b/src/bun.js/api/server.zig @@ -4339,7 +4339,6 @@ pub const ServerWebSocket = struct { callframe: *JSC.CallFrame, ) JSValue { const args = callframe.arguments(4); - if (args.len < 1) { log("publish()", .{}); globalThis.throw("publish requires at least 1 argument", .{}); @@ -5354,6 +5353,31 @@ pub fn NewServer(comptime NamespaceType: type, comptime ssl_enabled_: bool, comp pub const doFetch = onFetch; pub const doRequestIP = JSC.wrapInstanceMethod(ThisServer, "requestIP", false); + pub fn doSubscriberCount(this: *ThisServer, globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) JSC.JSValue { + const arguments = callframe.arguments(1); + if (arguments.len < 1) { + globalThis.throwNotEnoughArguments("subscriberCount", 1, 0); + return .zero; + } + + if (arguments.ptr[0].isEmptyOrUndefinedOrNull()) { + globalThis.throwInvalidArguments("subscriberCount requires a topic name as a string", .{}); + return .zero; + } + + var topic = arguments.ptr[0].toSlice(globalThis, bun.default_allocator); + defer topic.deinit(); + if (globalThis.hasException()) { + return .zero; + } + + if (topic.len == 0) { + return JSValue.jsNumber(0); + } + + return JSValue.jsNumber((this.app.num_subscribers(topic.slice()))); + } + pub usingnamespace NamespaceType; pub usingnamespace bun.New(@This()); diff --git a/test/js/bun/websocket/websocket-server-fixture.js b/test/js/bun/websocket/websocket-server-fixture.js index 82fcc8844dac0..8e140b4b1f509 100644 --- a/test/js/bun/websocket/websocket-server-fixture.js +++ b/test/js/bun/websocket/websocket-server-fixture.js @@ -8,6 +8,7 @@ let pending = []; using server = Bun.serve({ + port: 0, websocket: { open(ws) { globalThis.sockets ??= []; diff --git a/test/js/bun/websocket/websocket-server.test.ts b/test/js/bun/websocket/websocket-server.test.ts index ec4de1e559579..3b283bc2b5d8e 100644 --- a/test/js/bun/websocket/websocket-server.test.ts +++ b/test/js/bun/websocket/websocket-server.test.ts @@ -492,9 +492,11 @@ describe("ServerWebSocket", () => { } } }; - test(label, (done, connect) => ({ + test(label, (done, connect, options) => ({ async open(ws) { + const initial = options.server.subscriberCount(topic); ws.subscribe(topic); + expect(options.server.subscriberCount(topic)).toBe(initial + 1); if (ws.data.id === 0) { await connect(); } else if (ws.data.id === 1) { @@ -525,10 +527,12 @@ describe("ServerWebSocket", () => { } } }; - test(label, done => ({ + test(label, (done, _, options) => ({ publishToSelf: true, async open(ws) { + const initial = options.server.subscriberCount(topic); ws.subscribe(topic); + expect(options.server.subscriberCount(topic)).toBe(initial + 1); send(ws); }, drain(ws) { @@ -690,7 +694,11 @@ describe("ServerWebSocket", () => { function test( label: string, - fn: (done: (err?: unknown) => void, connect: () => Promise) => Partial>, + fn: ( + done: (err?: unknown) => void, + connect: () => Promise, + options: { server: Server }, + ) => Partial>, timeout?: number, ) { it( @@ -705,6 +713,9 @@ function test( } }; let id = 0; + var options = { + server: undefined, + }; const server: Server = serve({ port: 0, fetch(request, server) { @@ -717,9 +728,11 @@ function test( websocket: { sendPings: false, message() {}, - ...fn(done, () => connect(server)), + ...fn(done, () => connect(server), options as any), }, }); + options.server = server; + expect(server.subscriberCount("empty topic")).toBe(0); await connect(server); }, { timeout: timeout ?? 1000 },