Skip to content

Commit

Permalink
Add RedriveExecution support to StepFunctionStartExecutionOperator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan committed Jul 24, 2024
1 parent b4e82cf commit 68b3159
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 4 deletions.
18 changes: 18 additions & 0 deletions airflow/providers/amazon/aws/hooks/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import json

from airflow.exceptions import AirflowFailException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


Expand All @@ -43,6 +44,7 @@ def start_execution(
state_machine_arn: str,
name: str | None = None,
state_machine_input: dict | str | None = None,
is_redrive_execution: bool = False,
) -> str:
"""
Start Execution of the State Machine.
Expand All @@ -51,10 +53,26 @@ def start_execution(
- :external+boto3:py:meth:`SFN.Client.start_execution`
:param state_machine_arn: AWS Step Function State Machine ARN.
:param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
complete successfully in the last 14 days.
:param name: The name of the execution.
:param state_machine_input: JSON data input to pass to the State Machine.
:return: Execution ARN.
"""
if is_redrive_execution:
if not name:
raise AirflowFailException(
"Execution name is required to start RedriveExecution for %s.", state_machine_arn
)
elements = state_machine_arn.split(":stateMachine:")
execution_arn = f"{elements[0]}:execution:{elements[1]}:{name}"
self.conn.redrive_execution(executionArn=execution_arn)
self.log.info(
"Successfully started RedriveExecution for Step Function State Machine: %s.",
state_machine_arn,
)
return execution_arn

execution_args = {"stateMachineArn": state_machine_arn}
if name is not None:
execution_args["name"] = name
Expand Down
14 changes: 12 additions & 2 deletions airflow/providers/amazon/aws/operators/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
:param state_machine_arn: ARN of the Step Function State Machine
:param name: The name of the execution.
:param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
complete successfully in the last 14 days.
:param state_machine_input: JSON data input to pass to the State Machine
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
Expand All @@ -73,7 +75,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
"""

aws_hook_class = StepFunctionHook
template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input")
template_fields: Sequence[str] = aws_template_fields(
"state_machine_arn", "name", "input", "is_redrive_execution"
)
ui_color = "#f9c915"
operator_extra_links = (StateMachineDetailsLink(), StateMachineExecutionsDetailsLink())

Expand All @@ -82,6 +86,7 @@ def __init__(
*,
state_machine_arn: str,
name: str | None = None,
is_redrive_execution: bool = False,
state_machine_input: dict | str | None = None,
waiter_max_attempts: int = 30,
waiter_delay: int = 60,
Expand All @@ -91,6 +96,7 @@ def __init__(
super().__init__(**kwargs)
self.state_machine_arn = state_machine_arn
self.name = name
self.is_redrive_execution = is_redrive_execution
self.input = state_machine_input
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
Expand All @@ -105,7 +111,11 @@ def execute(self, context: Context):
state_machine_arn=self.state_machine_arn,
)

if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)):
if not (
execution_arn := self.hook.start_execution(
self.state_machine_arn, self.name, self.input, self.is_redrive_execution
)
):
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")

StateMachineExecutionsDetailsLink.persist(
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/amazon/aws/hooks/test_step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from unittest import mock

import pytest
from moto import mock_aws

from airflow.exceptions import AirflowFailException
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook


Expand All @@ -42,6 +47,31 @@ def test_start_execution(self):

assert execution_arn is not None

@mock.patch.object(StepFunctionHook, "conn")
def test_redrive_execution(self, mock_conn):
mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)}
StepFunctionHook().start_execution(
state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine",
name="random-123",
is_redrive_execution=True,
)

mock_conn.redrive_execution.assert_called_once_with(
executionArn="arn:aws:states:us-east-1:123456789012:execution:test-state-machine:random-123"
)

@mock.patch.object(StepFunctionHook, "conn")
def test_redrive_execution_without_name_should_fail(self, mock_conn):
mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)}

with pytest.raises(
AirflowFailException, match="Execution name is required to start RedriveExecution"
):
StepFunctionHook().start_execution(
state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine",
is_redrive_execution=True,
)

def test_describe_execution(self):
hook = StepFunctionHook(aws_conn_id="aws_default", region_name="us-east-1")
state_machine = hook.get_conn().create_state_machine(
Expand Down
36 changes: 34 additions & 2 deletions tests/providers/amazon/aws/operators/test_step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_execute(self, mocked_hook, mocked_context):
aws_conn_id=None,
)
assert op.execute(mocked_context) == hook_response
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT, False)
self.mocked_details_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_step_function_start_execution_deferrable(self, mocked_hook):
)
with pytest.raises(TaskDeferred):
operator.execute(None)
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT, False)

@mock.patch.object(StepFunctionStartExecutionOperator, "hook")
@pytest.mark.parametrize("execution_arn", [pytest.param(None, id="none"), pytest.param("", id="empty")])
Expand All @@ -200,3 +200,35 @@ def test_step_function_no_execution_arn_returns(self, mocked_hook, execution_arn
)
with pytest.raises(AirflowException, match="Failed to start State Machine execution"):
op.execute({})

@mock.patch.object(StepFunctionStartExecutionOperator, "hook")
def test_start_redrive_execution(self, mocked_hook, mocked_context):
hook_response = (
"arn:aws:states:us-east-1:123456789012:execution:"
"pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934"
)
mocked_hook.start_execution.return_value = hook_response
op = StepFunctionStartExecutionOperator(
task_id=self.TASK_ID,
state_machine_arn=STATE_MACHINE_ARN,
name=NAME,
is_redrive_execution=True,
state_machine_input=None,
aws_conn_id=None,
)
assert op.execute(mocked_context) == hook_response
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, None, True)
self.mocked_details_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
operator=mock.ANY,
region_name=mock.ANY,
state_machine_arn=STATE_MACHINE_ARN,
)
self.mocked_executions_details_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
operator=mock.ANY,
region_name=mock.ANY,
execution_arn=EXECUTION_ARN,
)

0 comments on commit 68b3159

Please sign in to comment.