diff --git a/CHANGELOG.md b/CHANGELOG.md index e69de29bb..4e248eadf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -0,0 +1 @@ +- Add support for callable function to return streaming response (#1629) diff --git a/spec/common/providers/https.spec.ts b/spec/common/providers/https.spec.ts index 9ea0f4f79..050577564 100644 --- a/spec/common/providers/https.spec.ts +++ b/spec/common/providers/https.spec.ts @@ -49,25 +49,33 @@ async function runCallableTest(test: CallTest): Promise { cors: { origin: true, methods: "POST" }, ...test.callableOption, }; - const callableFunctionV1 = https.onCallHandler(opts, (data, context) => { - expect(data).to.deep.equal(test.expectedData); - return test.callableFunction(data, context); - }); + const callableFunctionV1 = https.onCallHandler( + opts, + (data, context) => { + expect(data).to.deep.equal(test.expectedData); + return test.callableFunction(data, context); + }, + "gcfv1" + ); const responseV1 = await runHandler(callableFunctionV1, test.httpRequest); - expect(responseV1.body).to.deep.equal(test.expectedHttpResponse.body); + expect(responseV1.body).to.deep.equal(JSON.stringify(test.expectedHttpResponse.body)); expect(responseV1.headers).to.deep.equal(test.expectedHttpResponse.headers); expect(responseV1.status).to.equal(test.expectedHttpResponse.status); - const callableFunctionV2 = https.onCallHandler(opts, (request) => { - expect(request.data).to.deep.equal(test.expectedData); - return test.callableFunction2(request); - }); + const callableFunctionV2 = https.onCallHandler( + opts, + (request) => { + expect(request.data).to.deep.equal(test.expectedData); + return test.callableFunction2(request); + }, + "gcfv2" + ); const responseV2 = await runHandler(callableFunctionV2, test.httpRequest); - expect(responseV2.body).to.deep.equal(test.expectedHttpResponse.body); + expect(responseV2.body).to.deep.equal(JSON.stringify(test.expectedHttpResponse.body)); expect(responseV2.headers).to.deep.equal(test.expectedHttpResponse.headers); expect(responseV2.status).to.equal(test.expectedHttpResponse.status); } @@ -165,7 +173,7 @@ describe("onCallHandler", () => { status: 400, headers: expectedResponseHeaders, body: { - error: { status: "INVALID_ARGUMENT", message: "Bad Request" }, + error: { message: "Bad Request", status: "INVALID_ARGUMENT" }, }, }, }); @@ -203,7 +211,7 @@ describe("onCallHandler", () => { status: 400, headers: expectedResponseHeaders, body: { - error: { status: "INVALID_ARGUMENT", message: "Bad Request" }, + error: { message: "Bad Request", status: "INVALID_ARGUMENT" }, }, }, }); @@ -225,7 +233,7 @@ describe("onCallHandler", () => { status: 400, headers: expectedResponseHeaders, body: { - error: { status: "INVALID_ARGUMENT", message: "Bad Request" }, + error: { message: "Bad Request", status: "INVALID_ARGUMENT" }, }, }, }); @@ -244,7 +252,7 @@ describe("onCallHandler", () => { expectedHttpResponse: { status: 500, headers: expectedResponseHeaders, - body: { error: { status: "INTERNAL", message: "INTERNAL" } }, + body: { error: { message: "INTERNAL", status: "INTERNAL" } }, }, }); }); @@ -262,7 +270,7 @@ describe("onCallHandler", () => { expectedHttpResponse: { status: 500, headers: expectedResponseHeaders, - body: { error: { status: "INTERNAL", message: "INTERNAL" } }, + body: { error: { message: "INTERNAL", status: "INTERNAL" } }, }, }); }); @@ -280,7 +288,7 @@ describe("onCallHandler", () => { expectedHttpResponse: { status: 404, headers: expectedResponseHeaders, - body: { error: { status: "NOT_FOUND", message: "i am error" } }, + body: { error: { message: "i am error", status: "NOT_FOUND" } }, }, }); }); @@ -364,8 +372,8 @@ describe("onCallHandler", () => { headers: expectedResponseHeaders, body: { error: { - status: "UNAUTHENTICATED", message: "Unauthenticated", + status: "UNAUTHENTICATED", }, }, }, @@ -391,8 +399,8 @@ describe("onCallHandler", () => { headers: expectedResponseHeaders, body: { error: { - status: "UNAUTHENTICATED", message: "Unauthenticated", + status: "UNAUTHENTICATED", }, }, }, @@ -461,8 +469,8 @@ describe("onCallHandler", () => { headers: expectedResponseHeaders, body: { error: { - status: "UNAUTHENTICATED", message: "Unauthenticated", + status: "UNAUTHENTICATED", }, }, }, @@ -748,6 +756,53 @@ describe("onCallHandler", () => { }); }); }); + + describe("Streaming callables", () => { + it("returns data in SSE format for requests Accept: text/event-stream header", async () => { + const mockReq = mockRequest( + { message: "hello streaming" }, + "application/json", + {}, + { accept: "text/event-stream" } + ) as any; + const fn = https.onCallHandler( + { + cors: { origin: true, methods: "POST" }, + }, + (req, resp) => { + resp.write("hello"); + return "world"; + }, + "gcfv2" + ); + + const resp = await runHandler(fn, mockReq); + const data = [`data: {"message":"hello"}`, `data: {"result":"world"}`]; + expect(resp.body).to.equal([...data, ""].join("\n")); + }); + + it("returns error in SSE format", async () => { + const mockReq = mockRequest( + { message: "hello streaming" }, + "application/json", + {}, + { accept: "text/event-stream" } + ) as any; + const fn = https.onCallHandler( + { + cors: { origin: true, methods: "POST" }, + }, + () => { + throw new Error("BOOM"); + }, + "gcfv2" + ); + + const resp = await runHandler(fn, mockReq); + const data = [`data: {"error":{"message":"INTERNAL","status":"INTERNAL"}}`]; + expect(resp.body).to.equal([...data, ""].join("\n")); + }); + }); }); describe("encoding/decoding", () => { diff --git a/spec/helper.ts b/spec/helper.ts index 8dd78d82c..544061b0b 100644 --- a/spec/helper.ts +++ b/spec/helper.ts @@ -47,6 +47,7 @@ export function runHandler( // MockResponse mocks an express.Response. // This class lives here so it can reference resolve and reject. class MockResponse { + private sentBody = ""; private statusCode = 0; private headers: { [name: string]: string } = {}; private callback: () => void; @@ -65,7 +66,10 @@ export function runHandler( return this.headers[name]; } - public send(body: any) { + public send(sendBody: any) { + const toSend = typeof sendBody === "object" ? JSON.stringify(sendBody) : sendBody; + const body = this.sentBody ? this.sentBody + ((toSend as string) || "") : toSend; + resolve({ status: this.statusCode, headers: this.headers, @@ -76,6 +80,10 @@ export function runHandler( } } + public write(writeBody: any) { + this.sentBody += typeof writeBody === "object" ? JSON.stringify(writeBody) : writeBody; + } + public end() { this.send(undefined); } diff --git a/spec/v1/providers/https.spec.ts b/spec/v1/providers/https.spec.ts index c3a7671c0..96f54f569 100644 --- a/spec/v1/providers/https.spec.ts +++ b/spec/v1/providers/https.spec.ts @@ -276,7 +276,7 @@ describe("callable CORS", () => { const response = await runHandler(func, req as any); expect(response.status).to.equal(200); - expect(response.body).to.be.deep.equal({ result: 42 }); + expect(response.body).to.be.deep.equal(JSON.stringify({ result: 42 })); expect(response.headers).to.deep.equal(expectedResponseHeaders); }); }); diff --git a/spec/v2/providers/https.spec.ts b/spec/v2/providers/https.spec.ts index 643044338..77d69bfcc 100644 --- a/spec/v2/providers/https.spec.ts +++ b/spec/v2/providers/https.spec.ts @@ -417,7 +417,7 @@ describe("onCall", () => { req.method = "POST"; const resp = await runHandler(func, req as any); - expect(resp.body).to.deep.equal({ result: 42 }); + expect(resp.body).to.deep.equal(JSON.stringify({ result: 42 })); }); it("should enforce CORS options", async () => { @@ -496,7 +496,7 @@ describe("onCall", () => { const response = await runHandler(func, req as any); expect(response.status).to.equal(200); - expect(response.body).to.be.deep.equal({ result: 42 }); + expect(response.body).to.be.deep.equal(JSON.stringify({ result: 42 })); expect(response.headers).to.deep.equal(expectedResponseHeaders); }); diff --git a/src/common/providers/https.ts b/src/common/providers/https.ts index 2f0e56538..d10696837 100644 --- a/src/common/providers/https.ts +++ b/src/common/providers/https.ts @@ -141,6 +141,15 @@ export interface CallableRequest { rawRequest: Request; } +/** + * CallableProxyResponse exposes subset of express.Response object + * to allow writing partial, streaming responses back to the client. + */ +export interface CallableProxyResponse { + write: express.Response["write"]; + acceptsStreaming: boolean; +} + /** * The set of Firebase Functions status codes. The codes are the same at the * ones exposed by {@link https://github.com/grpc/grpc/blob/master/doc/statuscodes.md | gRPC}. @@ -673,7 +682,10 @@ async function checkAppCheckToken( } type v1CallableHandler = (data: any, context: CallableContext) => any | Promise; -type v2CallableHandler = (request: CallableRequest) => Res; +type v2CallableHandler = ( + request: CallableRequest, + response?: CallableProxyResponse +) => Res; /** @internal **/ export interface CallableOptions { @@ -685,9 +697,10 @@ export interface CallableOptions { /** @internal */ export function onCallHandler( options: CallableOptions, - handler: v1CallableHandler | v2CallableHandler + handler: v1CallableHandler | v2CallableHandler, + version: "gcfv1" | "gcfv2" ): (req: Request, res: express.Response) => Promise { - const wrapped = wrapOnCallHandler(options, handler); + const wrapped = wrapOnCallHandler(options, handler, version); return (req: Request, res: express.Response) => { return new Promise((resolve) => { res.on("finish", resolve); @@ -698,10 +711,15 @@ export function onCallHandler( }; } +function encodeSSE(data: unknown): string { + return `data: ${JSON.stringify(data)}\n`; +} + /** @internal */ function wrapOnCallHandler( options: CallableOptions, - handler: v1CallableHandler | v2CallableHandler + handler: v1CallableHandler | v2CallableHandler, + version: "gcfv1" | "gcfv2" ): (req: Request, res: express.Response) => Promise { return async (req: Request, res: express.Response): Promise => { try { @@ -719,7 +737,7 @@ function wrapOnCallHandler( // The original monkey-patched code lived in the functionsEmulatorRuntime // (link: https://github.com/firebase/firebase-tools/blob/accea7abda3cc9fa6bb91368e4895faf95281c60/src/emulator/functionsEmulatorRuntime.ts#L480) // and was not compatible with how monorepos separate out packages (see https://github.com/firebase/firebase-tools/issues/5210). - if (isDebugFeatureEnabled("skipTokenVerification") && handler.length === 2) { + if (isDebugFeatureEnabled("skipTokenVerification") && version === "gcfv1") { const authContext = context.rawRequest.header(CALLABLE_AUTH_HEADER); if (authContext) { logger.debug("Callable functions auth override", { @@ -763,18 +781,34 @@ function wrapOnCallHandler( context.instanceIdToken = req.header("Firebase-Instance-ID-Token"); } + const acceptsStreaming = req.header("accept") === "text/event-stream"; const data: Req = decode(req.body.data); let result: Res; - if (handler.length === 2) { - result = await handler(data, context); + if (version === "gcfv1") { + result = await (handler as v1CallableHandler)(data, context); } else { const arg: CallableRequest = { ...context, data, }; + // TODO: set up optional heartbeat + const responseProxy: CallableProxyResponse = { + write(chunk): boolean { + if (acceptsStreaming) { + const formattedData = encodeSSE({ message: chunk }); + return res.write(formattedData); + } + // if client doesn't accept sse-protocol, response.write() is no-op. + }, + acceptsStreaming, + }; + if (acceptsStreaming) { + // SSE always responds with 200 + res.status(200); + } // For some reason the type system isn't picking up that the handler // is a one argument function. - result = await (handler as any)(arg); + result = await (handler as any)(arg, responseProxy); } // Encode the result as JSON to preserve types like Dates. @@ -782,7 +816,12 @@ function wrapOnCallHandler( // If there was some result, encode it in the body. const responseBody: HttpResponseBody = { result }; - res.status(200).send(responseBody); + if (acceptsStreaming) { + res.write(encodeSSE(responseBody)); + res.end(); + } else { + res.status(200).send(responseBody); + } } catch (err) { let httpErr = err; if (!(err instanceof HttpsError)) { @@ -793,8 +832,11 @@ function wrapOnCallHandler( const { status } = httpErr.httpErrorCode; const body = { error: httpErr.toJSON() }; - - res.status(status).send(body); + if (req.header("accept") === "text/event-stream") { + res.send(encodeSSE(body)); + } else { + res.status(status).send(body); + } } }; } diff --git a/src/v1/providers/https.ts b/src/v1/providers/https.ts index e9cd5d132..8d079bfa1 100644 --- a/src/v1/providers/https.ts +++ b/src/v1/providers/https.ts @@ -102,9 +102,8 @@ export function _onCallWithOptions( handler: (data: any, context: CallableContext) => any | Promise, options: DeploymentOptions ): HttpsFunction & Runnable { - // onCallHandler sniffs the function length of the passed-in callback - // and the user could have only tried to listen to data. Wrap their handler - // in another handler to avoid accidentally triggering the v2 API + // fix the length of handler to make the call to handler consistent + // in the onCallHandler const fixedLen = (data: any, context: CallableContext) => { return withInit(handler)(data, context); }; @@ -115,7 +114,8 @@ export function _onCallWithOptions( consumeAppCheckToken: options.consumeAppCheckToken, cors: { origin: true, methods: "POST" }, }, - fixedLen + fixedLen, + "gcfv1" ) ); diff --git a/src/v2/providers/https.ts b/src/v2/providers/https.ts index 16ad9038c..b21c04ba1 100644 --- a/src/v2/providers/https.ts +++ b/src/v2/providers/https.ts @@ -33,6 +33,7 @@ import { isDebugFeatureEnabled } from "../../common/debug"; import { ResetValue } from "../../common/options"; import { CallableRequest, + CallableProxyResponse, FunctionsErrorCode, HttpsError, onCallHandler, @@ -347,7 +348,7 @@ export function onRequest( */ export function onCall>( opts: CallableOptions, - handler: (request: CallableRequest) => Return + handler: (request: CallableRequest, response?: CallableProxyResponse) => Return ): CallableFunction ? Return : Promise>; /** @@ -356,11 +357,11 @@ export function onCall>( * @returns A function that you can export and deploy. */ export function onCall>( - handler: (request: CallableRequest) => Return + handler: (request: CallableRequest, response?: CallableProxyResponse) => Return ): CallableFunction ? Return : Promise>; export function onCall>( optsOrHandler: CallableOptions | ((request: CallableRequest) => Return), - handler?: (request: CallableRequest) => Return + handler?: (request: CallableRequest, response?: CallableProxyResponse) => Return ): CallableFunction ? Return : Promise> { let opts: CallableOptions; if (arguments.length === 1) { @@ -378,16 +379,17 @@ export function onCall>( origin = origin[0]; } - // onCallHandler sniffs the function length to determine which API to present. - // fix the length to prevent api versions from being mismatched. - const fixedLen = (req: CallableRequest) => withInit(handler)(req); + // fix the length of handler to make the call to handler consistent + const fixedLen = (req: CallableRequest, resp?: CallableProxyResponse) => + withInit(handler)(req, resp); let func: any = onCallHandler( { cors: { origin, methods: "POST" }, enforceAppCheck: opts.enforceAppCheck ?? options.getGlobalOptions().enforceAppCheck, consumeAppCheckToken: opts.consumeAppCheckToken, }, - fixedLen + fixedLen, + "gcfv2" ); func = wrapTraceContext(func);