Skip to content

Commit

Permalink
Merge pull request #73 from nibble-4bits/feature/task-state-timeout
Browse files Browse the repository at this point in the history
Feature/task state timeout
  • Loading branch information
nibble-4bits authored Oct 1, 2023
2 parents 998905b + 191bcfa commit 474d6e4
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 7 deletions.
46 changes: 46 additions & 0 deletions __tests__/stateActions/TaskStateAction.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import type { TaskState } from '../../src/typings/TaskState';
import { LambdaClient } from '../../src/aws/LambdaClient';
import { TaskStateAction } from '../../src/stateMachine/stateActions/TaskStateAction';
import { StatesTimeoutError } from '../../src/error/predefined/StatesTimeoutError';
import { sleep } from '../../src/util';

afterEach(() => {
jest.clearAllMocks();
Expand Down Expand Up @@ -70,4 +72,48 @@ describe('Task State', () => {
expect(mockInvokeFunction).not.toHaveBeenCalled();
expect(stateResult).toBe(8);
});

test('should throw `States.Timeout` error if action runs longer than the value specified in `TimeoutSeconds` field', async () => {
const definition: TaskState = {
Type: 'Task',
Resource: 'mock-arn',
TimeoutSeconds: 1,
End: true,
};
const stateName = 'TaskState';
const input = { num1: 5, num2: 3 };
const context = {};

const localHandlerFn = async () => {
await sleep(1100);
return 1;
};
const options = { overrideFn: localHandlerFn, awsConfig: undefined };

const taskStateAction = new TaskStateAction(definition, stateName);

await expect(() => taskStateAction.execute(input, context, options)).rejects.toThrow(StatesTimeoutError);
});

test('should throw `States.Timeout` error if action runs longer than the value specified in `TimeoutSecondsPath` field', async () => {
const definition: TaskState = {
Type: 'Task',
Resource: 'mock-arn',
TimeoutSecondsPath: '$.taskTimeout',
End: true,
};
const stateName = 'TaskState';
const input = { num1: 5, num2: 3, taskTimeout: 1 };
const context = {};

const localHandlerFn = async () => {
await sleep(1100);
return 1;
};
const options = { overrideFn: localHandlerFn, awsConfig: undefined };

const taskStateAction = new TaskStateAction(definition, stateName);

await expect(() => taskStateAction.execute(input, context, options)).rejects.toThrow(StatesTimeoutError);
});
});
4 changes: 2 additions & 2 deletions docs/feature-support.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ The following features all come from the [Amazon States Language specification](
- [x] `TimestampPath`
- [x] Task
- [x] `Resource` (only Lambda functions supported)
- [x] `TimeoutSeconds`
- [x] `TimeoutSecondsPath`
- [x] Parallel
- [x] `Branches`
- [x] Map
Expand Down Expand Up @@ -106,9 +108,7 @@ The following features all come from the [Amazon States Language specification](

- States
- Task
- [ ] `TimeoutSeconds`
- [ ] `HeartbeatSeconds`
- [ ] `TimeoutSecondsPath`
- [ ] `HeartbeatSecondsPath`
- [ ] `Credentials`
- Map
Expand Down
54 changes: 49 additions & 5 deletions src/stateMachine/stateActions/TaskStateAction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,69 @@ import type { TaskState } from '../../typings/TaskState';
import type { JSONValue } from '../../typings/JSONValue';
import type { ActionResult, TaskStateActionOptions } from '../../typings/StateActions';
import type { Context } from '../../typings/Context';
import { StatesTimeoutError } from '../../error/predefined/StatesTimeoutError';
import { BaseStateAction } from './BaseStateAction';
import { LambdaClient } from '../../aws/LambdaClient';
import { jsonPathQuery } from '../jsonPath/JsonPath';
import { IntegerConstraint } from '../jsonPath/constraints/IntegerConstraint';

class TaskStateAction extends BaseStateAction<TaskState> {
private timeoutAbortController: AbortController;

constructor(stateDefinition: TaskState, stateName: string) {
super(stateDefinition, stateName);
this.timeoutAbortController = new AbortController();
}

private createTimeoutPromise(input: JSONValue, context: Context): Promise<never> | undefined {
const state = this.stateDefinition;

if (!state.TimeoutSeconds && !state.TimeoutSecondsPath) return;

let timeout: number;
if (state.TimeoutSeconds) {
timeout = state.TimeoutSeconds;
} else if (state.TimeoutSecondsPath) {
timeout = jsonPathQuery<number>(state.TimeoutSecondsPath, input, context, {
constraints: [IntegerConstraint.greaterThanOrEqual(1)],
});
}

return new Promise<never>((_, reject) => {
const handleTimeoutAbort = () => clearTimeout(timeoutId);

const timeoutId = setTimeout(() => {
this.timeoutAbortController.signal.removeEventListener('abort', handleTimeoutAbort);
reject(new StatesTimeoutError());
}, timeout * 1000);

this.timeoutAbortController.signal.addEventListener('abort', handleTimeoutAbort, { once: true });
});
}

override async execute(input: JSONValue, context: Context, options: TaskStateActionOptions): Promise<ActionResult> {
const state = this.stateDefinition;
const racingPromises = [];
const timeoutPromise = this.createTimeoutPromise(input, context);

// If local override for task resource is defined, use that
if (options.overrideFn) {
const result = await options.overrideFn(input);
return this.buildExecutionResult(result);
// If local override for task resource is defined, use that
const resultPromise = options.overrideFn(input);
racingPromises.push(resultPromise);
} else {
// Else, call Lambda in AWS using SDK
const lambdaClient = new LambdaClient(options.awsConfig);
const resultPromise = lambdaClient.invokeFunction(state.Resource, input);
racingPromises.push(resultPromise);
}

const lambdaClient = new LambdaClient(options.awsConfig);
const result = await lambdaClient.invokeFunction(state.Resource, input);
if (timeoutPromise) {
racingPromises.push(timeoutPromise);
}

const result = await Promise.race(racingPromises);
this.timeoutAbortController.abort();

return this.buildExecutionResult(result);
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/typings/TaskState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ interface BaseTaskState
CatchableState {
Type: 'Task';
Resource: string;
TimeoutSeconds?: number;
TimeoutSecondsPath?: string;
}

export type TaskState = (IntermediateState | TerminalState) & BaseTaskState;

0 comments on commit 474d6e4

Please sign in to comment.