Skip to content

Commit

Permalink
Refactor source code to use async/await instead of done
Browse files Browse the repository at this point in the history
  • Loading branch information
ejizba committed Mar 2, 2022
1 parent eba7624 commit 6c8fb5c
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 180 deletions.
92 changes: 16 additions & 76 deletions src/Context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@ import {
import { FunctionInfo } from './FunctionInfo';
import { Request } from './http/Request';
import { Response } from './http/Response';
import EventEmitter = require('events');
import LogLevel = rpc.RpcLog.Level;
import LogCategory = rpc.RpcLog.RpcLogCategory;

export function CreateContextAndInputs(
info: FunctionInfo,
request: rpc.IInvocationRequest,
logCallback: LogCallback,
callback: ResultCallback
) {
const context = new InvocationContext(info, request, logCallback, callback);
export function CreateContextAndInputs(info: FunctionInfo, request: rpc.IInvocationRequest, logCallback: LogCallback) {
const doneEmitter = new EventEmitter();
const context = new InvocationContext(info, request, logCallback, doneEmitter);

const bindings: ContextBindings = {};
const inputs: any[] = [];
Expand Down Expand Up @@ -76,6 +72,7 @@ export function CreateContextAndInputs(
return {
context: <Context>context,
inputs: inputs,
doneEmitter,
};
}

Expand All @@ -95,7 +92,7 @@ class InvocationContext implements Context {
info: FunctionInfo,
request: rpc.IInvocationRequest,
logCallback: LogCallback,
callback: ResultCallback
doneEmitter: EventEmitter
) {
this.invocationId = <string>request.invocationId;
this.traceContext = fromRpcTraceContext(request.traceContext);
Expand All @@ -107,89 +104,32 @@ class InvocationContext implements Context {
};
this.executionContext = executionContext;
this.bindings = {};
let _done = false;
let _promise = false;

// Log message that is tied to function invocation
this.log = Object.assign(
(...args: any[]) => logWithAsyncCheck(_done, logCallback, LogLevel.Information, executionContext, ...args),
{
error: (...args: any[]) =>
logWithAsyncCheck(_done, logCallback, LogLevel.Error, executionContext, ...args),
warn: (...args: any[]) =>
logWithAsyncCheck(_done, logCallback, LogLevel.Warning, executionContext, ...args),
info: (...args: any[]) =>
logWithAsyncCheck(_done, logCallback, LogLevel.Information, executionContext, ...args),
verbose: (...args: any[]) =>
logWithAsyncCheck(_done, logCallback, LogLevel.Trace, executionContext, ...args),
}
);
this.log = Object.assign((...args: any[]) => logCallback(LogLevel.Information, ...args), {
error: (...args: any[]) => logCallback(LogLevel.Error, ...args),
warn: (...args: any[]) => logCallback(LogLevel.Warning, ...args),
info: (...args: any[]) => logCallback(LogLevel.Information, ...args),
verbose: (...args: any[]) => logCallback(LogLevel.Trace, ...args),
});

this.bindingData = getNormalizedBindingData(request);
this.bindingDefinitions = getBindingDefinitions(info);

// isPromise is a hidden parameter that we set to true in the event of a returned promise
this.done = (err?: any, result?: any, isPromise?: boolean) => {
_promise = isPromise === true;
if (_done) {
if (_promise) {
logCallback(
LogLevel.Error,
LogCategory.User,
"Error: Choose either to return a promise or call 'done'. Do not use both in your script."
);
} else {
logCallback(
LogLevel.Error,
LogCategory.User,
"Error: 'done' has already been called. Please check your script for extraneous calls to 'done'."
);
}
return;
}
_done = true;

// Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object
if (info.httpOutputName && this.res && this.bindings[info.httpOutputName] === undefined) {
this.bindings[info.httpOutputName] = this.res;
}

callback(err, {
return: result,
bindings: this.bindings,
});
this.done = (err?: unknown, result?: any) => {
doneEmitter.emit('done', err, result);
};
}
}

// Emit warning if trying to log after function execution is done.
function logWithAsyncCheck(
done: boolean,
log: LogCallback,
level: LogLevel,
executionContext: ExecutionContext,
...args: any[]
) {
if (done) {
let badAsyncMsg =
"Warning: Unexpected call to 'log' on the context object after function execution has completed. Please check for asynchronous calls that are not awaited or calls to 'done' made before function execution completes. ";
badAsyncMsg += `Function name: ${executionContext.functionName}. Invocation Id: ${executionContext.invocationId}. `;
badAsyncMsg += `Learn more: https://go.microsoft.com/fwlink/?linkid=2097909 `;
log(LogLevel.Warning, LogCategory.System, badAsyncMsg);
}
return log(level, LogCategory.User, ...args);
}

export interface InvocationResult {
return: any;
bindings: ContextBindings;
}

export type DoneCallback = (err?: Error | string, result?: any) => void;

export type LogCallback = (level: LogLevel, category: rpc.RpcLog.RpcLogCategory, ...args: any[]) => void;
export type DoneCallback = (err?: unknown, result?: any) => void;

export type ResultCallback = (err?: any, result?: InvocationResult) => void;
export type LogCallback = (level: LogLevel, ...args: any[]) => void;

export interface Dict<T> {
[key: string]: T;
Expand Down
213 changes: 125 additions & 88 deletions src/eventHandlers/invocationRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

import { format } from 'util';
import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc';
import { CreateContextAndInputs, LogCallback, ResultCallback } from '../Context';
import { CreateContextAndInputs } from '../Context';
import { toTypedData } from '../converters';
import { isError } from '../utils/ensureErrorType';
import { nonNullProp } from '../utils/nonNull';
import { toRpcStatus } from '../utils/toRpcStatus';
import { WorkerChannel } from '../WorkerChannel';
Expand All @@ -16,28 +17,89 @@ import LogLevel = rpc.RpcLog.Level;
* @param requestId gRPC message request id
* @param msg gRPC message content
*/
export function invocationRequest(channel: WorkerChannel, requestId: string, msg: rpc.IInvocationRequest) {
export async function invocationRequest(channel: WorkerChannel, requestId: string, msg: rpc.IInvocationRequest) {
const response: rpc.IInvocationResponse = {
invocationId: msg.invocationId,
result: toRpcStatus(),
};
// explicitly set outputData to empty array to concat later
response.outputData = [];

let isDone = false;
let resultIsPromise = false;

const info = channel.functionLoader.getInfo(nonNullProp(msg, 'functionId'));
const logCallback: LogCallback = (level, category, ...args) => {

function log(level: LogLevel, category: LogCategory, ...args: any[]) {
channel.log({
invocationId: msg.invocationId,
category: `${info.name}.Invocation`,
message: format.apply(null, <[any, any[]]>args),
level: level,
logCategory: category,
});
};
}
function systemLog(level: LogLevel, ...args: any[]) {
log(level, LogCategory.System, ...args);
}
function userLog(level: LogLevel, ...args: any[]) {
if (isDone) {
let badAsyncMsg =
"Warning: Unexpected call to 'log' on the context object after function execution has completed. Please check for asynchronous calls that are not awaited or calls to 'done' made before function execution completes. ";
badAsyncMsg += `Function name: ${info.name}. Invocation Id: ${msg.invocationId}. `;
badAsyncMsg += `Learn more: https://go.microsoft.com/fwlink/?linkid=2097909 `;
systemLog(LogLevel.Warning, badAsyncMsg);
}
log(level, LogCategory.User, ...args);
}

// Log invocation details to ensure the invocation received by node worker
logCallback(LogLevel.Debug, LogCategory.System, 'Received FunctionInvocationRequest');
systemLog(LogLevel.Debug, 'Received FunctionInvocationRequest');

const resultCallback: ResultCallback = (err: unknown, result) => {
const response: rpc.IInvocationResponse = {
invocationId: msg.invocationId,
result: toRpcStatus(err),
};
// explicitly set outputData to empty array to concat later
response.outputData = [];
function onDone(): void {
if (isDone) {
const message = resultIsPromise
? "Error: Choose either to return a promise or call 'done'. Do not use both in your script."
: "Error: 'done' has already been called. Please check your script for extraneous calls to 'done'.";
systemLog(LogLevel.Error, message);
}
isDone = true;
}

const { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
try {
const legacyDoneTask = new Promise((resolve, reject) => {
doneEmitter.on('done', (err?: unknown, result?: any) => {
onDone();
if (isError(err)) {
reject(err);
} else {
resolve(result);
}
});
});

let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
userFunction = channel.runInvocationRequestBefore(context, userFunction);
let rawResult = userFunction(context, ...inputs);
resultIsPromise = rawResult && typeof rawResult.then === 'function';
let resultTask: Promise<any>;
if (resultIsPromise) {
rawResult = Promise.resolve(rawResult).then((r) => {
onDone();
return r;
});
resultTask = Promise.race([rawResult, legacyDoneTask]);
} else {
resultTask = legacyDoneTask;
}

const result = await resultTask;

// Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object
if (info.httpOutputName && context.res && context.bindings[info.httpOutputName] === undefined) {
context.bindings[info.httpOutputName] = context.res;
}

// As legacy behavior, falsy values get serialized to `null` in AzFunctions.
// This breaks Durable Functions expectations, where customers expect any
Expand All @@ -46,86 +108,61 @@ export function invocationRequest(channel: WorkerChannel, requestId: string, msg
// values get serialized.
const isDurableBinding = info?.bindings?.name?.type == 'activityTrigger';

try {
if (result || (isDurableBinding && result != null)) {
const returnBinding = info.getReturnBinding();
// Set results from return / context.done
if (result.return || (isDurableBinding && result.return != null)) {
// $return binding is found: return result data to $return binding
if (returnBinding) {
response.returnValue = returnBinding.converter(result.return);
// $return binding is not found: read result as object of outputs
} else {
response.outputData = Object.keys(info.outputBindings)
.filter((key) => result.return[key] !== undefined)
.map(
(key) =>
<rpc.IParameterBinding>{
name: key,
data: info.outputBindings[key].converter(result.return[key]),
}
);
}
// returned value does not match any output bindings (named or $return)
// if not http, pass along value
if (!response.returnValue && response.outputData.length == 0 && !info.hasHttpTrigger) {
response.returnValue = toTypedData(result.return);
}
}
// Set results from context.bindings
if (result.bindings) {
response.outputData = response.outputData.concat(
Object.keys(info.outputBindings)
// Data from return prioritized over data from context.bindings
.filter((key) => {
const definedInBindings: boolean = result.bindings[key] !== undefined;
const hasReturnValue = !!result.return;
const hasReturnBinding = !!returnBinding;
const definedInReturn: boolean =
hasReturnValue && !hasReturnBinding && result.return[key] !== undefined;
return definedInBindings && !definedInReturn;
})
.map(
(key) =>
<rpc.IParameterBinding>{
name: key,
data: info.outputBindings[key].converter(result.bindings[key]),
}
)
const returnBinding = info.getReturnBinding();
// Set results from return / context.done
if (result || (isDurableBinding && result != null)) {
// $return binding is found: return result data to $return binding
if (returnBinding) {
response.returnValue = returnBinding.converter(result);
// $return binding is not found: read result as object of outputs
} else {
response.outputData = Object.keys(info.outputBindings)
.filter((key) => result[key] !== undefined)
.map(
(key) =>
<rpc.IParameterBinding>{
name: key,
data: info.outputBindings[key].converter(result[key]),
}
);
}
}
} catch (err) {
response.result = toRpcStatus(err);
// returned value does not match any output bindings (named or $return)
// if not http, pass along value
if (!response.returnValue && response.outputData.length == 0 && !info.hasHttpTrigger) {
response.returnValue = toTypedData(result);
}
}
channel.eventStream.write({
requestId: requestId,
invocationResponse: response,
});

channel.runInvocationRequestAfter(context);
};

const { context, inputs } = CreateContextAndInputs(info, msg, logCallback, resultCallback);
let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));

userFunction = channel.runInvocationRequestBefore(context, userFunction);

// catch user errors from the same async context in the event loop and correlate with invocation
// throws from asynchronous work (setTimeout, etc) are caught by 'unhandledException' and cannot be correlated with invocation
try {
const result = userFunction(context, ...inputs);

if (result && typeof result.then === 'function') {
result
.then((result) => {
(<any>context.done)(null, result, true);
})
.catch((err) => {
(<any>context.done)(err, null, true);
});
// Set results from context.bindings
if (context.bindings) {
response.outputData = response.outputData.concat(
Object.keys(info.outputBindings)
// Data from return prioritized over data from context.bindings
.filter((key) => {
const definedInBindings: boolean = context.bindings[key] !== undefined;
const hasReturnValue = !!result;
const hasReturnBinding = !!returnBinding;
const definedInReturn: boolean =
hasReturnValue && !hasReturnBinding && result[key] !== undefined;
return definedInBindings && !definedInReturn;
})
.map(
(key) =>
<rpc.IParameterBinding>{
name: key,
data: info.outputBindings[key].converter(context.bindings[key]),
}
)
);
}
} catch (err) {
resultCallback(err);
response.result = toRpcStatus(err);
isDone = true;
}

channel.eventStream.write({
requestId: requestId,
invocationResponse: response,
});

channel.runInvocationRequestAfter(context);
}
2 changes: 1 addition & 1 deletion src/setupEventStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export function setupEventStream(workerId: string, channel: WorkerChannel): void
void functionLoadRequest(channel, msg.requestId, nonNullProp(msg, eventName));
break;
case 'invocationRequest':
invocationRequest(channel, msg.requestId, nonNullProp(msg, eventName));
void invocationRequest(channel, msg.requestId, nonNullProp(msg, eventName));
break;
case 'workerInitRequest':
workerInitRequest(channel, msg.requestId, nonNullProp(msg, eventName));
Expand Down
Loading

0 comments on commit 6c8fb5c

Please sign in to comment.