diff --git a/src/rpc/RPCServer.ts b/src/rpc/RPCServer.ts index fe357dec13..e81413e8be 100644 --- a/src/rpc/RPCServer.ts +++ b/src/rpc/RPCServer.ts @@ -196,6 +196,13 @@ class RPCServer extends EventTarget { connectionInfo, ctx, ) => { + // Setting up abort controller + const abortController = new AbortController(); + if (ctx.signal.aborted) abortController.abort(ctx.signal.reason); + ctx.signal.addEventListener('abort', () => { + abortController.abort(ctx.signal.reason); + }); + const signal = abortController.signal; // Setting up middleware const middleware = this.middlewareFactory(); // Forward from the client to the server @@ -214,14 +221,14 @@ class RPCServer extends EventTarget { const reverseStream = middleware.reverse.writable; // Generator derived from handler const outputGen = async function* (): AsyncGenerator { - if (ctx.signal.aborted) throw ctx.signal.reason; + if (signal.aborted) throw signal.reason; // Input generator derived from the forward stream const inputGen = async function* (): AsyncIterable { for await (const data of forwardStream) { yield data.params as I; } }; - const handlerG = handler(inputGen(), connectionInfo, ctx); + const handlerG = handler(inputGen(), connectionInfo, { signal }); for await (const response of handlerG) { const responseMessage: JSONRPCResponseResult = { jsonrpc: '2.0', @@ -271,6 +278,8 @@ class RPCServer extends EventTarget { ), }), ); + // Abort with the reason + abortController.abort(reason); // If the output stream path fails then we need to end the generator // early. await outputGenerator.return(undefined); diff --git a/src/rpc/handlers.ts b/src/rpc/handlers.ts index af4aed4e08..48714b471a 100644 --- a/src/rpc/handlers.ts +++ b/src/rpc/handlers.ts @@ -31,6 +31,11 @@ abstract class DuplexHandler< Input extends JSONValue = JSONValue, Output extends JSONValue = JSONValue, > extends Handler { + /** + * Note that if the output has an error, the handler will not see this as an + * error. If you need to handle any clean up it should be handled in a + * `finally` block and check the abort signal for potential errors. + */ abstract handle( input: AsyncIterable, connectionInfo: ConnectionInfo, diff --git a/tests/rpc/RPCServer.test.ts b/tests/rpc/RPCServer.test.ts index d1560c62c6..083f947597 100644 --- a/tests/rpc/RPCServer.test.ts +++ b/tests/rpc/RPCServer.test.ts @@ -506,13 +506,15 @@ describe(`${RPCServer.name}`, () => { }, { numRuns: 1 }, ); - testProp.only( + testProp( 'should emit stream error if output stream fails', [specificMessageArb], async (messages) => { const handlerEndedProm = promise(); + let ctx: ContextCancellable | undefined; class TestMethod extends DuplexHandler { - public async *handle(input): AsyncIterable { + public async *handle(input, _, _ctx): AsyncIterable { + ctx = _ctx; // Echo input try { yield* input; @@ -564,6 +566,10 @@ describe(`${RPCServer.name}`, () => { expect(event.detail.cause).toBe(readerReason); // Check that the handler was cleaned up. await expect(handlerEndedProm.p).toResolve(); + // Check that an abort signal happened + expect(ctx).toBeDefined(); + expect(ctx?.signal.aborted).toBeTrue(); + expect(ctx?.signal.reason).toBe(readerReason); await rpcServer.destroy(); }, { numRuns: 1 },