Skip to content

Commit f1abc83

Browse files
committed
Reuse transport and signaling client for dialing with webrtc
1 parent c234c2e commit f1abc83

File tree

2 files changed

+144
-56
lines changed

2 files changed

+144
-56
lines changed

src/rpc/__tests__/dial.spec.ts

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ import {
2727
} from '../../__mocks__/webrtc';
2828
import { withICEServers } from '../__fixtures__/dial-webrtc-options';
2929
import { createMockTransport } from '../../__mocks__/transports';
30-
import { createMockSignalingExchange } from '../__mocks__/signaling-exchanges';
3130
import { ClientChannel } from '../client-channel';
31+
import type { Transport } from '@connectrpc/connect';
3232

3333
vi.mock('../peer');
3434
vi.mock('../signaling-exchange');
@@ -52,28 +52,33 @@ const setupDialWebRTCMocks = () => {
5252
const peerConnection = createMockPeerConnection();
5353
const dataChannel = createMockDataChannel();
5454
const transport = createMockTransport();
55-
const signalingExchange = createMockSignalingExchange(transport);
5655

5756
vi.mocked(newPeerConnectionForClient).mockResolvedValue({
5857
pc: peerConnection,
5958
dc: dataChannel,
6059
});
6160

62-
vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange);
63-
6461
const optionalWebRTCConfigFn = vi.fn().mockResolvedValue({
6562
config: {
6663
additionalIceServers: [],
6764
disableTrickle: false,
6865
},
6966
});
7067

71-
vi.mocked(createClient).mockReturnValue({
68+
const mockClient = {
7269
optionalWebRTCConfig: optionalWebRTCConfigFn,
73-
} as unknown as ReturnType<typeof createClient>);
70+
} as unknown as ReturnType<typeof createClient>;
7471

72+
vi.mocked(createClient).mockReturnValue(mockClient);
7573
vi.mocked(createGrpcWebTransport).mockReturnValue(transport);
7674

75+
const signalingExchange = {
76+
doExchange: vi.fn().mockResolvedValue(transport),
77+
terminate: vi.fn(),
78+
} as unknown as SignalingExchange;
79+
80+
vi.mocked(SignalingExchange).mockImplementation(() => signalingExchange);
81+
7782
return {
7883
peerConnection,
7984
dataChannel,
@@ -207,21 +212,18 @@ describe('dialWebRTC', () => {
207212
expect(vi.mocked(peerConnection.close)).toHaveBeenCalled();
208213
});
209214

210-
it('should close peer connection if dialDirect fails', async () => {
215+
it('should propagate error if transport creation fails', async () => {
211216
// Arrange
212-
const { peerConnection, transport } = setupDialWebRTCMocks();
213-
// First call succeeds (getOptionalWebRTCConfig), second call fails (signaling)
214-
vi.mocked(createGrpcWebTransport)
215-
.mockReturnValueOnce(transport)
216-
.mockImplementationOnce(() => {
217-
throw new Error('Transport creation failed');
218-
});
217+
setupDialWebRTCMocks();
218+
vi.mocked(createGrpcWebTransport).mockImplementation(() => {
219+
throw new Error('Transport creation failed');
220+
});
219221

220222
// Act & Assert
221223
await expect(dialWebRTC(TEST_URL, TEST_HOST)).rejects.toThrow(
222224
'Transport creation failed'
223225
);
224-
expect(vi.mocked(peerConnection.close)).toHaveBeenCalled();
226+
expect(newPeerConnectionForClient).not.toHaveBeenCalled();
225227
});
226228

227229
it('should rethrow errors after cleanup', async () => {
@@ -327,6 +329,103 @@ describe('validateDialOptions', () => {
327329
});
328330
});
329331

332+
describe('resource management', () => {
333+
it('should reuse a single transport for config fetching and signaling', async () => {
334+
// Arrange
335+
setupDialWebRTCMocks();
336+
337+
// Act
338+
await dialWebRTC(TEST_URL, TEST_HOST);
339+
340+
// Assert
341+
expect(createGrpcWebTransport).toHaveBeenCalledTimes(1);
342+
expect(createGrpcWebTransport).toHaveBeenCalledWith({
343+
baseUrl: TEST_URL,
344+
credentials: 'same-origin',
345+
});
346+
});
347+
348+
it('should reuse a single signaling client for config fetching and signaling', async () => {
349+
// Arrange
350+
setupDialWebRTCMocks();
351+
352+
// Act
353+
await dialWebRTC(TEST_URL, TEST_HOST);
354+
355+
// Assert
356+
expect(createClient).toHaveBeenCalledTimes(1);
357+
expect(createClient).toHaveBeenCalledWith(
358+
expect.anything(),
359+
expect.anything()
360+
);
361+
});
362+
363+
it('should not leak transports on successful connection', async () => {
364+
// Arrange
365+
const { transport } = setupDialWebRTCMocks();
366+
const transportCount = { created: 0 };
367+
368+
vi.mocked(createGrpcWebTransport).mockImplementation(() => {
369+
transportCount.created += 1;
370+
return transport;
371+
});
372+
373+
// Act
374+
await dialWebRTC(TEST_URL, TEST_HOST);
375+
376+
// Assert
377+
expect(transportCount.created).toBe(1);
378+
});
379+
380+
it('should not leak transports on connection failure', async () => {
381+
// Arrange
382+
const { transport, signalingExchange } = setupDialWebRTCMocks();
383+
const transportCount = { created: 0 };
384+
385+
vi.mocked(createGrpcWebTransport).mockImplementation(() => {
386+
transportCount.created += 1;
387+
return transport;
388+
});
389+
390+
const error = new Error('Connection failed');
391+
vi.mocked(signalingExchange.doExchange).mockRejectedValueOnce(error);
392+
393+
// Act
394+
await dialWebRTC(TEST_URL, TEST_HOST).catch(() => {
395+
// Ignore error for this test
396+
});
397+
398+
// Assert
399+
expect(transportCount.created).toBe(1);
400+
});
401+
402+
it('should use the same transport reference for both config and signaling', async () => {
403+
// Arrange
404+
setupDialWebRTCMocks();
405+
const capturedTransports: Transport[] = [];
406+
407+
vi.mocked(createClient).mockImplementation(
408+
(_service, capturedTransport) => {
409+
capturedTransports.push(capturedTransport);
410+
return {
411+
optionalWebRTCConfig: vi.fn().mockResolvedValue({
412+
config: {
413+
additionalIceServers: [],
414+
disableTrickle: false,
415+
},
416+
}),
417+
} as unknown as ReturnType<typeof createClient>;
418+
}
419+
);
420+
421+
// Act
422+
await dialWebRTC(TEST_URL, TEST_HOST);
423+
424+
// Assert
425+
expect(capturedTransports.length).toBe(1);
426+
});
427+
});
428+
330429
describe('dialDirect', () => {
331430
afterEach(() => {
332431
vi.restoreAllMocks();

src/rpc/dial.ts

Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -309,20 +309,24 @@ export interface WebRTCConnection {
309309
dataChannel: RTCDataChannel;
310310
}
311311

312-
const getOptionalWebRTCConfig = async (
312+
const getSignalingClient = async (
313313
signalingAddress: string,
314-
callOpts: CallOptions,
315-
dialOpts?: DialOptions,
314+
signalingExchangeOpts: DialOptions | undefined,
316315
transportCredentialsInclude = false
317-
): Promise<WebRTCConfig> => {
318-
const optsCopy = { ...dialOpts } as DialOptions;
319-
const directTransport = await dialDirect(
316+
) => {
317+
const transport = await dialDirect(
320318
signalingAddress,
321-
optsCopy,
319+
signalingExchangeOpts,
322320
transportCredentialsInclude
323321
);
324322

325-
const signalingClient = createClient(SignalingService, directTransport);
323+
return createClient(SignalingService, transport);
324+
};
325+
326+
const getOptionalWebRTCConfig = async (
327+
callOpts: CallOptions,
328+
signalingClient: ReturnType<typeof createClient<typeof SignalingService>>
329+
): Promise<WebRTCConfig> => {
326330
try {
327331
const resp = await signalingClient.optionalWebRTCConfig({}, callOpts);
328332
return resp.config ?? new WebRTCConfig();
@@ -363,18 +367,25 @@ export const dialWebRTC = async (
363367
};
364368

365369
/**
366-
* First complete our WebRTC options, gathering any extra information like
367-
* TURN servers from a cloud server.
370+
* First, derive options specifically for signaling against our target. Then
371+
* complete our WebRTC options, gathering any extra information like TURN
372+
* servers from a cloud server. This also creates the transport and signaling
373+
* client that we'll reuse to avoid resource leaks.
368374
*/
369-
const webrtcOpts = await processWebRTCOpts(
375+
const exchangeOpts = processSignalingExchangeOpts(
370376
usableSignalingAddress,
371-
callOpts,
372-
dialOpts,
373-
transportCredentialsInclude
377+
dialOpts
374378
);
375-
// then derive options specifically for signaling against our target.
376-
const exchangeOpts = processSignalingExchangeOpts(
379+
380+
const signalingClient = await getSignalingClient(
377381
usableSignalingAddress,
382+
exchangeOpts,
383+
transportCredentialsInclude
384+
);
385+
386+
const webrtcOpts = await processWebRTCOpts(
387+
signalingClient,
388+
callOpts,
378389
dialOpts
379390
);
380391

@@ -385,21 +396,6 @@ export const dialWebRTC = async (
385396
);
386397
let successful = false;
387398

388-
let directTransport: Transport;
389-
try {
390-
directTransport = await dialDirect(
391-
usableSignalingAddress,
392-
exchangeOpts,
393-
transportCredentialsInclude
394-
);
395-
} catch (error) {
396-
pc.close();
397-
dc.close();
398-
throw error;
399-
}
400-
401-
const signalingClient = createClient(SignalingService, directTransport);
402-
403399
const exchange = new SignalingExchange(
404400
signalingClient,
405401
callOpts,
@@ -453,18 +449,11 @@ export const dialWebRTC = async (
453449
};
454450

455451
const processWebRTCOpts = async (
456-
signalingAddress: string,
452+
signalingClient: ReturnType<typeof createClient<typeof SignalingService>>,
457453
callOpts: CallOptions,
458-
dialOpts?: DialOptions,
459-
transportCredentialsInclude = false
454+
dialOpts: DialOptions | undefined
460455
): Promise<DialWebRTCOptions> => {
461-
// Get TURN servers, if any.
462-
const config = await getOptionalWebRTCConfig(
463-
signalingAddress,
464-
callOpts,
465-
dialOpts,
466-
transportCredentialsInclude
467-
);
456+
const config = await getOptionalWebRTCConfig(callOpts, signalingClient);
468457
const additionalIceServers: RTCIceServer[] = config.additionalIceServers.map(
469458
(ice) => {
470459
const iceUrls = [];

0 commit comments

Comments
 (0)