Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support placeholders for processing step #155

Merged
merged 23 commits into from
Aug 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
927b24f
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
003b5e8
Merge branch 'main' into update-contributing
shivlaks Aug 9, 2021
a7700a6
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
6b6443a
Update table
ca-nguyen Aug 12, 2021
7f6ef30
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
00830f3
Added doc
ca-nguyen Aug 12, 2021
c708da7
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
2ea9e1f
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
17543ed
documentation: Add setup instructions to run/debug tests locally
ca-nguyen Jul 16, 2021
36e2ee8
Added sub section for debug setup and linked to run tests instructions
ca-nguyen Aug 10, 2021
ea40f7c
Update table
ca-nguyen Aug 12, 2021
e499108
Support placeholders for processor parameters in processingstep
ca-nguyen Aug 12, 2021
4c63229
Added doc
ca-nguyen Aug 12, 2021
34bb281
Removed contibuting changes(included in another pr)
ca-nguyen Aug 12, 2021
a098c61
Merge sagemaker generated parameters with placeholder compatible para…
ca-nguyen Aug 17, 2021
06eb069
Merge branch 'support-placeholders-for-processing-step' of https://gi…
ca-nguyen Aug 17, 2021
da99c92
Using == instead of is()
ca-nguyen Aug 17, 2021
37b2422
Removed unused InvalidPathToPlaceholderParameter exception
ca-nguyen Aug 17, 2021
c433576
Merge branch 'main' into support-placeholders-for-processing-step
ca-nguyen Aug 17, 2021
fd640ab
Added doc and renamed args
ca-nguyen Aug 18, 2021
1dfa0e3
Update src/stepfunctions/steps/sagemaker.py parameters description
ca-nguyen Aug 19, 2021
6143783
Removed dict name args to opt for more generic log message when overw…
ca-nguyen Aug 19, 2021
ebc5e22
Using fstring in test
ca-nguyen Aug 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/stepfunctions/steps/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class Field(Enum):
HeartbeatSeconds = 'heartbeat_seconds'
HeartbeatSecondsPath = 'heartbeat_seconds_path'


# Retry and catch fields
ErrorEquals = 'error_equals'
IntervalSeconds = 'interval_seconds'
Expand Down
45 changes: 27 additions & 18 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from stepfunctions.inputs import Placeholder
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
from stepfunctions.steps.utils import merge_dicts, tags_dict_to_kv_list
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn

from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
Expand All @@ -30,6 +30,7 @@

SAGEMAKER_SERVICE_NAME = "sagemaker"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see us embracing pep8 in files we touch 🙌

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🙌🙌

class SageMakerApi(Enum):
CreateTrainingJob = "createTrainingJob"
CreateTransformJob = "createTransformJob"
Expand Down Expand Up @@ -479,7 +480,9 @@ class ProcessingStep(Task):
Creates a Task State to execute a SageMaker Processing Job.
"""

def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None, container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True, tags=None, **kwargs):
def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, experiment_config=None,
container_arguments=None, container_entrypoint=None, kms_key_id=None, wait_for_completion=True,
tags=None, **kwargs):
"""
Args:
ca-nguyen marked this conversation as resolved.
Show resolved Hide resolved
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
Expand All @@ -491,15 +494,18 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
experiment_config (dict, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str]): The arguments for a container used to run a processing job.
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
experiment_config (dict or Placeholder, optional): Specify the experiment config for the processing. (Default: None)
container_arguments ([str] or Placeholder): The arguments for a container used to run a processing job.
container_entrypoint ([str] or Placeholder): The entrypoint for a container used to run a processing job.
kms_key_id (str or Placeholder): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the processing job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the processing job and proceed to the next step. (default: True)
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
tags (list[dict] or Placeholder, optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateProcessingJob<https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateProcessingJob.html>`_.
You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders<https://aws-step-functions-data-science-sdk.readthedocs.io/en/stable/placeholders.html?highlight=placeholder#stepfunctions.inputs.Placeholder>`_.

"""
if wait_for_completion:
"""
Expand All @@ -518,22 +524,25 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
SageMakerApi.CreateProcessingJob)

if isinstance(job_name, str):
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
else:
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)
processing_parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id)

if isinstance(job_name, Placeholder):
ca-nguyen marked this conversation as resolved.
Show resolved Hide resolved
parameters['ProcessingJobName'] = job_name
processing_parameters['ProcessingJobName'] = job_name

if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config
processing_parameters['ExperimentConfig'] = experiment_config

if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

if 'S3Operations' in parameters:
del parameters['S3Operations']

kwargs[Field.Parameters.value] = parameters
processing_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags)

if 'S3Operations' in processing_parameters:
del processing_parameters['S3Operations']

if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict):
# Update processing_parameters with input parameters
merge_dicts(processing_parameters, kwargs[Field.Parameters.value])

kwargs[Field.Parameters.value] = processing_parameters
super(ProcessingStep, self).__init__(state_id, **kwargs)
26 changes: 26 additions & 0 deletions src/stepfunctions/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import boto3
import logging
from stepfunctions.inputs import Placeholder

logger = logging.getLogger('stepfunctions')

Expand Down Expand Up @@ -45,3 +46,28 @@ def get_aws_partition():
return cur_partition

return cur_partition


def merge_dicts(target, source):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: is it typical to modify a dict in place rather than return a merged dict that doesn't manipulate inputs?
i'm not sure if it's idiomatic, or my Java tendencies to declare inputs as final is kicking in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was implemented having dict.update() function in mind, where it is possible to update a dict with another. In our case, we are merging nested dicts as well.
Mutable objects are all passed by reference in Python and the description explains the function behaviour, so I think it makes sense to leave it as is - what do you think? :)

"""
ca-nguyen marked this conversation as resolved.
Show resolved Hide resolved
Merges source dictionary into the target dictionary.
Values in the target dict are updated with the values of the source dict.
Args:
target (dict): Base dictionary into which source is merged
source (dict): Dictionary used to update target. If the same key is present in both dictionaries, source's value
will overwrite target's value for the corresponding key
"""
if isinstance(target, dict) and isinstance(source, dict):
for key, value in source.items():
if key in target:
if isinstance(target[key], dict) and isinstance(source[key], dict):
merge_dicts(target[key], source[key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to see recursion being used :)

elif target[key] == value:
pass
else:
logger.info(
f"Property: <{key}> with value: <{target[key]}>"
f" will be overwritten with provided value: <{value}>")
target[key] = source[key]
else:
target[key] = source[key]
96 changes: 96 additions & 0 deletions tests/integ/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sagemaker.tuner import HyperparameterTuner
from sagemaker.processing import ProcessingInput, ProcessingOutput

from stepfunctions.inputs import ExecutionInput
from stepfunctions.steps import Chain
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, TuningStep, ProcessingStep
from stepfunctions.workflow import Workflow
Expand Down Expand Up @@ -352,3 +353,98 @@ def test_processing_step(sklearn_processor_fixture, sagemaker_session, sfn_clien
# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
# End of Cleanup


def test_processing_step_with_placeholders(sklearn_processor_fixture, sagemaker_session, sfn_client, sfn_role_arn,
sagemaker_role_arn):
region = boto3.session.Session().region_name
input_data = f"s3://sagemaker-sample-data-{region}/processing/census/census-income.csv"

input_s3 = sagemaker_session.upload_data(
path=os.path.join(DATA_DIR, 'sklearn_processing'),
bucket=sagemaker_session.default_bucket(),
key_prefix='integ-test-data/sklearn_processing/code'
)

output_s3 = f"s3://{sagemaker_session.default_bucket()}/integ-test-data/sklearn_processing"

inputs = [
ProcessingInput(source=input_data, destination='/opt/ml/processing/input', input_name='input-1'),
ProcessingInput(source=input_s3 + '/preprocessor.py', destination='/opt/ml/processing/input/code',
input_name='code'),
]

outputs = [
ProcessingOutput(source='/opt/ml/processing/train', destination=output_s3 + '/train_data',
output_name='train_data'),
ProcessingOutput(source='/opt/ml/processing/test', destination=output_s3 + '/test_data',
output_name='test_data'),
]

# Build workflow definition
execution_input = ExecutionInput(schema={
'image_uri': str,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only using these values for test purposes, using the direct string values for better code readability

'instance_count': int,
'entrypoint': str,
'role': str,
'volume_size_in_gb': int,
'max_runtime_in_seconds': int,
'container_arguments': [str],
})

parameters = {
'AppSpecification': {
'ContainerEntrypoint': execution_input['entrypoint'],
'ImageUri': execution_input['image_uri']
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount': execution_input['instance_count'],
'VolumeSizeInGB': execution_input['volume_size_in_gb']
}
},
'RoleArn': execution_input['role'],
'StoppingCondition': {
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
}
}

job_name = generate_job_name()
processing_step = ProcessingStep('create_processing_job_step',
processor=sklearn_processor_fixture,
job_name=job_name,
inputs=inputs,
outputs=outputs,
container_arguments=execution_input['container_arguments'],
container_entrypoint=execution_input['entrypoint'],
parameters=parameters
)
workflow_graph = Chain([processing_step])

with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
workflow = create_workflow_and_check_definition(
workflow_graph=workflow_graph,
workflow_name=unique_name_from_base("integ-test-processing-step-workflow"),
sfn_client=sfn_client,
sfn_role_arn=sfn_role_arn
)

execution_input = {
'image_uri': '683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3',
'instance_count': 1,
'entrypoint': ['python3', '/opt/ml/processing/input/code/preprocessor.py'],
'role': sagemaker_role_arn,
'volume_size_in_gb': 30,
'max_runtime_in_seconds': 500,
'container_arguments': ['--train-test-split-ratio', '0.2']
}

# Execute workflow
execution = workflow.execute(inputs=execution_input)
execution_output = execution.get_output(wait=True)

# Check workflow output
assert execution_output.get("ProcessingJobStatus") == "Completed"

# Cleanup
state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
137 changes: 136 additions & 1 deletion tests/unit/test_sagemaker_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@

from unittest.mock import MagicMock, patch
from stepfunctions.inputs import ExecutionInput, StepInput
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep, ProcessingStep
from stepfunctions.steps.fields import Field
from stepfunctions.steps.sagemaker import TrainingStep, TransformStep, ModelStep, EndpointStep, EndpointConfigStep,\
ProcessingStep
from stepfunctions.steps.sagemaker import tuning_config

from tests.unit.utils import mock_boto_api_call
Expand Down Expand Up @@ -962,3 +964,136 @@ def test_processing_step_creation(sklearn_processor):
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
'End': True
}


def test_processing_step_creation_with_placeholders(sklearn_processor):
execution_input = ExecutionInput(schema={
'image_uri': str,
'instance_count': int,
'entrypoint': str,
'output_kms_key': str,
'role': str,
'env': str,
'volume_size_in_gb': int,
'volume_kms_key': str,
'max_runtime_in_seconds': int,
'tags': [{str: str}],
'container_arguments': [str]
})

step_input = StepInput(schema={
'instance_type': str
})

parameters = {
'AppSpecification': {
'ContainerEntrypoint': execution_input['entrypoint'],
'ImageUri': execution_input['image_uri']
},
'Environment': execution_input['env'],
'ProcessingOutputConfig': {
'KmsKeyId': execution_input['output_kms_key']
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount': execution_input['instance_count'],
'InstanceType': step_input['instance_type'],
'VolumeKmsKeyId': execution_input['volume_kms_key'],
'VolumeSizeInGB': execution_input['volume_size_in_gb']
}
},
'RoleArn': execution_input['role'],
'StoppingCondition': {
'MaxRuntimeInSeconds': execution_input['max_runtime_in_seconds']
},
'Tags': execution_input['tags']
}

inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')]
outputs = [
ProcessingOutput(source='/opt/ml/processing/output/train'),
ProcessingOutput(source='/opt/ml/processing/output/validation'),
ProcessingOutput(source='/opt/ml/processing/output/test')
]
step = ProcessingStep(
'Feature Transformation',
sklearn_processor,
'MyProcessingJob',
container_entrypoint=execution_input['entrypoint'],
container_arguments=execution_input['container_arguments'],
kms_key_id=execution_input['output_kms_key'],
inputs=inputs,
outputs=outputs,
parameters=parameters
)
assert step.to_dict() == {
'Type': 'Task',
'Parameters': {
'AppSpecification': {
'ContainerArguments.$': "$$.Execution.Input['container_arguments']",
'ContainerEntrypoint.$': "$$.Execution.Input['entrypoint']",
'ImageUri.$': "$$.Execution.Input['image_uri']"
},
'Environment.$': "$$.Execution.Input['env']",
'ProcessingInputs': [
{
'InputName': None,
'AppManaged': False,
'S3Input': {
'LocalPath': '/opt/ml/processing/input',
'S3CompressionType': 'None',
'S3DataDistributionType': 'FullyReplicated',
'S3DataType': 'S3Prefix',
'S3InputMode': 'File',
'S3Uri': 'dataset.csv'
}
}
],
'ProcessingOutputConfig': {
'KmsKeyId.$': "$$.Execution.Input['output_kms_key']",
'Outputs': [
{
'OutputName': None,
'AppManaged': False,
'S3Output': {
'LocalPath': '/opt/ml/processing/output/train',
'S3UploadMode': 'EndOfJob',
'S3Uri': None
}
},
{
'OutputName': None,
'AppManaged': False,
'S3Output': {
'LocalPath': '/opt/ml/processing/output/validation',
'S3UploadMode': 'EndOfJob',
'S3Uri': None
}
},
{
'OutputName': None,
'AppManaged': False,
'S3Output': {
'LocalPath': '/opt/ml/processing/output/test',
'S3UploadMode': 'EndOfJob',
'S3Uri': None
}
}
]
},
'ProcessingResources': {
'ClusterConfig': {
'InstanceCount.$': "$$.Execution.Input['instance_count']",
'InstanceType.$': "$['instance_type']",
'VolumeKmsKeyId.$': "$$.Execution.Input['volume_kms_key']",
'VolumeSizeInGB.$': "$$.Execution.Input['volume_size_in_gb']"
shivlaks marked this conversation as resolved.
Show resolved Hide resolved
}
},
'ProcessingJobName': 'MyProcessingJob',
'RoleArn.$': "$$.Execution.Input['role']",
'Tags.$': "$$.Execution.Input['tags']",
'StoppingCondition': {'MaxRuntimeInSeconds.$': "$$.Execution.Input['max_runtime_in_seconds']"},
},
'Resource': 'arn:aws:states:::sagemaker:createProcessingJob.sync',
'End': True
}
Loading