diff --git a/README.md b/README.md index e037aef..46c89a5 100644 --- a/README.md +++ b/README.md @@ -86,14 +86,9 @@ The **Results Validator** ensures that the pipeline has fulfilled the interface ## Middleware -If **Middleware** is specified, it will be run on the specified stage lifecycle event(s) for each stage in the pipeline. +If **Middleware** is specified, it will be wrapped around each stage in the pipeline. This follows [a pattern similar to Express](https://expressjs.com/en/guide/using-middleware.html). Each middleware is called in the order it is specified and includes a `next()` to call the next middleware/stage. -| Stage Event | Description | -| ----------------- | ---------------------------------- | -| `onStageStart` | Runs before each stage is executed | -| `onStageComplete` | Runs after each stage is executed | - -Middleware is specified as an object with middleware callbacks mapped to at least one of the above event keys. A middleware callback is provided the following attributes: +A middleware callback is provided the following attributes: | Parameter | Description | | -------------- | ----------------------------------------------------------------------------- | @@ -102,6 +97,7 @@ Middleware is specified as an object with middleware callbacks mapped to at leas | `results` | A read-only set of results returned by stages so far | | `stageNames` | An array of the names of the methods that make up the current pipeline stages | | `currentStage` | The name of the current pipeline stage | +| `next` | Calls the next middleware in the stack (or the stage if none) | See the [LogStageMiddlewareFactory](./src/middleware/logStageMiddlewareFactory.ts) for a simple middleware implementation. It is wrapped in a factory method so a log method can be properly injected. diff --git a/package.json b/package.json index fb05f36..fd710d4 100644 --- a/package.json +++ b/package.json @@ -5,6 +5,7 @@ "main": "build/index.js", "types": "build/index.d.ts", "scripts": { + "dev": "tsc --noEmit --watch", "prepack": "npm run build", "build": "tsc --project tsconfig.build.json", "eslint": "eslint --ext .js,.ts --cache --cache-location=node_modules/.cache/eslint --cache-strategy content .", diff --git a/src/__mocks__/TestPipeline.ts b/src/__mocks__/TestPipeline.ts index 8ce9775..c8759b6 100644 --- a/src/__mocks__/TestPipeline.ts +++ b/src/__mocks__/TestPipeline.ts @@ -1,6 +1,7 @@ import { last } from "lodash"; -import { +import type { PipelineInitializer, + PipelineMiddleware, PipelineResultValidator, PipelineStage, } from "../types"; @@ -24,6 +25,12 @@ export type TestStage = PipelineStage< TestPipelineResults >; +export type TestMiddleware = PipelineMiddleware< + TestPipelineArguments, + TestPipelineContext, + TestPipelineResults +>; + /** * A stage to set up the test pipeline */ @@ -82,7 +89,7 @@ export const errorStage: TestStage = () => { */ export const testPipelineResultValidator: PipelineResultValidator< TestPipelineResults -> = (results) => { +> = (results): results is TestPipelineResults => { // false if sum is not a number if (typeof results.sum !== "number") { return false; diff --git a/src/__tests__/buildPipeline.test.ts b/src/__tests__/buildPipeline.test.ts index 3b31075..20b6bd3 100644 --- a/src/__tests__/buildPipeline.test.ts +++ b/src/__tests__/buildPipeline.test.ts @@ -1,4 +1,6 @@ +import { logStageMiddlewareFactory } from "middleware/logStageMiddlewareFactory"; import { + TestMiddleware, TestPipelineArguments, TestPipelineContext, TestPipelineResults, @@ -12,7 +14,6 @@ import { import { buildPipeline } from "../buildPipeline"; import { PipelineError } from "../error/PipelineError"; import { returnSumResult } from "./../__mocks__/TestPipeline"; -import { PipelineMiddleware } from "./../types"; const INCREMENT = 5; @@ -53,42 +54,55 @@ describe("buildPipeline", () => { }); describe("when using middleware", () => { - const testStart = jest.fn(); - const testComplete = jest.fn(); - const testMiddleware: PipelineMiddleware = { - onStageStart: testStart, - onStageComplete: testComplete, - }; - - const partialComplete = jest.fn(); - const partialMiddleware: PipelineMiddleware = { - onStageComplete: partialComplete, - }; - - beforeEach(() => { - testStart.mockClear(); - testComplete.mockClear(); - partialComplete.mockClear(); - }); + let middlewareCalls: string[]; + + let testMiddleware1: TestMiddlewareMock; + let testMiddleware2: TestMiddlewareMock; + + beforeAll(async () => { + middlewareCalls = []; + + const createMiddlewareMock = (name: string): TestMiddlewareMock => { + return jest.fn(({ currentStage, next }) => { + middlewareCalls.push(`${currentStage}: ${name}`); - it("should run the test middleware", async () => { - await runPipelineForStages(successfulStages, [testMiddleware]); + return next(); + }); + }; - expect(testStart).toHaveBeenCalledTimes(successfulStages.length); - expect(testComplete).toHaveBeenCalledTimes(successfulStages.length); + testMiddleware1 = createMiddlewareMock("testMiddleware1"); + testMiddleware2 = createMiddlewareMock("testMiddleware2"); + + await runPipelineForStages(successfulStages, [ + logStageMiddlewareFactory(), + testMiddleware1, + testMiddleware2, + ]); }); - it("should run the partial middleware", async () => { - await runPipelineForStages(successfulStages, [partialMiddleware]); + it(`should run each middleware ${successfulStages.length} times`, () => { + expect(testMiddleware1).toHaveBeenCalledTimes(successfulStages.length); + expect(testMiddleware2).toHaveBeenCalledTimes(successfulStages.length); + }); - expect(partialComplete).toHaveBeenCalledTimes(successfulStages.length); + it("should run middleware in the correct order", () => { + expect(middlewareCalls).toEqual([ + "additionStage: testMiddleware1", + "additionStage: testMiddleware2", + "additionStage: testMiddleware1", + "additionStage: testMiddleware2", + "returnSumResult: testMiddleware1", + "returnSumResult: testMiddleware2", + "returnHistoryResult: testMiddleware1", + "returnHistoryResult: testMiddleware2", + ]); }); }); }); function runPipelineForStages( stages: TestStage[], - middleware: PipelineMiddleware[] = [], + middleware: TestMiddleware[] = [], ) { const pipeline = buildPipeline< TestPipelineArguments, @@ -104,3 +118,8 @@ function runPipelineForStages( return pipeline({ increment: INCREMENT }); } + +type TestMiddlewareMock = jest.Mock< + ReturnType, + Parameters +>; diff --git a/src/buildPipeline.ts b/src/buildPipeline.ts index 6cefc60..aefa0d1 100644 --- a/src/buildPipeline.ts +++ b/src/buildPipeline.ts @@ -1,13 +1,10 @@ -import { compact, merge } from "lodash"; +import { merge } from "lodash"; import { PipelineError } from "./error/PipelineError"; -import { +import type { Pipeline, PipelineInitializer, PipelineMetadata, PipelineMiddleware, - PipelineMiddlewareCallable, - PipelineMiddlewareEventType, - PipelineMiddlewarePayload, PipelineResultValidator, PipelineStage, } from "./types"; @@ -21,7 +18,7 @@ interface BuildPipelineInput< initializer: PipelineInitializer; stages: PipelineStage[]; resultsValidator: PipelineResultValidator; - middleware?: PipelineMiddleware[]; + middleware?: PipelineMiddleware[]; } /** @@ -36,7 +33,7 @@ export function buildPipeline< initializer, stages, resultsValidator, - middleware = [], + middleware: middlewares = [], }: BuildPipelineInput): Pipeline { return async (args) => { const results: Partial = {}; @@ -55,78 +52,55 @@ export function buildPipeline< const context = await initializer(args); maybeContext = context; - const buildMiddlewarePayload = ( + const reversedMiddleware = [...middlewares].reverse(); + const wrapMiddleware = ( + middleware: PipelineMiddleware, currentStage: string, - ): PipelineMiddlewarePayload => ({ - context, - metadata, - results, - stageNames, - currentStage, - }); + next: () => Promise>, + ) => { + return () => { + return middleware({ + context, + metadata, + results, + stageNames, + currentStage, + next, + }); + }; + }; for (const stage of stages) { - await executeMiddlewareForEvent( - "onStageStart", - middleware, - buildMiddlewarePayload(stage.name), - ); + // initialize next() with the stage itself + let next = () => stage(context, metadata) as Promise>; + + // wrap stage with middleware such that the first middleware is the outermost function + for (const middleware of reversedMiddleware) { + next = wrapMiddleware(middleware, stage.name, next); + } - const stageResults = await stage(context, metadata); + // invoke middleware-wrapped stage + const stageResults = await next(); // if the stage returns results, merge them onto the results object if (stageResults) { merge(results, stageResults); } - - await executeMiddlewareForEvent( - "onStageComplete", - [...middleware].reverse(), - buildMiddlewarePayload(stage.name), - ); } - if (!isValidResult(results, resultsValidator)) { + if (!resultsValidator(results)) { throw new Error("Results from pipeline failed validation"); } return results; - } catch (e) { + } catch (cause) { throw new PipelineError( - `${String(e)}`, + String(cause), maybeContext, results, metadata, - e, + cause, ); } }; } - -async function executeMiddlewareForEvent< - A extends object, - C extends object, - R extends object, ->( - event: PipelineMiddlewareEventType, - middleware: PipelineMiddleware[], - payload: PipelineMiddlewarePayload, -) { - const handlers = compact>( - middleware.map((m) => m[event]), - ); - - for (const handler of handlers) { - await handler(payload); - } -} - -/** - * Wraps the provided validator in a type guard - */ -function isValidResult( - result: Partial, - validator: PipelineResultValidator, -): result is R { - return validator(result); -} diff --git a/src/middleware/logStageMiddlewareFactory.ts b/src/middleware/logStageMiddlewareFactory.ts index ba00136..3ffc6bc 100644 --- a/src/middleware/logStageMiddlewareFactory.ts +++ b/src/middleware/logStageMiddlewareFactory.ts @@ -1,15 +1,20 @@ import { PipelineMiddleware } from "../types"; /** - * A simple implementation of Pipeline middleware that logs when each stage begins and finishes + * A simple implementation of Pipeline middleware that logs the duration of each stage */ export const logStageMiddlewareFactory = ( logger: (msg: string) => void = console.log, -): PipelineMiddleware => ({ - onStageStart: ({ metadata, currentStage }) => { - logger(`[${metadata.name}] starting ${currentStage}...`); - }, - onStageComplete: ({ metadata, currentStage }) => { - logger(`[${metadata.name}] ${currentStage} completed`); - }, -}); +): PipelineMiddleware => { + return async ({ metadata, currentStage, next }) => { + const started = performance.now(); + + try { + return await next(); + } finally { + logger( + `[${metadata.name}] ${currentStage} completed in ${performance.now() - started}ms`, + ); + } + }; +}; diff --git a/src/types.ts b/src/types.ts index 5a641e1..417f754 100644 --- a/src/types.ts +++ b/src/types.ts @@ -12,10 +12,15 @@ export type PipelineStage< A extends object, C extends object, R extends object, -> = ( - context: C, - metadata: PipelineMetadata, -) => Promise> | Partial | Promise | void; +> = (context: C, metadata: PipelineMetadata) => PipelineStageResult; + +/** + * Optional partial result that gets merged with results from other stages + */ +export type PipelineStageResult = + | Promise | void> + | Partial + | void; /** * A method that initializes the pipeline by creating the context object that gets passed to each stage. Note that because the context extends PipelineContext, this method must also include the pipeline name and arguments when constructing the context object. @@ -30,8 +35,8 @@ export type PipelineInitializer = ( * Validates that results at the conclusion of the pipeline's execution are complete */ export type PipelineResultValidator = ( - results: Partial, -) => boolean; + results: Readonly>, +) => results is R; /** * Basic metadata about a pipeline execution @@ -41,38 +46,17 @@ export interface PipelineMetadata { name: Readonly; } -interface BasePipelineMiddleware< - A extends object = object, - C extends object = object, - R extends object = object, -> { - /** Runs before a pipeline stage is executed */ - onStageStart: PipelineMiddlewareCallable; - /** Runs after a pipeline stage is executing and includes results returned by that stage */ - onStageComplete: PipelineMiddlewareCallable; -} - -/** - * Event-based middleware to run around each pipeline stage - */ -export type PipelineMiddleware = Partial; - -/** - * The events that are supported by pipeline middleware - */ -export type PipelineMiddlewareEventType = keyof BasePipelineMiddleware; - /** - * Functions that can be assigned to each event in the middleware + * Middleware function that can run code before and/or after each stage */ -export type PipelineMiddlewareCallable< +export type PipelineMiddleware< A extends object = object, C extends object = object, R extends object = object, -> = (input: PipelineMiddlewarePayload) => Promise | void; +> = (payload: PipelineMiddlewarePayload) => Promise>; /** - * The payload that gets passed to each `PipelineMiddlewareCallable` + * The payload that gets passed to each `PipelineMiddleware` */ export interface PipelineMiddlewarePayload< A extends object, @@ -84,4 +68,5 @@ export interface PipelineMiddlewarePayload< results: Readonly>; stageNames: string[]; currentStage: string; + next: () => Promise>; }