@@ -34,7 +34,6 @@ import {
3434 collectFields ,
3535 createGraphQLError ,
3636 fakePromise ,
37- getAbortPromise ,
3837 getArgumentValues ,
3938 getDefinedRootType ,
4039 GraphQLResolveInfo ,
@@ -52,11 +51,10 @@ import {
5251 Path ,
5352 pathToArray ,
5453 promiseReduce ,
55- registerAbortSignalListener ,
5654} from '@graphql-tools/utils' ;
5755import { TypedDocumentNode } from '@graphql-typed-document-node/core' ;
5856import { DisposableSymbols } from '@whatwg-node/disposablestack' ;
59- import { handleMaybePromise } from '@whatwg-node/promise-helpers' ;
57+ import { createDeferredPromise , handleMaybePromise } from '@whatwg-node/promise-helpers' ;
6058import { coerceError } from './coerceError.js' ;
6159import { flattenAsyncIterable } from './flattenAsyncIterable.js' ;
6260import { invariant } from './invariant.js' ;
@@ -127,6 +125,8 @@ export interface ExecutionContext<TVariables = any, TContext = any> {
127125 errors : Array < GraphQLError > ;
128126 subsequentPayloads : Set < AsyncPayloadRecord > ;
129127 signal ?: AbortSignal ;
128+ onSignalAbort ?( handler : ( ) => void ) : void ;
129+ signalPromise ?: Promise < never > ;
130130}
131131
132132export interface FormattedExecutionResult <
@@ -421,6 +421,8 @@ export function buildExecutionContext<TData = any, TVariables = any, TContext =
421421 signal,
422422 } = args ;
423423
424+ signal ?. throwIfAborted ( ) ;
425+
424426 // If the schema used for execution is invalid, throw an error.
425427 assertValidSchema ( schema ) ;
426428
@@ -489,6 +491,31 @@ export function buildExecutionContext<TData = any, TVariables = any, TContext =
489491 return coercedVariableValues . errors ;
490492 }
491493
494+ signal ?. throwIfAborted ( ) ;
495+
496+ let onSignalAbort : ExecutionContext [ 'onSignalAbort' ] ;
497+ let signalPromise : ExecutionContext [ 'signalPromise' ] ;
498+
499+ if ( signal ) {
500+ const listeners = new Set < ( ) => void > ( ) ;
501+ const signalDeferred = createDeferredPromise < never > ( ) ;
502+ signalPromise = signalDeferred . promise ;
503+ const sharedListener = ( ) => {
504+ signalDeferred . reject ( signal . reason ) ;
505+ signal . removeEventListener ( 'abort' , sharedListener ) ;
506+ } ;
507+ signal . addEventListener ( 'abort' , sharedListener , { once : true } ) ;
508+ signalPromise . catch ( ( ) => {
509+ for ( const listener of listeners ) {
510+ listener ( ) ;
511+ }
512+ listeners . clear ( ) ;
513+ } ) ;
514+ onSignalAbort = handler => {
515+ listeners . add ( handler ) ;
516+ } ;
517+ }
518+
492519 return {
493520 schema,
494521 fragments,
@@ -502,6 +529,8 @@ export function buildExecutionContext<TData = any, TVariables = any, TContext =
502529 subsequentPayloads : new Set ( ) ,
503530 errors : [ ] ,
504531 signal,
532+ onSignalAbort,
533+ signalPromise,
505534 } ;
506535}
507536
@@ -626,9 +655,9 @@ function executeFields(
626655 }
627656 }
628657 } catch ( error ) {
629- if ( containsPromise ) {
658+ if ( error !== exeContext . signal ?. reason && containsPromise ) {
630659 // Ensure that any promises returned by other fields are handled, as they may also reject.
631- return promiseForObject ( results , exeContext . signal ) . finally ( ( ) => {
660+ return promiseForObject ( results , exeContext . signal , exeContext . signalPromise ) . finally ( ( ) => {
632661 throw error ;
633662 } ) ;
634663 }
@@ -643,7 +672,7 @@ function executeFields(
643672 // Otherwise, results is a map from field name to the result of resolving that
644673 // field, which is possibly a promise. Return a promise that will return this
645674 // same map, but with any promises replaced with the values they resolved to.
646- return promiseForObject ( results , exeContext . signal ) ;
675+ return promiseForObject ( results , exeContext . signal , exeContext . signalPromise ) ;
647676}
648677
649678/**
@@ -673,6 +702,7 @@ function executeField(
673702
674703 // Get the resolve function, regardless of if its result is normal or abrupt (error).
675704 try {
705+ exeContext . signal ?. throwIfAborted ( ) ;
676706 // Build a JS object of arguments from the field.arguments AST, using the
677707 // variables scope to fulfill any variable references.
678708 // TODO: find a way to memoize, in case this field is within a List type.
@@ -967,8 +997,9 @@ async function completeAsyncIteratorValue(
967997 iterator : AsyncIterator < unknown > ,
968998 asyncPayloadRecord ?: AsyncPayloadRecord ,
969999) : Promise < ReadonlyArray < unknown > > {
970- if ( exeContext . signal && iterator . return ) {
971- registerAbortSignalListener ( exeContext . signal , ( ) => {
1000+ exeContext . signal ?. throwIfAborted ( ) ;
1001+ if ( iterator . return ) {
1002+ exeContext . onSignalAbort ?.( ( ) => {
9721003 iterator . return ?.( ) ;
9731004 } ) ;
9741005 }
@@ -1746,18 +1777,25 @@ function executeSubscription(exeContext: ExecutionContext): MaybePromise<AsyncIt
17461777 const result = resolveFn ( rootValue , args , contextValue , info ) ;
17471778
17481779 if ( isPromise ( result ) ) {
1749- return result . then ( assertEventStream ) . then ( undefined , error => {
1750- throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
1751- } ) ;
1780+ return result
1781+ . then ( result => assertEventStream ( result , exeContext . signal , exeContext . onSignalAbort ) )
1782+ . then ( undefined , error => {
1783+ throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
1784+ } ) ;
17521785 }
17531786
1754- return assertEventStream ( result , exeContext . signal ) ;
1787+ return assertEventStream ( result , exeContext . signal , exeContext . onSignalAbort ) ;
17551788 } catch ( error ) {
17561789 throw locatedError ( error , fieldNodes , pathToArray ( path ) ) ;
17571790 }
17581791}
17591792
1760- function assertEventStream ( result : unknown , signal ?: AbortSignal ) : AsyncIterable < unknown > {
1793+ function assertEventStream (
1794+ result : unknown ,
1795+ signal ?: AbortSignal ,
1796+ onSignalAbort ?: ( handler : ( ) => void ) => void ,
1797+ ) : AsyncIterable < unknown > {
1798+ signal ?. throwIfAborted ( ) ;
17611799 if ( result instanceof Error ) {
17621800 throw result ;
17631801 }
@@ -1768,13 +1806,13 @@ function assertEventStream(result: unknown, signal?: AbortSignal): AsyncIterable
17681806 'Subscription field must return Async Iterable. ' + `Received: ${ inspect ( result ) } .` ,
17691807 ) ;
17701808 }
1771- if ( signal ) {
1809+ if ( onSignalAbort ) {
17721810 return {
17731811 [ Symbol . asyncIterator ] ( ) {
17741812 const asyncIterator = result [ Symbol . asyncIterator ] ( ) ;
17751813
17761814 if ( asyncIterator . return ) {
1777- registerAbortSignalListener ( signal , ( ) => {
1815+ onSignalAbort ?. ( ( ) => {
17781816 asyncIterator . return ?.( ) ;
17791817 } ) ;
17801818 }
@@ -2101,8 +2139,6 @@ function yieldSubsequentPayloads(
21012139) : AsyncGenerator < SubsequentIncrementalExecutionResult , void , void > {
21022140 let isDone = false ;
21032141
2104- const abortPromise = exeContext . signal ? getAbortPromise ( exeContext . signal ) : undefined ;
2105-
21062142 async function next ( ) : Promise < IteratorResult < SubsequentIncrementalExecutionResult , void > > {
21072143 if ( isDone ) {
21082144 return { value : undefined , done : true } ;
@@ -2112,8 +2148,8 @@ function yieldSubsequentPayloads(
21122148 record => record . promise ,
21132149 ) ;
21142150
2115- if ( abortPromise ) {
2116- await Promise . race ( [ abortPromise , ...subSequentPayloadPromises ] ) ;
2151+ if ( exeContext . signalPromise ) {
2152+ await Promise . race ( [ exeContext . signalPromise , ...subSequentPayloadPromises ] ) ;
21172153 } else {
21182154 await Promise . race ( subSequentPayloadPromises ) ;
21192155 }
0 commit comments