diff --git a/src/__test__/wrap.test.ts b/src/__test__/wrap.test.ts index 0298187..bfde055 100644 --- a/src/__test__/wrap.test.ts +++ b/src/__test__/wrap.test.ts @@ -1,5 +1,12 @@ /* eslint-disable @typescript-eslint/no-empty-function */ -import { ALBResult, APIGatewayProxyStructuredResultV2, CloudFrontResultResponse, Context } from 'aws-lambda'; +import { + ALBResult, + APIGatewayProxyStructuredResultV2, + CloudFrontResultResponse, + Context, + KinesisStreamBatchResponse, + KinesisStreamEvent, +} from 'aws-lambda'; import { describe, beforeEach, it } from 'node:test'; import assert from 'node:assert'; import { lf } from '../function.js'; @@ -235,7 +242,7 @@ describe('LambdaWrap', () => { it('should pass body through', async () => { const fn = lf.handler(() => { - return { body: 'fooBar' }; + return 'fooBar'; }); const ret = await new Promise((resolve) => fn(ApiGatewayExample, fakeContext, (a, b) => resolve(b))); assert.equal(ret, 'fooBar'); @@ -298,4 +305,28 @@ describe('LambdaWrap', () => { delete process.env['TRACE_LAMBDA']; }); + + it('should allow straight responses', async () => { + function fakeFn(req: LambdaRequest): KinesisStreamBatchResponse | void { + if (req.event.Records.length === 0) return; + + return { + batchItemFailures: req.event.Records.map((f) => { + return { itemIdentifier: f.kinesis.sequenceNumber }; + }), + }; + } + + const fn = lf.handler(fakeFn); + + const emptyResponse = await new Promise((resolve) => fn({ Records: [] }, fakeContext, (a, b) => resolve(b))); + assert.deepEqual(emptyResponse, undefined); + + const actualResponse = await new Promise((resolve) => + fn({ Records: [{ kinesis: { sequenceNumber: '123' } }] } as KinesisStreamEvent, fakeContext, (a, b) => + resolve(b), + ), + ); + assert.deepEqual(actualResponse, { batchItemFailures: [{ itemIdentifier: '123' }] }); + }); }); diff --git a/src/function.ts b/src/function.ts index d4c45df..6fccfa0 100644 --- a/src/function.ts +++ b/src/function.ts @@ -152,7 +152,7 @@ export class lf { * @param logger optional logger to use for the request @see lf.Logger */ public static handler( - fn: LambdaWrappedFunction, + fn: LambdaWrappedFunction, options?: Partial, logger?: LogType, ): LambdaHandler { @@ -176,8 +176,10 @@ export class lf { if (req.logContext['err']) return callback(req.logContext['err'] as Error); if (res.status > 399) return callback(req.toResponse(res) as unknown as string); } + return callback(null, req.toResponse(res)); } - return callback(null, req.toResponse(res)); + + return callback(null, res); }); } return handler;