Skip to content

Commit

Permalink
[sc-23218] Generalize middleware (#5)
Browse files Browse the repository at this point in the history
Co-authored-by: Nate Rutman <nrutman@users.noreply.github.com>
  • Loading branch information
namoscato and nrutman authored Feb 13, 2024
1 parent c5b1ca3 commit c3aef7d
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 134 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,9 @@ The **Results Validator** ensures that the pipeline has fulfilled the interface

## Middleware

If **Middleware** is specified, it will be run on the specified stage lifecycle event(s) for each stage in the pipeline.
If **Middleware** is specified, it will be wrapped around each stage in the pipeline. This follows [a pattern similar to Express](https://expressjs.com/en/guide/using-middleware.html). Each middleware is called in the order it is specified and includes a `next()` to call the next middleware/stage.

| Stage Event | Description |
| ----------------- | ---------------------------------- |
| `onStageStart` | Runs before each stage is executed |
| `onStageComplete` | Runs after each stage is executed |

Middleware is specified as an object with middleware callbacks mapped to at least one of the above event keys. A middleware callback is provided the following attributes:
A middleware callback is provided the following attributes:

| Parameter | Description |
| -------------- | ----------------------------------------------------------------------------- |
Expand All @@ -102,6 +97,7 @@ Middleware is specified as an object with middleware callbacks mapped to at leas
| `results` | A read-only set of results returned by stages so far |
| `stageNames` | An array of the names of the methods that make up the current pipeline stages |
| `currentStage` | The name of the current pipeline stage |
| `next` | Calls the next middleware in the stack (or the stage if none) |

See the [LogStageMiddlewareFactory](./src/middleware/logStageMiddlewareFactory.ts) for a simple middleware implementation. It is wrapped in a factory method so a log method can be properly injected.

Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"main": "build/index.js",
"types": "build/index.d.ts",
"scripts": {
"dev": "tsc --noEmit --watch",
"prepack": "npm run build",
"build": "tsc --project tsconfig.build.json",
"eslint": "eslint --ext .js,.ts --cache --cache-location=node_modules/.cache/eslint --cache-strategy content .",
Expand Down
11 changes: 9 additions & 2 deletions src/__mocks__/TestPipeline.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { last } from "lodash";
import {
import type {
PipelineInitializer,
PipelineMiddleware,
PipelineResultValidator,
PipelineStage,
} from "../types";
Expand All @@ -24,6 +25,12 @@ export type TestStage = PipelineStage<
TestPipelineResults
>;

export type TestMiddleware = PipelineMiddleware<
TestPipelineArguments,
TestPipelineContext,
TestPipelineResults
>;

/**
* A stage to set up the test pipeline
*/
Expand Down Expand Up @@ -82,7 +89,7 @@ export const errorStage: TestStage = () => {
*/
export const testPipelineResultValidator: PipelineResultValidator<
TestPipelineResults
> = (results) => {
> = (results): results is TestPipelineResults => {
// false if sum is not a number
if (typeof results.sum !== "number") {
return false;
Expand Down
71 changes: 45 additions & 26 deletions src/__tests__/buildPipeline.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { logStageMiddlewareFactory } from "middleware/logStageMiddlewareFactory";
import {
TestMiddleware,
TestPipelineArguments,
TestPipelineContext,
TestPipelineResults,
Expand All @@ -12,7 +14,6 @@ import {
import { buildPipeline } from "../buildPipeline";
import { PipelineError } from "../error/PipelineError";
import { returnSumResult } from "./../__mocks__/TestPipeline";
import { PipelineMiddleware } from "./../types";

const INCREMENT = 5;

Expand Down Expand Up @@ -53,42 +54,55 @@ describe("buildPipeline", () => {
});

describe("when using middleware", () => {
const testStart = jest.fn();
const testComplete = jest.fn();
const testMiddleware: PipelineMiddleware = {
onStageStart: testStart,
onStageComplete: testComplete,
};

const partialComplete = jest.fn();
const partialMiddleware: PipelineMiddleware = {
onStageComplete: partialComplete,
};

beforeEach(() => {
testStart.mockClear();
testComplete.mockClear();
partialComplete.mockClear();
});
let middlewareCalls: string[];

let testMiddleware1: TestMiddlewareMock;
let testMiddleware2: TestMiddlewareMock;

beforeAll(async () => {
middlewareCalls = [];

const createMiddlewareMock = (name: string): TestMiddlewareMock => {
return jest.fn(({ currentStage, next }) => {
middlewareCalls.push(`${currentStage}: ${name}`);

it("should run the test middleware", async () => {
await runPipelineForStages(successfulStages, [testMiddleware]);
return next();
});
};

expect(testStart).toHaveBeenCalledTimes(successfulStages.length);
expect(testComplete).toHaveBeenCalledTimes(successfulStages.length);
testMiddleware1 = createMiddlewareMock("testMiddleware1");
testMiddleware2 = createMiddlewareMock("testMiddleware2");

await runPipelineForStages(successfulStages, [
logStageMiddlewareFactory(),
testMiddleware1,
testMiddleware2,
]);
});

it("should run the partial middleware", async () => {
await runPipelineForStages(successfulStages, [partialMiddleware]);
it(`should run each middleware ${successfulStages.length} times`, () => {
expect(testMiddleware1).toHaveBeenCalledTimes(successfulStages.length);
expect(testMiddleware2).toHaveBeenCalledTimes(successfulStages.length);
});

expect(partialComplete).toHaveBeenCalledTimes(successfulStages.length);
it("should run middleware in the correct order", () => {
expect(middlewareCalls).toEqual([
"additionStage: testMiddleware1",
"additionStage: testMiddleware2",
"additionStage: testMiddleware1",
"additionStage: testMiddleware2",
"returnSumResult: testMiddleware1",
"returnSumResult: testMiddleware2",
"returnHistoryResult: testMiddleware1",
"returnHistoryResult: testMiddleware2",
]);
});
});
});

function runPipelineForStages(
stages: TestStage[],
middleware: PipelineMiddleware[] = [],
middleware: TestMiddleware[] = [],
) {
const pipeline = buildPipeline<
TestPipelineArguments,
Expand All @@ -104,3 +118,8 @@ function runPipelineForStages(

return pipeline({ increment: INCREMENT });
}

type TestMiddlewareMock = jest.Mock<
ReturnType<TestMiddleware>,
Parameters<TestMiddleware>
>;
92 changes: 33 additions & 59 deletions src/buildPipeline.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import { compact, merge } from "lodash";
import { merge } from "lodash";
import { PipelineError } from "./error/PipelineError";
import {
import type {
Pipeline,
PipelineInitializer,
PipelineMetadata,
PipelineMiddleware,
PipelineMiddlewareCallable,
PipelineMiddlewareEventType,
PipelineMiddlewarePayload,
PipelineResultValidator,
PipelineStage,
} from "./types";
Expand All @@ -21,7 +18,7 @@ interface BuildPipelineInput<
initializer: PipelineInitializer<C, A>;
stages: PipelineStage<A, C, R>[];
resultsValidator: PipelineResultValidator<R>;
middleware?: PipelineMiddleware[];
middleware?: PipelineMiddleware<A, C, R>[];
}

/**
Expand All @@ -36,7 +33,7 @@ export function buildPipeline<
initializer,
stages,
resultsValidator,
middleware = [],
middleware: middlewares = [],
}: BuildPipelineInput<A, C, R>): Pipeline<A, R> {
return async (args) => {
const results: Partial<R> = {};
Expand All @@ -55,78 +52,55 @@ export function buildPipeline<
const context = await initializer(args);
maybeContext = context;

const buildMiddlewarePayload = (
const reversedMiddleware = [...middlewares].reverse();
const wrapMiddleware = (
middleware: PipelineMiddleware<A, C, R>,
currentStage: string,
): PipelineMiddlewarePayload<A, C, R> => ({
context,
metadata,
results,
stageNames,
currentStage,
});
next: () => Promise<Partial<R>>,
) => {
return () => {
return middleware({
context,
metadata,
results,
stageNames,
currentStage,
next,
});
};
};

for (const stage of stages) {
await executeMiddlewareForEvent(
"onStageStart",
middleware,
buildMiddlewarePayload(stage.name),
);
// initialize next() with the stage itself
let next = () => stage(context, metadata) as Promise<Partial<R>>;

// wrap stage with middleware such that the first middleware is the outermost function
for (const middleware of reversedMiddleware) {
next = wrapMiddleware(middleware, stage.name, next);
}

const stageResults = await stage(context, metadata);
// invoke middleware-wrapped stage
const stageResults = await next();

// if the stage returns results, merge them onto the results object
if (stageResults) {
merge(results, stageResults);
}

await executeMiddlewareForEvent(
"onStageComplete",
[...middleware].reverse(),
buildMiddlewarePayload(stage.name),
);
}

if (!isValidResult(results, resultsValidator)) {
if (!resultsValidator(results)) {
throw new Error("Results from pipeline failed validation");
}

return results;
} catch (e) {
} catch (cause) {
throw new PipelineError(
`${String(e)}`,
String(cause),
maybeContext,
results,
metadata,
e,
cause,
);
}
};
}

async function executeMiddlewareForEvent<
A extends object,
C extends object,
R extends object,
>(
event: PipelineMiddlewareEventType,
middleware: PipelineMiddleware[],
payload: PipelineMiddlewarePayload<A, C, R>,
) {
const handlers = compact<PipelineMiddlewareCallable<object, object, object>>(
middleware.map((m) => m[event]),
);

for (const handler of handlers) {
await handler(payload);
}
}

/**
* Wraps the provided validator in a type guard
*/
function isValidResult<R extends object>(
result: Partial<R>,
validator: PipelineResultValidator<R>,
): result is R {
return validator(result);
}
23 changes: 14 additions & 9 deletions src/middleware/logStageMiddlewareFactory.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import { PipelineMiddleware } from "../types";

/**
* A simple implementation of Pipeline middleware that logs when each stage begins and finishes
* A simple implementation of Pipeline middleware that logs the duration of each stage
*/
export const logStageMiddlewareFactory = (
logger: (msg: string) => void = console.log,
): PipelineMiddleware => ({
onStageStart: ({ metadata, currentStage }) => {
logger(`[${metadata.name}] starting ${currentStage}...`);
},
onStageComplete: ({ metadata, currentStage }) => {
logger(`[${metadata.name}] ${currentStage} completed`);
},
});
): PipelineMiddleware => {
return async ({ metadata, currentStage, next }) => {
const started = performance.now();

try {
return await next();
} finally {
logger(
`[${metadata.name}] ${currentStage} completed in ${performance.now() - started}ms`,
);
}
};
};
Loading

0 comments on commit c3aef7d

Please sign in to comment.