Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(workflows-sdk): Configurable retries upon step creation #5728

Merged
merged 10 commits into from
Dec 19, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,40 @@ describe("Workflow composer", function () {
jest.clearAllMocks()
})

it("should compose a new workflow composed retryable steps", async () => {
const maxRetries = 1

const mockStep1Fn = jest.fn().mockImplementation((input, context) => {
const attempt = context.metadata.attempt || 0
if (attempt <= maxRetries) {
throw new Error("test error")
}

return { inputs: [input], obj: "return from 1" }
})

const step1 = createStep({ name: "step1", maxRetries }, mockStep1Fn)
adrien2p marked this conversation as resolved.
Show resolved Hide resolved

const workflow = createWorkflow("workflow1", function (input) {
return step1(input)
})

const workflowInput = { test: "payload1" }
const { result: workflowResult } = await workflow().run({
input: workflowInput,
})

expect(mockStep1Fn).toHaveBeenCalledTimes(2)
expect(mockStep1Fn.mock.calls[0]).toHaveLength(2)
expect(mockStep1Fn.mock.calls[0][0]).toEqual(workflowInput)
expect(mockStep1Fn.mock.calls[1][0]).toEqual(workflowInput)

expect(workflowResult).toEqual({
inputs: [{ test: "payload1" }],
obj: "return from 1",
})
})

it("should compose a new workflow and execute it", async () => {
const mockStep1Fn = jest.fn().mockImplementation((input) => {
return { inputs: [input], obj: "return from 1" }
Expand Down Expand Up @@ -928,6 +962,73 @@ describe("Workflow composer", function () {
jest.clearAllMocks()
})

it("should compose a new workflow composed of retryable steps", async () => {
const maxRetries = 1

const mockStep1Fn = jest.fn().mockImplementation((input, context) => {
const attempt = context.metadata.attempt || 0
if (attempt <= maxRetries) {
throw new Error("test error")
}

return new StepResponse({ inputs: [input], obj: "return from 1" })
})

const step1 = createStep({ name: "step1", maxRetries }, mockStep1Fn)

const workflow = createWorkflow("workflow1", function (input) {
return step1(input)
})

const workflowInput = { test: "payload1" }
const { result: workflowResult } = await workflow().run({
input: workflowInput,
})

expect(mockStep1Fn).toHaveBeenCalledTimes(2)
expect(mockStep1Fn.mock.calls[0]).toHaveLength(2)
expect(mockStep1Fn.mock.calls[0][0]).toEqual(workflowInput)
expect(mockStep1Fn.mock.calls[1][0]).toEqual(workflowInput)

expect(workflowResult).toEqual({
inputs: [{ test: "payload1" }],
obj: "return from 1",
})
})

it("should compose a new workflow composed of retryable steps that should stop retries on permanent failure", async () => {
const maxRetries = 1

const mockStep1Fn = jest.fn().mockImplementation((input, context) => {
return StepResponse.permanentFailure("fail permanently")
})

const step1 = createStep({ name: "step1", maxRetries }, mockStep1Fn)

const workflow = createWorkflow("workflow1", function (input) {
return step1(input)
})

const workflowInput = { test: "payload1" }
const { errors } = await workflow().run({
input: workflowInput,
throwOnError: false,
})

expect(mockStep1Fn).toHaveBeenCalledTimes(1)
expect(mockStep1Fn.mock.calls[0]).toHaveLength(2)
expect(mockStep1Fn.mock.calls[0][0]).toEqual(workflowInput)

expect(errors).toHaveLength(1)
expect(errors[0]).toEqual({
action: "step1",
handlerType: "invoke",
error: expect.objectContaining({
message: "fail permanently",
}),
})
})

it("should compose a new workflow and execute it", async () => {
const mockStep1Fn = jest.fn().mockImplementation((input) => {
return new StepResponse({ inputs: [input], obj: "return from 1" })
Expand Down
15 changes: 15 additions & 0 deletions packages/orchestration/src/transaction/errors.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export class PermanentStepFailureError extends Error {
static isPermanentStepFailureError(
error: Error
): error is PermanentStepFailureError {
return (
error instanceof PermanentStepFailureError ||
error.name === "PermanentStepFailure"
)
}

constructor(message?: string) {
super(message)
this.name = "PermanentStepFailure"
}
}
1 change: 1 addition & 0 deletions packages/orchestration/src/transaction/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export * from "./transaction-orchestrator"
export * from "./transaction-step"
export * from "./distributed-transaction"
export * from "./orchestrator-builder"
export * from "./errors"
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {

import { EventEmitter } from "events"
import { promiseAll } from "@medusajs/utils"
import { PermanentStepFailureError } from "./errors"

export type TransactionFlow = {
modelId: string
Expand Down Expand Up @@ -367,24 +368,37 @@ export class TransactionOrchestrator extends EventEmitter {
transaction.getContext()
)

const setStepFailure = async (
error: Error | any,
{ endRetry }: { endRetry?: boolean } = {}
) => {
return TransactionOrchestrator.setStepFailure(
transaction,
step,
error,
endRetry ? 0 : step.definition.maxRetries
)
}

if (!step.definition.async) {
execution.push(
transaction
.handler(step.definition.action + "", type, payload, transaction)
.then(async (response) => {
.then(async (response: any) => {
await TransactionOrchestrator.setStepSuccess(
transaction,
step,
response
)
})
.catch(async (error) => {
await TransactionOrchestrator.setStepFailure(
transaction,
step,
error,
step.definition.maxRetries
)
if (
PermanentStepFailureError.isPermanentStepFailureError(error)
) {
await setStepFailure(error, { endRetry: true })
return
}
await setStepFailure(error)
})
)
} else {
Expand All @@ -393,12 +407,13 @@ export class TransactionOrchestrator extends EventEmitter {
transaction
.handler(step.definition.action + "", type, payload, transaction)
.catch(async (error) => {
await TransactionOrchestrator.setStepFailure(
transaction,
step,
error,
step.definition.maxRetries
)
if (
PermanentStepFailureError.isPermanentStepFailureError(error)
) {
await setStepFailure(error, { endRetry: true })
return
}
await setStepFailure(error)
})
)
)
Expand Down
1 change: 1 addition & 0 deletions packages/utils/src/bundles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ export * as ModulesSdkUtils from "./modules-sdk"
export * as ProductUtils from "./product"
export * as SearchUtils from "./search"
export * as ShippingProfileUtils from "./shipping"
export * as OrchestrationUtils from "./orchestration"
1 change: 1 addition & 0 deletions packages/utils/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ export * from "./pricing"
export * from "./product"
export * from "./search"
export * from "./shipping"
export * from "./orchestration"
1 change: 1 addition & 0 deletions packages/utils/src/orchestration/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export * from "./symbol"
11 changes: 8 additions & 3 deletions packages/workflows-sdk/src/helper/workflow-export.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { Context, LoadedModule, MedusaContainer } from "@medusajs/types"
import { MedusaModule } from "@medusajs/modules-sdk"
import { EOL } from "os"
import { ulid } from "ulid"
import { SymbolWorkflowWorkflowData } from "../utils/composer"
import { OrchestrationUtils } from "@medusajs/utils"

export type FlowRunOptions<TData = unknown> = {
input?: TData
Expand Down Expand Up @@ -99,11 +99,16 @@ export const exportWorkflow = <TData = unknown, TResult = unknown>(
if (Array.isArray(resultFrom)) {
result = resultFrom.map((from) => {
const res = transaction.getContext().invoke?.[from]
return res?.__type === SymbolWorkflowWorkflowData ? res.output : res
return res?.__type === OrchestrationUtils.SymbolWorkflowWorkflowData
? res.output
: res
})
} else {
const res = transaction.getContext().invoke?.[resultFrom]
result = res?.__type === SymbolWorkflowWorkflowData ? res.output : res
result =
res?.__type === OrchestrationUtils.SymbolWorkflowWorkflowData
? res.output
: res
}
}

Expand Down
47 changes: 25 additions & 22 deletions packages/workflows-sdk/src/utils/composer/create-step.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
import {
resolveValue,
StepResponse,
SymbolMedusaWorkflowComposerContext,
SymbolWorkflowStep,
SymbolWorkflowStepBind,
SymbolWorkflowStepResponse,
SymbolWorkflowWorkflowData,
} from "./helpers"
import { resolveValue, StepResponse } from "./helpers"
import {
CreateWorkflowComposerContext,
StepExecutionContext,
Expand All @@ -15,6 +7,8 @@ import {
WorkflowData,
} from "./type"
import { proxify } from "./helpers/proxy"
import { TransactionStepsDefinition } from "@medusajs/orchestration"
import { isString, OrchestrationUtils } from "@medusajs/utils"

/**
* The type of invocation function passed to a step.
Expand Down Expand Up @@ -75,6 +69,7 @@ interface ApplyStepOptions<
TInvokeResultCompensateInput
> {
stepName: string
stepConfig?: TransactionStepsDefinition
input: TStepInputs
invokeFn: InvokeFn<
TInvokeInput,
Expand All @@ -91,6 +86,7 @@ interface ApplyStepOptions<
* This is where the inputs and context are passed to the underlying invoke and compensate function.
*
* @param stepName
* @param stepConfig
* @param input
* @param invokeFn
* @param compensateFn
Expand All @@ -104,6 +100,7 @@ function applyStep<
TInvokeResultCompensateInput
>({
stepName,
stepConfig = {},
input,
invokeFn,
compensateFn,
Expand Down Expand Up @@ -135,12 +132,12 @@ function applyStep<
)

const stepResponseJSON =
stepResponse?.__type === SymbolWorkflowStepResponse
stepResponse?.__type === OrchestrationUtils.SymbolWorkflowStepResponse
? stepResponse.toJSON()
: stepResponse

return {
__type: SymbolWorkflowWorkflowData,
__type: OrchestrationUtils.SymbolWorkflowWorkflowData,
output: stepResponseJSON,
}
},
Expand All @@ -154,7 +151,8 @@ function applyStep<

const stepOutput = transactionContext.invoke[stepName]?.output
const invokeResult =
stepOutput?.__type === SymbolWorkflowStepResponse
stepOutput?.__type ===
OrchestrationUtils.SymbolWorkflowStepResponse
? stepOutput.compensateInput &&
JSON.parse(JSON.stringify(stepOutput.compensateInput))
: stepOutput && JSON.parse(JSON.stringify(stepOutput))
Expand All @@ -168,13 +166,13 @@ function applyStep<
: undefined,
}

this.flow.addAction(stepName, {
noCompensation: !compensateFn,
})
stepConfig!.noCompensation = !compensateFn

this.flow.addAction(stepName, stepConfig)
this.handlers.set(stepName, handler)

const ret = {
__type: SymbolWorkflowStep,
__type: OrchestrationUtils.SymbolWorkflowStep,
__step__: stepName,
}

Expand Down Expand Up @@ -236,9 +234,11 @@ export function createStep<
TInvokeResultCompensateInput
>(
/**
* The name of the step.
* The name of the step or its configuration (currently support maxRetries).
*/
name: string,
nameOrConfig:
| string
| ({ name: string } & Pick<TransactionStepsDefinition, "maxRetries">),
/**
* An invocation function that will be executed when the workflow is executed. The function must return an instance of {@link StepResponse}. The constructor of {@link StepResponse}
* accepts the output of the step as a first argument, and optionally as a second argument the data to be passed to the compensation function as a parameter.
Expand All @@ -256,20 +256,22 @@ export function createStep<
*/
compensateFn?: CompensateFn<TInvokeResultCompensateInput>
): StepFunction<TInvokeInput, TInvokeResultOutput> {
const stepName = name ?? invokeFn.name
const stepName =
(isString(nameOrConfig) ? nameOrConfig : nameOrConfig.name) ?? invokeFn.name
const config = isString(nameOrConfig) ? {} : nameOrConfig

const returnFn = function (input: {
[K in keyof TInvokeInput]: WorkflowData<TInvokeInput[K]>
}): WorkflowData<TInvokeResultOutput> {
if (!global[SymbolMedusaWorkflowComposerContext]) {
if (!global[OrchestrationUtils.SymbolMedusaWorkflowComposerContext]) {
throw new Error(
"createStep must be used inside a createWorkflow definition"
)
}

const stepBinder = (
global[
SymbolMedusaWorkflowComposerContext
OrchestrationUtils.SymbolMedusaWorkflowComposerContext
] as CreateWorkflowComposerContext
).stepBinder

Expand All @@ -281,14 +283,15 @@ export function createStep<
TInvokeResultCompensateInput
>({
stepName,
stepConfig: config,
input,
invokeFn,
compensateFn,
})
)
}

returnFn.__type = SymbolWorkflowStepBind
returnFn.__type = OrchestrationUtils.SymbolWorkflowStepBind
returnFn.__step__ = stepName

return returnFn as unknown as StepFunction<TInvokeInput, TInvokeResultOutput>
Expand Down
Loading
Loading