From fa0caa3bc688844af2c54268a4b5537c169a9ae3 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Wed, 4 Sep 2024 17:41:24 -0700 Subject: [PATCH 1/5] add tests for handlers cleanly exiting --- __tests__/cancellation.test.ts | 217 ++++++++++++++++++++++++++++++++- router/server.ts | 20 ++- 2 files changed, 230 insertions(+), 7 deletions(-) diff --git a/__tests__/cancellation.test.ts b/__tests__/cancellation.test.ts index 5dd79ccb..45904d95 100644 --- a/__tests__/cancellation.test.ts +++ b/__tests__/cancellation.test.ts @@ -39,7 +39,222 @@ function makeMockHandler( >(impl); } -describe.each(testMatrix(['ws', 'naive']))( +describe.each(testMatrix())( + 'clean handler cancellation ($transport.name transport, $codec.name codec)', + + async ({ transport, codec }) => { + const opts = { codec: codec.codec }; + + const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups(); + let getClientTransport: TestSetupHelpers['getClientTransport']; + let getServerTransport: TestSetupHelpers['getServerTransport']; + beforeEach(async () => { + const setup = await transport.setup({ client: opts, server: opts }); + getClientTransport = setup.getClientTransport; + getServerTransport = setup.getServerTransport; + + return async () => { + await postTestCleanup(); + await setup.cleanup(); + }; + }); + + describe('e2e', () => { + test('rpc', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + rpc: Procedure.rpc({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx }) => { + signalReceiver(ctx.signal); + + return Ok({}); + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + await client.service.rpc.rpc({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(true); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('stream', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + stream: Procedure.stream({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx, resWritable }) => { + signalReceiver(ctx.signal); + + resWritable.write(Ok({})); + resWritable.close(); + + return; + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, resReadable } = client.service.stream.stream({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(false); + + reqWritable.close(); + await waitFor(() => expect(sig.aborted).toEqual(true)); + + // collect should resolve as the stream has been properly ended + await expect(resReadable.collect()).resolves.toEqual([Ok({})]); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('upload', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + upload: Procedure.upload({ + requestInit: Type.Object({}), + requestData: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx }) => { + signalReceiver(ctx.signal); + + return Ok({}); + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { reqWritable, finalize } = client.service.upload.upload({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(false); + + reqWritable.close(); + await waitFor(() => expect(sig.aborted).toEqual(true)); + + expect(await finalize()).toEqual(Ok({})); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + + test('subscribe', async () => { + const clientTransport = getClientTransport('client'); + const serverTransport = getServerTransport(); + addPostTestCleanup(async () => { + await cleanupTransports([clientTransport, serverTransport]); + }); + + const signalReceiver = vi.fn<(sig: AbortSignal) => void>(); + const services = { + service: ServiceSchema.define({ + subscribe: Procedure.subscription({ + requestInit: Type.Object({}), + responseData: Type.Object({}), + handler: async ({ ctx, resWritable }) => { + resWritable.close(); + signalReceiver(ctx.signal); + + return; + }, + }), + }), + }; + + const server = createServer(serverTransport, services); + const client = createClient( + clientTransport, + serverTransport.clientId, + ); + + const { resReadable } = client.service.subscribe.subscribe({}); + + await waitFor(() => { + expect(signalReceiver).toHaveBeenCalledTimes(1); + }); + + const [sig] = signalReceiver.mock.calls[0]; + expect(sig.aborted).toEqual(true); + await expect(resReadable.collect()).resolves.toEqual([]); + + await testFinishesCleanly({ + clientTransports: [clientTransport], + serverTransport, + server, + }); + }); + }); + }, +); + +describe.each(testMatrix())( 'client initiated cancellation ($transport.name transport, $codec.name codec)', async ({ transport, codec }) => { const opts = { codec: codec.codec }; diff --git a/router/server.ts b/router/server.ts index 15d7bfed..f94747aa 100644 --- a/router/server.ts +++ b/router/server.ts @@ -104,6 +104,7 @@ interface ProcStream { serviceName: string; sessionMetadata: ParsedMetadata; procedure: AnyProcedure; + endedSignal: AbortSignal; handleMsg: (msg: OpaqueTransportMessage) => void; handleSessionDisconnect: () => void; } @@ -199,7 +200,11 @@ class RiverServer ...newStreamProps, ...message, }); - this.streams.set(streamId, newStream); + + if (!newStream.endedSignal.aborted) { + // if the stream was immediately aborted, don't bother setting it up + this.streams.set(streamId, newStream); + } }; const handleSessionStatus = (evt: EventMap['sessionStatus']) => { @@ -369,6 +374,7 @@ class RiverServer }); }; + const finishedController = new AbortController(); const procStream: ProcStream = { from: from, streamId, @@ -377,6 +383,7 @@ class RiverServer sessionMetadata, procedure, handleMsg: onMessage, + endedSignal: finishedController.signal, handleSessionDisconnect: () => { cleanClose = false; const errPayload = { @@ -422,7 +429,6 @@ class RiverServer cancelStream(streamId, result); }; - const finishedController = new AbortController(); const cleanup = () => { finishedController.abort(); this.streams.delete(streamId); @@ -529,10 +535,6 @@ class RiverServer // only consists of an init message and we shouldn't expect follow up data if (procClosesWithInit) { closeReadable(); - } else if (procedure.type === 'rpc' || procedure.type === 'subscription') { - // Though things can work just fine if they eventually follow up with a stream - // control message with a close bit set, it's an unusual client implementation! - this.log?.warn('sent an init without a stream close', loggingMetadata); } const handlerContext: ProcedureHandlerContext = { @@ -570,6 +572,7 @@ class RiverServer } resWritable.write(responsePayload); + resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { @@ -593,6 +596,8 @@ class RiverServer reqReadable, resWritable, }); + + resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { @@ -616,6 +621,8 @@ class RiverServer reqInit: initPayload, resWritable: resWritable, }); + + resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { @@ -645,6 +652,7 @@ class RiverServer } resWritable.write(responsePayload); + resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { From 68fd71d22dc720b79504fdeb94677ac4549db517 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Wed, 4 Sep 2024 17:41:50 -0700 Subject: [PATCH 2/5] 0.200.3 --- package-lock.json | 4 ++-- package.json | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/package-lock.json b/package-lock.json index 5abceda7..fbbec3f9 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@replit/river", - "version": "0.200.2", + "version": "0.200.3", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "@replit/river", - "version": "0.200.2", + "version": "0.200.3", "license": "MIT", "dependencies": { "@msgpack/msgpack": "^3.0.0-beta2", diff --git a/package.json b/package.json index d3b757e3..8802fa1d 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "@replit/river", "description": "It's like tRPC but... with JSON Schema Support, duplex streaming and support for service multiplexing. Transport agnostic!", - "version": "0.200.2", + "version": "0.200.3", "type": "module", "exports": { ".": { From 01e86e5eb7c059d90518fb8ccdbee22a6ee5a183 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Wed, 4 Sep 2024 17:46:03 -0700 Subject: [PATCH 3/5] just set in createNewProcStream --- router/server.ts | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/router/server.ts b/router/server.ts index 1fdce056..23c78261 100644 --- a/router/server.ts +++ b/router/server.ts @@ -104,7 +104,6 @@ interface ProcStream { serviceName: string; sessionMetadata: ParsedMetadata; procedure: AnyProcedure; - endedSignal: AbortSignal; handleMsg: (msg: OpaqueTransportMessage) => void; handleSessionDisconnect: () => void; } @@ -196,15 +195,10 @@ class RiverServer } // if its not a cancelled stream, validate and create a new stream - const newStream = this.createNewProcStream({ + this.createNewProcStream({ ...newStreamProps, ...message, }); - - if (!newStream.endedSignal.aborted) { - // if the stream was immediately aborted, don't bother setting it up - this.streams.set(streamId, newStream); - } }; const handleSessionStatus = (evt: EventMap['sessionStatus']) => { @@ -383,7 +377,6 @@ class RiverServer sessionMetadata, procedure, handleMsg: onMessage, - endedSignal: finishedController.signal, handleSessionDisconnect: () => { cleanClose = false; const errPayload = { @@ -666,6 +659,8 @@ class RiverServer break; } + this.streams.set(streamId, procStream); + return procStream; } From 750160f4b7cef9dfcd964215d0f15bf69803be94 Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Wed, 4 Sep 2024 17:47:43 -0700 Subject: [PATCH 4/5] remove extraneous resWritable.close --- router/server.ts | 6 ------ 1 file changed, 6 deletions(-) diff --git a/router/server.ts b/router/server.ts index 23c78261..550c48ce 100644 --- a/router/server.ts +++ b/router/server.ts @@ -567,7 +567,6 @@ class RiverServer } resWritable.write(responsePayload); - resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { @@ -591,8 +590,6 @@ class RiverServer reqReadable, resWritable, }); - - resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { @@ -616,8 +613,6 @@ class RiverServer reqInit: initPayload, resWritable: resWritable, }); - - resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { @@ -647,7 +642,6 @@ class RiverServer } resWritable.write(responsePayload); - resWritable.close(); } catch (err) { onHandlerError(err, span); } finally { From c3ecbb799d06faad4c14379931338d2ddc948b8f Mon Sep 17 00:00:00 2001 From: Jacky Zhao Date: Wed, 4 Sep 2024 17:49:03 -0700 Subject: [PATCH 5/5] forgor to check signal --- router/server.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/router/server.ts b/router/server.ts index 550c48ce..c181d856 100644 --- a/router/server.ts +++ b/router/server.ts @@ -234,7 +234,7 @@ class RiverServer this.transport.addEventListener('transportStatus', handleTransportStatus); } - private createNewProcStream(props: StreamInitProps): ProcStream { + private createNewProcStream(props: StreamInitProps) { const { streamId, initialSession, @@ -653,9 +653,9 @@ class RiverServer break; } - this.streams.set(streamId, procStream); - - return procStream; + if (!finishedController.signal.aborted) { + this.streams.set(streamId, procStream); + } } private getContext(service: AnyService, serviceName: string) {