Skip to content

Commit

Permalink
Add pre and post invocation hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ejizba committed Mar 2, 2022
1 parent 6c8fb5c commit 81f0154
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 151 deletions.
1 change: 1 addition & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"@typescript-eslint/restrict-template-expressions": "off",
"@typescript-eslint/unbound-method": "off",
"no-empty": "off",
"prefer-const": ["error", { "destructuring": "all" }],
"prefer-rest-params": "off",
"prefer-spread": "off"
},
Expand Down
35 changes: 35 additions & 0 deletions src/Disposable.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

/**
* Based off of VS Code
* https://github.com/microsoft/vscode/blob/a64e8e5673a44e5b9c2d493666bde684bd5a135c/src/vs/workbench/api/common/extHostTypes.ts#L32
*/
export class Disposable {
static from(...inDisposables: { dispose(): any }[]): Disposable {
let disposables: ReadonlyArray<{ dispose(): any }> | undefined = inDisposables;
return new Disposable(function () {
if (disposables) {
for (const disposable of disposables) {
if (disposable && typeof disposable.dispose === 'function') {
disposable.dispose();
}
}
disposables = undefined;
}
});
}

#callOnDispose?: () => any;

constructor(callOnDispose: () => any) {
this.#callOnDispose = callOnDispose;
}

dispose(): any {
if (this.#callOnDispose instanceof Function) {
this.#callOnDispose();
this.#callOnDispose = undefined;
}
}
}
71 changes: 43 additions & 28 deletions src/WorkerChannel.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { Context } from '@azure/functions';
import { HookCallback } from '@azure/functions-worker';
import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc';
import { Disposable } from './Disposable';
import { IFunctionLoader } from './FunctionLoader';
import { IEventStream } from './GrpcClient';

type InvocationRequestBefore = (context: Context, userFn: Function) => Function;
type InvocationRequestAfter = (context: Context) => void;
import Module = require('module');

export class WorkerChannel {
public eventStream: IEventStream;
public functionLoader: IFunctionLoader;
private _invocationRequestBefore: InvocationRequestBefore[];
private _invocationRequestAfter: InvocationRequestAfter[];
private _preInvocationHooks: HookCallback[] = [];
private _postInvocationHooks: HookCallback[] = [];

constructor(eventStream: IEventStream, functionLoader: IFunctionLoader) {
this.eventStream = eventStream;
this.functionLoader = functionLoader;
this._invocationRequestBefore = [];
this._invocationRequestAfter = [];
this.initWorkerModule(this);
}

/**
Expand All @@ -33,32 +31,49 @@ export class WorkerChannel {
});
}

/**
* Register a patching function to be run before User Function is executed.
* Hook should return a patched version of User Function.
*/
public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void {
this._invocationRequestBefore.push(beforeCb);
public registerHook(hookName: string, callback: HookCallback): Disposable {
const hooks = this.getHooks(hookName);
hooks.push(callback);
return new Disposable(() => {
const index = hooks.indexOf(callback);
if (index > -1) {
hooks.splice(index, 1);
}
});
}

/**
* Register a function to be run after User Function resolves.
*/
public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void {
this._invocationRequestAfter.push(afterCb);
public async executeHooks(hookName: string, context: {}): Promise<void> {
const callbacks = this.getHooks(hookName);
for (const callback of callbacks) {
await callback(context);
}
}

public runInvocationRequestBefore(context: Context, userFunction: Function): Function {
let wrappedFunction = userFunction;
for (const before of this._invocationRequestBefore) {
wrappedFunction = before(context, wrappedFunction);
private getHooks(hookName: string): HookCallback[] {
switch (hookName) {
case 'preInvocation':
return this._preInvocationHooks;
case 'postInvocation':
return this._postInvocationHooks;
default:
throw new RangeError(`Unrecognized hook "${hookName}"`);
}
return wrappedFunction;
}

public runInvocationRequestAfter(context: Context) {
for (const after of this._invocationRequestAfter) {
after(context);
}
private initWorkerModule(channel: WorkerChannel) {
const workerApi = {
registerHook: (hookName: string, callback: HookCallback) => channel.registerHook(hookName, callback),
Disposable,
};

Module.prototype.require = new Proxy(Module.prototype.require, {
apply(target, thisArg, argArray) {
if (argArray[0] === '@azure/functions-worker') {
return workerApi;
} else {
return Reflect.apply(target, thisArg, argArray);
}
},
});
}
}
26 changes: 20 additions & 6 deletions src/eventHandlers/invocationRequest.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License.

import { PostInvocationContext, PreInvocationContext } from '@azure/functions-worker';
import { format } from 'util';
import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc';
import { CreateContextAndInputs } from '../Context';
Expand Down Expand Up @@ -66,7 +67,7 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
isDone = true;
}

const { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
let { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
try {
const legacyDoneTask = new Promise((resolve, reject) => {
doneEmitter.on('done', (err?: unknown, result?: any) => {
Expand All @@ -79,8 +80,12 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
});
});

let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
userFunction = channel.runInvocationRequestBefore(context, userFunction);
const userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
const preInvocContext: PreInvocationContext = { invocationContext: context, inputs };

await channel.executeHooks('preInvocation', preInvocContext);
inputs = preInvocContext.inputs;

let rawResult = userFunction(context, ...inputs);
resultIsPromise = rawResult && typeof rawResult.then === 'function';
let resultTask: Promise<any>;
Expand All @@ -94,7 +99,18 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
resultTask = legacyDoneTask;
}

const result = await resultTask;
const postInvocContext: PostInvocationContext = Object.assign(preInvocContext, { result: null, error: null });
try {
postInvocContext.result = await resultTask;
} catch (err) {
postInvocContext.error = err;
}
await channel.executeHooks('postInvocation', postInvocContext);

if (isError(postInvocContext.error)) {
throw postInvocContext.error;
}
const result = postInvocContext.result;

// Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object
if (info.httpOutputName && context.res && context.bindings[info.httpOutputName] === undefined) {
Expand Down Expand Up @@ -163,6 +179,4 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
requestId: requestId,
invocationResponse: response,
});

channel.runInvocationRequestAfter(context);
}
Loading

0 comments on commit 81f0154

Please sign in to comment.