diff --git a/src/execution/__tests__/stream-test.ts b/src/execution/__tests__/stream-test.ts index eebdf3ff640..d951d0d5e9a 100644 --- a/src/execution/__tests__/stream-test.ts +++ b/src/execution/__tests__/stream-test.ts @@ -1,6 +1,7 @@ import { expect } from 'chai'; import { describe, it } from 'mocha'; +import { invariant } from '../../jsutils/invariant'; import { isAsyncIterable } from '../../jsutils/isAsyncIterable'; import { parse } from '../../language/parser'; @@ -74,6 +75,36 @@ const query = new GraphQLObjectType({ yield await Promise.resolve({}); }, }, + asyncIterableListDelayed: { + type: new GraphQLList(friendType), + async *resolve() { + for (const friend of friends) { + // pause an additional ms before yielding to allow time + // for tests to return or throw before next value is processed. + // eslint-disable-next-line no-await-in-loop + await new Promise((r) => setTimeout(r, 1)); + yield friend; + } + }, + }, + asyncIterableListNoReturn: { + type: new GraphQLList(friendType), + resolve() { + let i = 0; + return { + [Symbol.asyncIterator]: () => ({ + async next() { + const friend = friends[i++]; + if (friend) { + await new Promise((r) => setTimeout(r, 1)); + return { value: friend, done: false }; + } + return { value: undefined, done: true }; + }, + }), + }; + }, + }, asyncIterableListDelayedClose: { type: new GraphQLList(friendType), async *resolve() { @@ -697,4 +728,172 @@ describe('Execute: stream directive', () => { }, ]); }); + it('Returns underlying async iterables when dispatcher is returned', async () => { + const document = parse(` + query { + asyncIterableListDelayed @stream(initialCount: 1) { + name + id + } + } + `); + const schema = new GraphQLSchema({ query }); + + const executeResult = await execute({ schema, document, rootValue: {} }); + invariant(isAsyncIterable(executeResult)); + const iterator = executeResult[Symbol.asyncIterator](); + + const result1 = await iterator.next(); + expect(result1).to.deep.equal({ + done: false, + value: { + data: { + asyncIterableListDelayed: [ + { + id: '1', + name: 'Luke', + }, + ], + }, + hasNext: true, + }, + }); + + iterator.return?.(); + + // this result had started processing before return was called + const result2 = await iterator.next(); + expect(result2).to.deep.equal({ + done: false, + value: { + data: { + id: '2', + name: 'Han', + }, + hasNext: true, + path: ['asyncIterableListDelayed', 1], + }, + }); + + // third result is not returned because async iterator has returned + const result3 = await iterator.next(); + expect(result3).to.deep.equal({ + done: false, + value: { + hasNext: false, + }, + }); + }); + it('Can return async iterable when underlying iterable does not have a return method', async () => { + const document = parse(` + query { + asyncIterableListNoReturn @stream(initialCount: 1) { + name + id + } + } + `); + const schema = new GraphQLSchema({ query }); + + const executeResult = await execute({ schema, document, rootValue: {} }); + invariant(isAsyncIterable(executeResult)); + const iterator = executeResult[Symbol.asyncIterator](); + + const result1 = await iterator.next(); + expect(result1).to.deep.equal({ + done: false, + value: { + data: { + asyncIterableListNoReturn: [ + { + id: '1', + name: 'Luke', + }, + ], + }, + hasNext: true, + }, + }); + + iterator.return?.(); + + // this result had started processing before return was called + const result2 = await iterator.next(); + expect(result2).to.deep.equal({ + done: false, + value: { + data: { + id: '2', + name: 'Han', + }, + hasNext: true, + path: ['asyncIterableListNoReturn', 1], + }, + }); + + // third result is not returned because async iterator has returned + const result3 = await iterator.next(); + expect(result3).to.deep.equal({ + done: false, + value: { + hasNext: false, + }, + }); + }); + it('Returns underlying async iterables when dispatcher is thrown', async () => { + const document = parse(` + query { + asyncIterableListDelayed @stream(initialCount: 1) { + name + id + } + } + `); + const schema = new GraphQLSchema({ query }); + + const executeResult = await execute({ schema, document, rootValue: {} }); + invariant(isAsyncIterable(executeResult)); + const iterator = executeResult[Symbol.asyncIterator](); + + const result1 = await iterator.next(); + expect(result1).to.deep.equal({ + done: false, + value: { + data: { + asyncIterableListDelayed: [ + { + id: '1', + name: 'Luke', + }, + ], + }, + hasNext: true, + }, + }); + + iterator.throw?.(new Error('bad')); + + // this result had started processing before return was called + const result2 = await iterator.next(); + expect(result2).to.deep.equal({ + done: false, + value: { + data: { + id: '2', + name: 'Han', + }, + hasNext: true, + path: ['asyncIterableListDelayed', 1], + }, + }); + + // third result is not returned because async iterator has returned + const result3 = await iterator.next(); + expect(result3).to.deep.equal({ + done: false, + value: { + hasNext: false, + }, + }); + }); }); diff --git a/src/execution/execute.ts b/src/execution/execute.ts index cff73c5482c..71424f20b9d 100644 --- a/src/execution/execute.ts +++ b/src/execution/execute.ts @@ -1620,10 +1620,14 @@ interface DispatcherResult { export class Dispatcher { _subsequentPayloads: Array>>; _initialResult?: ExecutionResult; + _iterators: Array>; + _isDone: boolean; _hasReturnedInitialResult: boolean; constructor() { this._subsequentPayloads = []; + this._iterators = []; + this._isDone = false; this._hasReturnedInitialResult = false; } @@ -1692,6 +1696,8 @@ export class Dispatcher { label?: string, ): void { const subsequentPayloads = this._subsequentPayloads; + const iterators = this._iterators; + iterators.push(iterator); function next(index: number) { const fieldPath = addPath(path, index, undefined); const patchErrors: Array = []; @@ -1699,6 +1705,7 @@ export class Dispatcher { iterator.next().then( ({ value: data, done }) => { if (done) { + iterators.splice(iterators.indexOf(iterator), 1); return { value: undefined, done: true }; } @@ -1769,6 +1776,14 @@ export class Dispatcher { } _race(): Promise> { + if (this._isDone) { + return Promise.resolve({ + value: { + hasNext: false, + }, + done: false, + }); + } return new Promise<{ promise: Promise>; }>((resolve) => { @@ -1828,15 +1843,29 @@ export class Dispatcher { return this._race(); } - get( - initialResult: ExecutionResult, - ): AsyncIterableIterator { + async _return(): Promise> { + await Promise.all(this._iterators.map((iterator) => iterator.return?.())); + this._isDone = true; + return { value: undefined, done: true }; + } + + async _throw( + error?: unknown, + ): Promise> { + await Promise.all(this._iterators.map((iterator) => iterator.return?.())); + this._isDone = true; + return Promise.reject(error); + } + + get(initialResult: ExecutionResult): AsyncGenerator { this._initialResult = initialResult; return { [Symbol.asyncIterator]() { return this; }, next: () => this._next(), + return: () => this._return(), + throw: (error?: unknown) => this._throw(error), }; } }