Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions packages/core/src/com/hosts/ws-client-host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ export class WsClientHost extends BaseHost implements IDisposable {
isDisposed = this.disposables.isDisposed;
public connected: Promise<void>;
private socketClient: Socket;
public subscribers = new EventEmitter<{ disconnect: void; reconnect: void }>();
public subscribers = new EventEmitter<{ disconnect: string; reconnect: void; connect: void }>();
private stableClientId = crypto.randomUUID();

constructor(url: string, options?: Partial<SocketOptions>) {
super();
Expand All @@ -24,10 +25,12 @@ export class WsClientHost extends BaseHost implements IDisposable {

this.socketClient = io(url, {
transports: ['websocket'],
forceNew: true,
withCredentials: true, // Pass Cookie to socket io connection
path,
query,
auth: {
clientId: this.stableClientId,
},
...options,
});

Expand All @@ -36,15 +39,16 @@ export class WsClientHost extends BaseHost implements IDisposable {
});

this.socketClient.on('connect', () => {
this.socketClient.on('message', (data: unknown) => {
this.emitMessageHandlers(data as Message);
});
this.subscribers.emit('connect', undefined);
resolve();
});

this.socketClient.on('disconnect', () => {
this.subscribers.emit('disconnect', undefined);
this.socketClient.close();
this.socketClient.on('message', (data: unknown) => {
this.emitMessageHandlers(data as Message);
});

this.socketClient.on('disconnect', (reason: string) => {
this.subscribers.emit('disconnect', reason);
});

this.socketClient.on('reconnect', () => {
Expand All @@ -57,4 +61,18 @@ export class WsClientHost extends BaseHost implements IDisposable {
public postMessage(data: any) {
this.socketClient.emit('message', data);
}

disconnectSocket() {
if (this.socketClient.connected) {
this.socketClient.disconnect();
}
}
reconnectSocket() {
if (!this.socketClient.connected) {
this.socketClient.connect();
}
}
isConnected(): boolean {
return this.socketClient.connected;
}
}
145 changes: 113 additions & 32 deletions packages/runtime-node/src/ws-node-host.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,66 +14,147 @@ export class WsHost extends BaseHost {
}
}

type ClientEnvId = string;
type ClientId = string;

export class WsServerHost extends BaseHost implements IDisposable {
private socketToEnvId = new Map<string, { socket: io.Socket; clientID: string }>();
private clients = new Map<
ClientId,
{
socket: io.Socket;
namespacedEnvIds: Set<ClientEnvId>;
disposeTimer?: NodeJS.Timeout;
}
>();
private disposables = new SafeDisposable(WsServerHost.name);
dispose = this.disposables.dispose;
isDisposed = this.disposables.isDisposed;
private disposeGraceMs: number;

constructor(private server: io.Server | io.Namespace) {
constructor(
private server: io.Server | io.Namespace,
config: { disposeGraceMs?: number } = {},
) {
super();
this.disposeGraceMs = config.disposeGraceMs ?? 120_000;
this.server.on('connection', this.onConnection);
this.disposables.add('connection', () => this.server.off('connection', this.onConnection));
this.disposables.add('clear handlers', () => this.handlers.clear());
this.disposables.add('dispose clients', () => {
// clear pending dispose timers and emit dispose messages for all env IDs
for (const client of this.clients.values()) {
if (client.disposeTimer) {
clearTimeout(client.disposeTimer);
}
this.emitDisposeMessagesForClient(client.namespacedEnvIds);
}
this.clients.clear();
});
}

private extractClientIdAndEnvId(namespacedId: string): { stableClientId: string; envId: string } | undefined {
const slashIndex = namespacedId.indexOf('/');
if (slashIndex === -1) {
return undefined;
}
return {
stableClientId: namespacedId.slice(0, slashIndex),
envId: namespacedId.slice(slashIndex + 1),
};
}

private emitDisposeMessagesForClient(namespacedEnvIds: Set<ClientEnvId>): void {
for (const envId of namespacedEnvIds) {
this.emitMessageHandlers({
type: 'dispose',
from: envId,
origin: envId,
to: '*',
forwardingChain: [],
});
}
}

public postMessage(data: Message) {
if (data.to !== '*') {
if (this.socketToEnvId.has(data.to)) {
const { socket, clientID } = this.socketToEnvId.get(data.to)!;
data.to = clientID;
socket.emit('message', data);
} else {
this.server.emit('message', data);
const parsed = this.extractClientIdAndEnvId(data.to);
if (parsed) {
const client = this.clients.get(parsed.stableClientId);

if (client) {
data.to = parsed.envId;
client.socket.emit('message', data);
return;
}
}
// If not found in any client, broadcast
this.server.emit('message', data);
} else {
this.server.emit('message', data);
}
}

private onConnection = (socket: io.Socket): void => {
const nameSpace = (original: string) => `${socket.id}/${original}`;
const clientId = socket.handshake.auth?.clientId;
if (!clientId) {
socket.disconnect(true);
return;
}

// Handle reconnection: update socket and clear dispose timer
const existingClient = this.clients.get(clientId);
if (existingClient) {
// Clear dispose timer if exists
if (existingClient.disposeTimer) {
clearTimeout(existingClient.disposeTimer);
existingClient.disposeTimer = undefined;
}

// remove old socket listeners
existingClient.socket.removeAllListeners();
// Update socket reference
existingClient.socket = socket;
} else {
// New connection: create client entry
this.clients.set(clientId, {
socket,
namespacedEnvIds: new Set(),
});
}

const onMessage = (message: Message): void => {
// this mapping should not be here because of forwarding of messages
// maybe change message forwarding to have 'forward destination' and correct 'from'
// also maybe we can put the init of the map on 'connection' event
// maybe we can notify from client about the new connected id
const originId = nameSpace(message.origin);
const fromId = nameSpace(message.from);
this.socketToEnvId.set(fromId, { socket, clientID: message.from });
this.socketToEnvId.set(originId, { socket, clientID: message.origin });
// modify message to be able to forward it
message.from = fromId;
message.origin = originId;
const client = this.clients.get(clientId);
if (!client) return;
// Namespace the env IDs with stableClientId to differentiate between clients
const namespacedFrom = `${clientId}/${message.from}`;
const namespacedOrigin = `${clientId}/${message.origin}`;

// Track namespaced env IDs for this client
client.namespacedEnvIds.add(namespacedFrom);
client.namespacedEnvIds.add(namespacedOrigin);

// Modify message with namespaced IDs for routing
message.from = namespacedFrom;
message.origin = namespacedOrigin;

this.emitMessageHandlers(message);
};
socket.on('message', onMessage);

socket.once('disconnect', () => {
socket.off('message', onMessage);
for (const [envId, { socket: soc }] of this.socketToEnvId.entries()) {
if (socket === soc) {
this.socketToEnvId.delete(envId);
this.emitMessageHandlers({
type: 'dispose',
from: envId,
origin: envId,
to: '*',
forwardingChain: [],
});
}
}

const client = this.clients.get(clientId);
if (!client) return;

// Delay dispose to allow for socket recovery
client.disposeTimer = setTimeout(() => {
const clientToDispose = this.clients.get(clientId);
if (!clientToDispose) return;

this.clients.delete(clientId);
this.emitDisposeMessagesForClient(clientToDispose.namespacedEnvIds);
}, this.disposeGraceMs);
});
};
}
Loading