Skip to content

Commit

Permalink
feat: Add custom description for CloudFormation template (#161)
Browse files Browse the repository at this point in the history
* feat: Add custom description to CFN template
  • Loading branch information
ca-nguyen authored Sep 9, 2021
1 parent f8bbfaf commit d2ce83d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
13 changes: 11 additions & 2 deletions src/stepfunctions/workflow/cloudformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

logger = logging.getLogger('stepfunctions')


def repr_str(dumper, data):
if '\n' in data:
return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
return dumper.org_represent_str(data)


yaml.SafeDumper.org_represent_str = yaml.SafeDumper.represent_str
yaml.add_representer(dict, lambda self, data: yaml.representer.SafeRepresenter.represent_dict(self, data.items()), Dumper=yaml.SafeDumper)
yaml.add_representer(str, repr_str, Dumper=yaml.SafeDumper)
Expand All @@ -42,12 +44,19 @@ def repr_str(dumper, data):
}
}

def build_cloudformation_template(workflow):

def build_cloudformation_template(workflow, description=None):
"""
Creates a CloudFormation template from the provided Workflow
Args:
workflow (Workflow): Step Functions workflow instance
description (str, optional): Description of the template. If none provided, the default description will be used: "CloudFormation template for AWS Step Functions - State Machine"
"""
logger.warning('To reuse the CloudFormation template in different regions, please make sure to update the region specific AWS resources in the StateMachine definition.')

template = CLOUDFORMATION_BASE_TEMPLATE.copy()

template["Description"] = "CloudFormation template for AWS Step Functions - State Machine"
template["Description"] = description if description else "CloudFormation template for AWS Step Functions - State Machine"
template["Resources"]["StateMachineComponent"]["Properties"]["StateMachineName"] = workflow.name

definition = workflow.definition.to_dict()
Expand Down
6 changes: 4 additions & 2 deletions src/stepfunctions/workflow/stepfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,13 @@ def render_graph(self, portrait=False):
widget = WorkflowGraphWidget(self.definition.to_json())
return widget.show(portrait=portrait)

def get_cloudformation_template(self):
def get_cloudformation_template(self, description=None):
"""
Returns a CloudFormation template that contains only the StateMachine resource. To reuse the CloudFormation template in a different region, please make sure to update the region specific AWS resources (e.g: Lambda ARN, Training Image) in the StateMachine definition.
Args:
description (str, optional): Description of the template
"""
return build_cloudformation_template(self)
return build_cloudformation_template(self, description)

def __repr__(self):
return '{}(name={!r}, role={!r}, state_machine_arn={!r})'.format(
Expand Down
21 changes: 20 additions & 1 deletion tests/unit/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def client():
})
return sfn


@pytest.fixture
def workflow(client):
workflow = Workflow(
Expand All @@ -67,9 +68,11 @@ def workflow(client):
workflow.create()
return workflow


def test_workflow_creation(client, workflow):
assert workflow.state_machine_arn == state_machine_arn


def test_workflow_creation_failure_duplicate_state_ids(client):
improper_definition = steps.Chain([steps.Pass('HelloWorld'), steps.Succeed('HelloWorld')])
with pytest.raises(ValueError):
Expand All @@ -80,6 +83,7 @@ def test_workflow_creation_failure_duplicate_state_ids(client):
client=client
)


# calling update() before create()
def test_workflow_update_when_statemachinearn_is_none(client):
workflow = Workflow(
Expand All @@ -92,11 +96,13 @@ def test_workflow_update_when_statemachinearn_is_none(client):
with pytest.raises(WorkflowNotFound):
workflow.update(definition=new_definition)


# calling update() after create() without arguments
def test_workflow_update_when_arguments_are_missing(client, workflow):
with pytest.raises(MissingRequiredParameter):
workflow.update()


# calling update() after create()
def test_workflow_update(client, workflow):
client.update_state_machine = MagicMock(return_value={
Expand All @@ -106,12 +112,14 @@ def test_workflow_update(client, workflow):
new_role = 'arn:aws:iam::1234567890:role/service-role/StepFunctionsRoleNew'
assert workflow.update(definition=new_definition, role=new_role) == state_machine_arn


def test_attach_existing_workflow(client):
workflow = Workflow.attach(state_machine_arn, client)
assert workflow.name == state_machine_name
assert workflow.role == role_arn
assert workflow.state_machine_arn == state_machine_arn


def test_workflow_list_executions(client, workflow):
paginator = client.get_paginator('list_executions')
paginator.paginate = MagicMock(return_value=[
Expand Down Expand Up @@ -140,12 +148,14 @@ def test_workflow_list_executions(client, workflow):
workflow.state_machine_arn = None
assert workflow.list_executions() == []


def test_workflow_makes_deletion_call(client, workflow):
client.delete_state_machine = MagicMock(return_value=None)
workflow.delete()

client.delete_state_machine.assert_called_once_with(stateMachineArn=state_machine_arn)


def test_workflow_execute_creation(client, workflow):
execution = workflow.execute()
assert execution.workflow.state_machine_arn == state_machine_arn
Expand All @@ -164,11 +174,13 @@ def test_workflow_execute_creation(client, workflow):
input='{}'
)


def test_workflow_execute_when_statemachinearn_is_none(client, workflow):
workflow.state_machine_arn = None
with pytest.raises(WorkflowNotFound):
workflow.execute()


def test_execution_makes_describe_call(client, workflow):
execution = workflow.execute()

Expand All @@ -177,6 +189,7 @@ def test_execution_makes_describe_call(client, workflow):

client.describe_execution.assert_called_once()


def test_execution_makes_stop_call(client, workflow):
execution = workflow.execute()

Expand All @@ -194,6 +207,7 @@ def test_execution_makes_stop_call(client, workflow):
error='Error'
)


def test_execution_list_events(client, workflow):
paginator = client.get_paginator('get_execution_history')
paginator.paginate = MagicMock(return_value=[
Expand Down Expand Up @@ -229,6 +243,7 @@ def test_execution_list_events(client, workflow):
}
)


def test_list_workflows(client):
paginator = client.get_paginator('list_state_machines')
paginator.paginate = MagicMock(return_value=[
Expand All @@ -254,11 +269,14 @@ def test_list_workflows(client):
}
)


def test_cloudformation_export_with_simple_definition(workflow):
cfn_template = workflow.get_cloudformation_template()
cfn_template = yaml.load(cfn_template)
assert 'StateMachineComponent' in cfn_template['Resources']
assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn']
assert cfn_template['Description'] == "CloudFormation template for AWS Step Functions - State Machine"


def test_cloudformation_export_with_sagemaker_execution_role(workflow):
workflow.definition.to_dict = MagicMock(return_value={
Expand All @@ -281,7 +299,8 @@ def test_cloudformation_export_with_sagemaker_execution_role(workflow):
}
}
})
cfn_template = workflow.get_cloudformation_template()
cfn_template = workflow.get_cloudformation_template(description="CloudFormation template with Sagemaker role")
cfn_template = yaml.load(cfn_template)
assert json.dumps(workflow.definition.to_dict(), indent=2) == cfn_template['Resources']['StateMachineComponent']['Properties']['DefinitionString']
assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn']
assert cfn_template['Description'] == "CloudFormation template with Sagemaker role"

0 comments on commit d2ce83d

Please sign in to comment.