diff --git a/packages/core/src/com/hosts/ws-client-host.ts b/packages/core/src/com/hosts/ws-client-host.ts index ffb874529..707986d34 100644 --- a/packages/core/src/com/hosts/ws-client-host.ts +++ b/packages/core/src/com/hosts/ws-client-host.ts @@ -10,7 +10,8 @@ export class WsClientHost extends BaseHost implements IDisposable { isDisposed = this.disposables.isDisposed; public connected: Promise; private socketClient: Socket; - public subscribers = new EventEmitter<{ disconnect: void; reconnect: void }>(); + public subscribers = new EventEmitter<{ disconnect: string; reconnect: void; connect: void }>(); + private stableClientId = crypto.randomUUID(); constructor(url: string, options?: Partial) { super(); @@ -24,10 +25,12 @@ export class WsClientHost extends BaseHost implements IDisposable { this.socketClient = io(url, { transports: ['websocket'], - forceNew: true, withCredentials: true, // Pass Cookie to socket io connection path, query, + auth: { + clientId: this.stableClientId, + }, ...options, }); @@ -36,15 +39,16 @@ export class WsClientHost extends BaseHost implements IDisposable { }); this.socketClient.on('connect', () => { - this.socketClient.on('message', (data: unknown) => { - this.emitMessageHandlers(data as Message); - }); + this.subscribers.emit('connect', undefined); resolve(); }); - this.socketClient.on('disconnect', () => { - this.subscribers.emit('disconnect', undefined); - this.socketClient.close(); + this.socketClient.on('message', (data: unknown) => { + this.emitMessageHandlers(data as Message); + }); + + this.socketClient.on('disconnect', (reason: string) => { + this.subscribers.emit('disconnect', reason); }); this.socketClient.on('reconnect', () => { @@ -57,4 +61,18 @@ export class WsClientHost extends BaseHost implements IDisposable { public postMessage(data: any) { this.socketClient.emit('message', data); } + + disconnectSocket() { + if (this.socketClient.connected) { + this.socketClient.disconnect(); + } + } + reconnectSocket() { + if (!this.socketClient.connected) { + this.socketClient.connect(); + } + } + isConnected(): boolean { + return this.socketClient.connected; + } } diff --git a/packages/runtime-node/src/ws-node-host.ts b/packages/runtime-node/src/ws-node-host.ts index 81b0e3b23..be4a31a20 100644 --- a/packages/runtime-node/src/ws-node-host.ts +++ b/packages/runtime-node/src/ws-node-host.ts @@ -14,47 +14,128 @@ export class WsHost extends BaseHost { } } +type ClientEnvId = string; +type ClientId = string; + export class WsServerHost extends BaseHost implements IDisposable { - private socketToEnvId = new Map(); + private clients = new Map< + ClientId, + { + socket: io.Socket; + namespacedEnvIds: Set; + disposeTimer?: NodeJS.Timeout; + } + >(); private disposables = new SafeDisposable(WsServerHost.name); dispose = this.disposables.dispose; isDisposed = this.disposables.isDisposed; + private disposeGraceMs: number; - constructor(private server: io.Server | io.Namespace) { + constructor( + private server: io.Server | io.Namespace, + config: { disposeGraceMs?: number } = {}, + ) { super(); + this.disposeGraceMs = config.disposeGraceMs ?? 120_000; this.server.on('connection', this.onConnection); this.disposables.add('connection', () => this.server.off('connection', this.onConnection)); this.disposables.add('clear handlers', () => this.handlers.clear()); + this.disposables.add('dispose clients', () => { + // clear pending dispose timers and emit dispose messages for all env IDs + for (const client of this.clients.values()) { + if (client.disposeTimer) { + clearTimeout(client.disposeTimer); + } + this.emitDisposeMessagesForClient(client.namespacedEnvIds); + } + this.clients.clear(); + }); + } + + private extractClientIdAndEnvId(namespacedId: string): { stableClientId: string; envId: string } | undefined { + const slashIndex = namespacedId.indexOf('/'); + if (slashIndex === -1) { + return undefined; + } + return { + stableClientId: namespacedId.slice(0, slashIndex), + envId: namespacedId.slice(slashIndex + 1), + }; + } + + private emitDisposeMessagesForClient(namespacedEnvIds: Set): void { + for (const envId of namespacedEnvIds) { + this.emitMessageHandlers({ + type: 'dispose', + from: envId, + origin: envId, + to: '*', + forwardingChain: [], + }); + } } public postMessage(data: Message) { if (data.to !== '*') { - if (this.socketToEnvId.has(data.to)) { - const { socket, clientID } = this.socketToEnvId.get(data.to)!; - data.to = clientID; - socket.emit('message', data); - } else { - this.server.emit('message', data); + const parsed = this.extractClientIdAndEnvId(data.to); + if (parsed) { + const client = this.clients.get(parsed.stableClientId); + + if (client) { + data.to = parsed.envId; + client.socket.emit('message', data); + return; + } } + // If not found in any client, broadcast + this.server.emit('message', data); } else { this.server.emit('message', data); } } private onConnection = (socket: io.Socket): void => { - const nameSpace = (original: string) => `${socket.id}/${original}`; + const clientId = socket.handshake.auth?.clientId; + if (!clientId) { + socket.disconnect(true); + return; + } + + // Handle reconnection: update socket and clear dispose timer + const existingClient = this.clients.get(clientId); + if (existingClient) { + // Clear dispose timer if exists + if (existingClient.disposeTimer) { + clearTimeout(existingClient.disposeTimer); + existingClient.disposeTimer = undefined; + } + + // remove old socket listeners + existingClient.socket.removeAllListeners(); + // Update socket reference + existingClient.socket = socket; + } else { + // New connection: create client entry + this.clients.set(clientId, { + socket, + namespacedEnvIds: new Set(), + }); + } + const onMessage = (message: Message): void => { - // this mapping should not be here because of forwarding of messages - // maybe change message forwarding to have 'forward destination' and correct 'from' - // also maybe we can put the init of the map on 'connection' event - // maybe we can notify from client about the new connected id - const originId = nameSpace(message.origin); - const fromId = nameSpace(message.from); - this.socketToEnvId.set(fromId, { socket, clientID: message.from }); - this.socketToEnvId.set(originId, { socket, clientID: message.origin }); - // modify message to be able to forward it - message.from = fromId; - message.origin = originId; + const client = this.clients.get(clientId); + if (!client) return; + // Namespace the env IDs with stableClientId to differentiate between clients + const namespacedFrom = `${clientId}/${message.from}`; + const namespacedOrigin = `${clientId}/${message.origin}`; + + // Track namespaced env IDs for this client + client.namespacedEnvIds.add(namespacedFrom); + client.namespacedEnvIds.add(namespacedOrigin); + + // Modify message with namespaced IDs for routing + message.from = namespacedFrom; + message.origin = namespacedOrigin; this.emitMessageHandlers(message); }; @@ -62,18 +143,18 @@ export class WsServerHost extends BaseHost implements IDisposable { socket.once('disconnect', () => { socket.off('message', onMessage); - for (const [envId, { socket: soc }] of this.socketToEnvId.entries()) { - if (socket === soc) { - this.socketToEnvId.delete(envId); - this.emitMessageHandlers({ - type: 'dispose', - from: envId, - origin: envId, - to: '*', - forwardingChain: [], - }); - } - } + + const client = this.clients.get(clientId); + if (!client) return; + + // Delay dispose to allow for socket recovery + client.disposeTimer = setTimeout(() => { + const clientToDispose = this.clients.get(clientId); + if (!clientToDispose) return; + + this.clients.delete(clientId); + this.emitDisposeMessagesForClient(clientToDispose.namespacedEnvIds); + }, this.disposeGraceMs); }); }; } diff --git a/packages/runtime-node/test/node-com.unit.ts b/packages/runtime-node/test/node-com.unit.ts index 11c592ba5..4a1a927e4 100644 --- a/packages/runtime-node/test/node-com.unit.ts +++ b/packages/runtime-node/test/node-com.unit.ts @@ -14,7 +14,7 @@ import { expect } from 'chai'; import { safeListeningHttpServer } from 'create-listening-server'; import { fork } from 'node:child_process'; import type { Socket } from 'node:net'; -import { waitFor } from 'promise-assist'; +import { sleep, waitFor } from 'promise-assist'; import sinon, { spy } from 'sinon'; import * as io from 'socket.io'; @@ -24,6 +24,7 @@ interface ICommunicationTestApi { } describe('Socket communication', () => { + const disposeGraceMs = 10; let clientHost: WsClientHost; let serverHost: WsServerHost; let socketServer: io.Server; @@ -56,7 +57,9 @@ describe('Socket communication', () => { }); clientHost = new WsClientHost(serverTopology['server-host']); - serverHost = new WsServerHost(nameSpace); + disposables.add(() => clientHost.dispose()); + serverHost = new WsServerHost(nameSpace, { disposeGraceMs }); + disposables.add(() => serverHost.dispose()); await clientHost.connected; }); @@ -225,19 +228,20 @@ describe('Socket communication', () => { it('notifies if environment is disconnected', async () => { const spy = sinon.spy(); const clientCom = new Communication(clientHost, 'client-host', serverTopology); - const { id } = await socketClientInitializer({ + const socketClient = await socketClientInitializer({ communication: clientCom, env: new Environment('server-host', 'node', 'single'), }); + disposables.add(() => socketClient.dispose()); + expect(socketClient.id).to.not.eq(undefined); - expect(id).to.not.eq(undefined); - - const host = clientCom.getEnvironmentHost(id); + const host = clientCom.getEnvironmentHost(socketClient.id); (host as WsClientHost).subscribers.on('disconnect', spy); await socketServer.close(); await waitFor( () => { expect(spy.callCount).to.be.eq(1); + expect(spy.firstCall.args[0]).to.be.a('string'); }, { timeout: 2_000, @@ -251,22 +255,26 @@ describe('Socket communication', () => { const { waitForCall: waitForClient1Call, spy: spyClient1 } = createWaitForCall<(ev: { data: Message }) => void>('client'); const clientHost1 = new WsClientHost(serverTopology['server-host']!); + disposables.add(() => clientHost1.dispose()); const clientHost2 = new WsClientHost(serverTopology['server-host']!); + const clientCom1 = new Communication(clientHost1, 'client-host1', serverTopology); const clientCom2 = new Communication(clientHost2, 'client-host2', serverTopology); new Communication(serverHost, 'server-host'); - await socketClientInitializer({ + const socketClient1 = await socketClientInitializer({ communication: clientCom1, env: { env: 'server-host', }, }); - await socketClientInitializer({ + disposables.add(() => socketClient1.dispose()); + const socketClient2 = await socketClientInitializer({ communication: clientCom2, env: { env: 'server-host', }, }); + disposables.add(() => socketClient2.dispose()); clientCom1.registerEnv('client-host2', clientCom1.getEnvironmentHost('server-host')!); serverHost.addEventListener('message', spyServer); clientHost1.addEventListener('message', spyClient1); @@ -284,6 +292,99 @@ describe('Socket communication', () => { expect(message.from).to.equal('server-host'); }); }); + + it('should handle client reconnection and cancel delayed dispose', async () => { + const COMMUNICATION_ID = 'reconnect-test'; + const { spy: disposeSpy } = createWaitForCall<(ev: { data: Message }) => void>('dispose'); + const { waitForCall: waitForConnect, spy: connectSpy } = createWaitForCall<() => void>('connect'); + + const clientCom = new Communication(clientHost, 'client-host', serverTopology); + const serverCom = new Communication(serverHost, 'server-host'); + + serverCom.registerAPI( + { id: COMMUNICATION_ID }, + { + sayHello: () => 'hello', + sayHelloWithDataAndParams: (name: string) => `hello ${name}`, + }, + ); + + // Test communication works + const methods = clientCom.apiProxy({ id: 'server-host' }, { id: COMMUNICATION_ID }); + expect(await methods.sayHello()).to.eql('hello'); + + // Listen for dispose & reconnect messages + serverHost.addEventListener('message', disposeSpy); + clientHost.subscribers.on('connect', connectSpy); + + // Disconnect and quickly reconnect (before dispose delay expires) + clientHost.disconnectSocket(); + expect(clientHost.isConnected()).to.eql(false); + + // Reconnect immediately (within the 10ms dispose delay) + clientHost.reconnectSocket(); + await waitForConnect(() => true); + + // Wait a bit more than the dispose delay to ensure dispose timer would have fired + await sleep(disposeGraceMs * 2); + + // Verify no dispose message was sent (since reconnection cancelled it) + expect(disposeSpy.callCount).to.eql(0); + + // Verify communication still works after reconnection + expect(await methods.sayHello()).to.eql('hello'); + expect(await methods.sayHelloWithDataAndParams('reconnected')).to.eq('hello reconnected'); + }); + + it('should emit dispose message if client does not reconnect within dispose delay', async () => { + const COMMUNICATION_ID = 'dispose-test'; + const { waitForCall: waitForDispose, spy: disposeSpy } = + createWaitForCall<(ev: { data: Message }) => void>('dispose'); + + const clientCom = new Communication(clientHost, 'client-host', serverTopology); + const serverCom = new Communication(serverHost, 'server-host'); + + serverCom.registerAPI( + { id: COMMUNICATION_ID }, + { + sayHello: () => 'hello', + sayHelloWithDataAndParams: (name: string) => `hello ${name}`, + }, + ); + + // Test communication works + const methods = clientCom.apiProxy({ id: 'server-host' }, { id: COMMUNICATION_ID }); + expect(await methods.sayHello()).to.eql('hello'); + + // Register dispose listener after initial communication + serverHost.addEventListener('message', disposeSpy); + + // Disconnect without reconnecting + clientHost.disconnectSocket(); + expect(clientHost.isConnected()).to.eql(false); + + // Wait for dispose message to be emitted after delay + await waitForDispose(([arg]) => { + const message = arg.data as DisposeMessage; + expect(message.type).to.eql('dispose'); + expect(message.from).to.include('/client-host'); + expect(message.origin).to.include('/client-host'); + return true; + }); + + // Verify that connection can be established again after dispose + const { waitForCall: waitForConnect, spy: connectSpy } = + createWaitForCall<() => void>('reconnect-after-dispose'); + clientHost.subscribers.on('connect', connectSpy); + + clientHost.reconnectSocket(); + await waitForConnect(() => true); + expect(clientHost.isConnected()).to.eql(true); + + // Verify communication works again after reconnection + expect(await methods.sayHello()).to.eql('hello'); + expect(await methods.sayHelloWithDataAndParams('after-dispose')).to.eq('hello after-dispose'); + }); }); describe('IPC communication', () => { diff --git a/packages/runtime-node/test/node-env.manager.unit.ts b/packages/runtime-node/test/node-env.manager.unit.ts index d63c61406..280d31d0d 100644 --- a/packages/runtime-node/test/node-env.manager.unit.ts +++ b/packages/runtime-node/test/node-env.manager.unit.ts @@ -12,6 +12,14 @@ import { runEnv as runAEnv } from '../test-kit/entrypoints/a.node.js'; import testFeature from '../test-kit/feature/test-feature.js'; describe('NodeEnvManager', () => { + const disposables = new Set<() => Promise | void>(); + afterEach(async () => { + for (const dispose of Array.from(disposables).reverse()) { + await dispose(); + } + disposables.clear(); + }); + const meta = { url: import.meta.resolve('../test-kit/entrypoints/') }; const testCommunicationId = 'test'; @@ -39,14 +47,11 @@ describe('NodeEnvManager', () => { }; manager = new NodeEnvManager(meta, featureEnvironmentsMapping); + disposables.add(() => manager.dispose()); const { port } = await manager.autoLaunch(new Map([['feature', 'test-feature']])); nodeEnvsPort = port; communication = getClientCom(port); - }); - - afterEach(async () => { - await communication.dispose(); - await manager.dispose(); + disposables.add(() => communication.dispose()); }); it('should reach env "a"', async () => { @@ -68,7 +73,9 @@ describe('NodeEnvManager', () => { it('should handle two communication with the same', async () => { // setup new com instance with the same id const communication2 = new Communication(new BaseHost(), testCommunicationId); + disposables.add(() => communication2.dispose()); const host = new WsClientHost('http://localhost:' + nodeEnvsPort, {}); + disposables.add(() => host.dispose()); communication2.registerEnv(aEnv.env, host); communication2.registerEnv(bEnv.env, host); @@ -85,20 +92,21 @@ describe('NodeEnvManager', () => { }); describe('NodeEnvManager with 2 node envs, one remote the other in a worker thread', () => { - let closeEnvA: () => Promise; let nodeEnvsManager: NodeEnvManager; let communication: Communication; beforeEach(async () => { const { port: aPort, socketServer, close } = await launchEngineHttpServer(); - closeEnvA = close; + disposables.add(() => close()); + const wsServerHost = new WsServerHost(socketServer); + disposables.add(() => wsServerHost.dispose()); await runAEnv({ Feature: testFeature, topLevelConfig: [ COM.configure({ config: { - host: new WsServerHost(socketServer), + host: wsServerHost, id: aEnv.env, }, }), @@ -125,13 +133,10 @@ describe('NodeEnvManager', () => { }; nodeEnvsManager = new NodeEnvManager(meta, featureEnvironmentsMapping); + disposables.add(() => nodeEnvsManager.dispose()); const { port } = await nodeEnvsManager.autoLaunch(new Map([['feature', 'test-feature']])); communication = getClientCom(port); - }); - afterEach(async () => { - await communication.dispose(); - await closeEnvA(); - await nodeEnvsManager.dispose(); + disposables.add(() => communication.dispose()); }); it('should reach env "a"', async () => {